2016-10-18 4 views
1

Предположим, у нас есть два графика вычисления TensorFlow, G1 и G2, с сохраненными весами W1 и W2. Предположим, мы построили новый график G, просто построив G1 и G2. Как мы можем восстановить как W1, так и W2 для этого нового графика G?TensorFlow: восстановление нескольких графиков

В качестве простого примера:

import tensorflow as tf 

V1 = tf.Variable(tf.zeros([1])) 
saver_1 = tf.train.Saver() 
V2 = tf.Variable(tf.zeros([1])) 
saver_2 = tf.train.Saver() 

sess = tf.Session() 
saver_1.restore(sess, 'W1') 
saver_2.restore(sess, 'W2') 

В этом примере, saver_1 восстанавливает успешно соответствующее V1, но saver_2 терпит неудачу с NotFoundError.

ответ

2

Возможно, вы можете использовать две вкладки, где каждый хранитель ищет только одну из переменных. Если вы просто используете tf.train.Saver(), я думаю, что он будет искать все переменные, которые вы определили. Вы можете дать ему список переменных, которые нужно искать, используя tf.train.Saver([v1, ...]). Для получения дополнительной информации вы можете прочитать о конструкторе tf.train.Saver здесь: https://www.tensorflow.org/versions/r0.11/api_docs/python/state_ops.html#Saver

Вот простой рабочий пример. Предположим, что вы делаете ваши вычисления в файле «save_vars.py» и имеет следующий код:

import tensorflow as tf 

# Graph 1 - set v1 to have value [1.0] 
g1 = tf.Graph() 
with g1.as_default(): 
    v1 = tf.Variable(tf.zeros([1]), name="v1") 
    assign1 = v1.assign(tf.constant([1.0])) 
    init1 = tf.initialize_all_variables() 
    save1 = tf.train.Saver() 

# Graph 2 - set v2 to have value [2.0] 
g2 = tf.Graph() 
with g2.as_default(): 
    v2 = tf.Variable(tf.zeros([1]), name="v2") 
    assign2 = v2.assign(tf.constant([2.0])) 
    init2 = tf.initialize_all_variables() 
    save2 = tf.train.Saver() 

# Do the computation for graph 1 and save 
sess1 = tf.Session(graph=g1) 
sess1.run(init1) 
print sess1.run(assign1) 
save1.save(sess1, "tmp/v1.ckpt") 

# Do the computation for graph 2 and save 
sess2 = tf.Session(graph=g2) 
sess2.run(init2) 
print sess2.run(assign2) 
save2.save(sess2, "tmp/v2.ckpt") 

Если вы убедитесь, что у вас есть каталог tmp и запустить python save_vars.py, вы получите сохраненные файлы контрольных точек.

Теперь вы можете восстановить с помощью файла с именем "restore_vars.py" со следующим кодом:

import tensorflow as tf 

# The variables v1 and v2 that we want to restore 
v1 = tf.Variable(tf.zeros([1]), name="v1") 
v2 = tf.Variable(tf.zeros([1]), name="v2") 

# saver1 will only look for v1 
saver1 = tf.train.Saver([v1]) 
# saver2 will only look for v2 
saver2 = tf.train.Saver([v2]) 
with tf.Session() as sess: 
    saver1.restore(sess, "tmp/v1.ckpt") 
    saver2.restore(sess, "tmp/v2.ckpt") 
    print sess.run(v1) 
    print sess.run(v2) 

и при запуске python restore_vars.py, вывод должен быть

[1.] 
[2.] 

(по крайней мере, на моем компьютере это выход). Не стесняйтесь оставлять комментарий, если что-то неясно.

Смежные вопросы