Я реализовал двунаправленный RNN в TensorFlow, используя BasicLSTMCell
и . Я вычисляю потери с использованием seq2seq.sequence_loss_by_example
после конкатенации выходов, которые я получаю. Мое приложение - следующий предиктор символов.Получение чрезвычайно низких потерь в двунаправленном RNN?
Я получаю чрезвычайно низкийcost
, (~ 50 раз меньше, чем однонаправленный RNN). Я подозреваю, что совершил ошибку на этапе seq2seq.sequence_loss_by_example
.
Вот моя модель -
# Model begins
cell_fn = rnn_cell.BasicLSTMCell
cell = fw_cell = cell_fn(args.rnn_size, state_is_tuple=True)
cell2 = bw_cell = cell_fn(args.rnn_size, state_is_tuple=True)
input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
initial_state = fw_cell.zero_state(args.batch_size, tf.float32)
initial_state2 = bw_cell.zero_state(args.batch_size, tf.float32)
with tf.variable_scope('rnnlm'):
softmax_w = tf.get_variable("softmax_w", [2*args.rnn_size, args.vocab_size])
softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
input_embeddings = tf.nn.embedding_lookup(embedding, input_data)
inputs = tf.unpack(input_embeddings, axis=1)
outputs, last_state, last_state2 = rnn.bidirectional_rnn(fw_cell,
bw_cell,
inputs,
initial_state_fw=initial_state,
initial_state_bw=initial_state2,
dtype=tf.float32)
output = tf.reshape(tf.concat(1, outputs), [-1, 2*args.rnn_size])
logits = tf.matmul(output, softmax_w) + softmax_b
probs = tf.nn.softmax(logits)
loss = seq2seq.sequence_loss_by_example([logits],
[tf.reshape(targets, [-1])],
[tf.ones([args.batch_size * args.seq_length])],
args.vocab_size)
cost = tf.reduce_sum(loss)/args.batch_size/args.seq_length
lr = tf.Variable(0.0, trainable=False)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),
args.grad_clip)
optimizer = tf.train.AdamOptimizer(lr)
train_op = optimizer.apply_gradients(zip(grads, tvars))
Буду рад предоставить дополнительную информацию, если необходимо – martianwars