2016-11-22 3 views
0

Я следую за блоком wildml по классификации текста, используя тензор. Я изменил код, чтобы сохранить граф Защиты следующим образом:Ошибка Tensorflow при восстановлении графа def из файла .pb

tf.train.write_graph(sess.graph_def,'./DeepLearn/model/','train.pb', as_text=False) 

Позже в отдельном файл я восстанавливающий граф следующим образом:

with tf.gfile.FastGFile(os.path.join('./DeepLearn/model/','train.pb'), 'rb') as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read()) 
    _ = tf.import_graph_def(graph_def, name='') 
with tf.Session() as sess: 
    t = sess.graph.get_tensor_by_name('embedding/W:0') 
    sess.run(t) 

Когда я пытаюсь запустить тензор и получить его значение , я получаю следующую ошибку:

tensorflow.python.framework.errors.FailedPreconditionError: Attempting to use uninitialized value embedding/W 

Какова возможная причина этой ошибки. Тензор должен быть инициализирован, так как я восстанавливаю его из сохраненного графика.

+0

'sess.run (tf.initialize_all_variables())'? – sygi

+0

Но я загружаю тензор из ранее сохраненного графика, поэтому я не думаю, что мне нужно его инициализировать с помощью этого утверждения. – Nitin

+1

Вам все равно нужно инициализировать переменные, так как чтение graphdef восстанавливает только сам граф, а не значения переменных. Если вы хотите восстановить значения переменных, которые нужно загрузить с контрольной точки. –

ответ

0

Спасибо Alexandre! Да, мне нужно загрузить как график (из .pb-файла), так и веса (из файла контрольных точек.). Использовал следующий образец кода (взятый из блога), и это сработало для меня.

with tf.Session() as persisted_sess: 
    print("load graph") 
    with gfile.FastGFile("/tmp/load/test.pb",'rb') as f: 
     graph_def = tf.GraphDef() 
     graph_def.ParseFromString(f.read()) 
     persisted_sess.graph.as_default() 
     tf.import_graph_def(graph_def, name='') 
    persisted_result = persisted_sess.graph.get_tensor_by_name("saved_result:0") 
    tf.add_to_collection(tf.GraphKeys.VARIABLES,persisted_result) 
    try: 
     saver = tf.train.Saver(tf.all_variables()) 
    except:pass 
     print("load data") 
    saver.restore(persisted_sess, "checkpoint.data") # now OK 
    print(persisted_result.eval()) 
    print("DONE")