2016-11-24 6 views
4

Я пробовал Tensorflow's rnn example. С некоторыми проблемами в начале я мог бы запустить пример, чтобы обучить ptb, и теперь у меня есть обученная модель.Как использовать пример модели PTB от Tensorflow?

Как я могу точно использовать модель для создания предложений без необходимости тренировки каждый раз?

Я бегу это с помощью команды python ptb_word_lm.py --data_path=/home/data/ --model medium --save_path=/home/medium

Есть ли пример где-то о том, как использовать обученную модель, чтобы сделать предложения?

+0

Можете ли вы рассказать/показать мне, что результат по команде «--save_path/дом/средний»? – LKM

ответ

5

1.Add следующий код на последней строке PTBModel:__init__() функции:

self._output_probs = tf.nn.softmax(logits) 

2.Add следующая функция в PTBModel:

@property 
def output_probs(self): 
    return self._output_probs 

3.Try, чтобы запустить следующий код:

raw_data = reader.ptb_raw_data(FLAGS.data_path) 
train_data, valid_data, test_data, vocabulary, word_to_id, id_to_word = raw_data 

eval_config = get_config() 
eval_config.batch_size = 1 
eval_config.num_steps = 1 

sess = tf.Session() 

initializer = tf.random_uniform_initializer(-eval_config.init_scale, 
              eval_config.init_scale) 
with tf.variable_scope("model", reuse=None, initializer=initializer): 
    mtest = PTBModel(is_training=False, config=eval_config) 

sess.run(tf.initialize_all_variables()) 

saver = tf.train.Saver() 

ckpt = tf.train.get_checkpoint_state('/home/medium') # __YOUR__MODEL__SAVE__PATH__ 
if ckpt and gfile.Exists(ckpt.model_checkpoint_path): 
    msg = 'Reading model parameters from %s' % ckpt.model_checkpoint_path 
    print(msg) 
    saver.restore(sess, ckpt.model_checkpoint_path) 

def pick_from_weight(weight, pows=1.0): 
    weight = weight**pows 
    t = np.cumsum(weight) 
    s = np.sum(weight) 
    return int(np.searchsorted(t, np.random.rand(1) * s)) 

while True: 
    number_of_sentences = 10 # generate 10 sentences one time 
    sentence_cnt = 0 
    text = '\n' 
    end_of_sentence_char = word_to_id['<eos>'] 
    input_char = np.array([[end_of_sentence_char]]) 
    state = sess.run(mtest.initial_state) 
    while sentence_cnt < number_of_sentences: 
     feed_dict = {mtest.input_data: input_char, 
        mtest.initial_state: state} 
     probs, state = sess.run([mtest.output_probs, mtest.final_state], 
             feed_dict=feed_dict) 
     sampled_char = pick_from_weight(probs[0]) 
     if sampled_char == end_of_sentence_char: 
      text += '.\n' 
      sentence_cnt += 1 
     else: 
      text += ' ' + id_to_word[sampled_char] 
     input_char = np.array([[sampled_char]]) 
    print(text) 
    raw_input('press any key to continue ...') 
+0

Я получаю сообщение об ошибке: при запуске этого кода объект 'PTBModel' не имеет атрибута '_output_probs''. – smith

Смежные вопросы