Короткий ответ: Возможно, вы хотите checkpoint files (permalink).
Длинный ответ:
Давайте проясним об установке здесь. Я предполагаю, что у вас есть два устройства: A и B, и вы тренируетесь по A и выполняете вывод на B. Периодически вы хотите обновить параметры на устройстве, выполняющем вывод, с новыми параметрами, найденными во время обучения на Другие. Урок, упомянутый выше, является хорошим местом для начала. Он показывает вам, как работают объекты tf.train.Saver
, и здесь вам не нужно ничего сложнее.
Вот пример:
import tensorflow as tf
def build_net(graph, device):
with graph.as_default():
with graph.device(device):
# Input placeholders
inputs = tf.placeholder(tf.float32, [None, 784])
labels = tf.placeholder(tf.float32, [None, 10])
# Initialization
w0 = tf.get_variable('w0', shape=[784,256], initializer=tf.contrib.layers.xavier_initializer())
w1 = tf.get_variable('w1', shape=[256,256], initializer=tf.contrib.layers.xavier_initializer())
w2 = tf.get_variable('w2', shape=[256,10], initializer=tf.contrib.layers.xavier_initializer())
b0 = tf.Variable(tf.zeros([256]))
b1 = tf.Variable(tf.zeros([256]))
b2 = tf.Variable(tf.zeros([10]))
# Inference network
h1 = tf.nn.relu(tf.matmul(inputs, w0)+b0)
h2 = tf.nn.relu(tf.matmul(h1,w1)+b1)
output = tf.nn.softmax(tf.matmul(h2,w2)+b2)
# Training network
cross_entropy = tf.reduce_mean(-tf.reduce_sum(labels * tf.log(output), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# Your checkpoint function
saver = tf.train.Saver()
return tf.initialize_all_variables(), inputs, labels, output, optimizer, saver
Код программы обучения:
def programA_main():
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# Build training network on device A
graphA = tf.Graph()
init, inputs, labels, _, training_net, saver = build_net(graphA, '/cpu:0')
with tf.Session(graph=graphA) as sess:
sess.run(init)
for step in xrange(1,10000):
batch = mnist.train.next_batch(50)
sess.run(training_net, feed_dict={inputs: batch[0], labels: batch[1]})
if step%100==0:
saver.save(sess, '/tmp/graph.checkpoint')
print 'saved checkpoint'
... и код программы вывода:
def programB_main():
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# Build inference network on device B
graphB = tf.Graph()
init, inputs, _, inference_net, _, saver = build_net(graphB, '/cpu:0')
with tf.Session(graph=graphB) as sess:
batch = mnist.test.next_batch(50)
saver.restore(sess, '/tmp/graph.checkpoint')
print 'loaded checkpoint'
out = sess.run(inference_net, feed_dict={inputs: batch[0]})
print out[0]
import time; time.sleep(2)
saver.restore(sess, '/tmp/graph.checkpoint')
print 'loaded checkpoint'
out = sess.run(inference_net, feed_dict={inputs: batch[0]})
print out[1]
Если огонь программа обучения, а затем программа вывода, вы увидите, что программа вывода выводит два разных выхода (из той же входной партии). Это результат того, что он подбирает параметры, которые проверила контрольная программа.
Теперь эта программа, очевидно, не является вашей конечной точкой. Мы не выполняем никакой реальной синхронизации, и вам нужно будет решить, что означает «периодический» в отношении контрольной точки. Но это должно дать вам представление о том, как синхронизировать параметры из одной сети в другую.
Последнее предупреждение: это не означает, что две сети обязательно детерминированы. Существуют известные недетерминированные элементы в TensorFlow (например, this), поэтому будьте осторожны, если вам нужно точно тот же ответ.Но это трудная правда о запуске на нескольких устройствах.
Удачи вам!
Почему бы не строить несколько графиков параллельно, а не копировать существующие? –
Этот вопрос довольно неоднозначный. Вы спрашиваете об обновлении структуры данных TensorFlow 'Graph' in situ [(жесткий)] (http://stackoverflow.com/questions/37610757/how-to-remove-nodes-from-tensorflow-graph/37620231#37620231) ? Или вы спрашиваете, как обновлять параметры в одном графике из другого [(не так уж плохо)] (https://www.tensorflow.org/versions/master/how_tos/variables/index.html#saving-and-restoring) без изменения структуры? Или это связано с контролем версий в нейронных сетях (что является проблемой разработки программного обеспечения в более широком смысле)? – rdadolf
@rdadolf второй. Мне просто нужно сохранить копию одних и тех же моделей на разных машинах и время от времени синхронизировать параметры. – MBZ