2017-01-17 4 views
0

Ниже приведен простой mnist учебник (т.е. один слой SoftMax) с сайта Tensorflow, который я пытался расширить с многопоточной шаг обучения:Tensorflow и нарезание резьб

from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
import threading 

# Training loop executed in each thread 
def training_func(): 
    while True: 
     batch = mnist.train.next_batch(100) 
     global_step_val,_ = sess.run([global_step, train_step], feed_dict={x: batch[0], y_: batch[1]}) 
     print("global step: %d" % global_step_val) 
     if global_step_val >= 4000: 
     break 

# create session and graph 
sess = tf.Session() 

x = tf.placeholder(tf.float32, shape=[None, 784]) 
y_ = tf.placeholder(tf.float32, shape=[None, 10]) 

W = tf.Variable(tf.zeros([784,10])) 
b = tf.Variable(tf.zeros([10])) 
global_step = tf.Variable(0, name="global_step") 
y = tf.matmul(x,W) + b 

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_)) 

inc = global_step.assign_add(1) 
with tf.control_dependencies([inc]): 
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 

# initialize graph and create mnist loader 
sess.run(tf.global_variables_initializer()) 
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 

# create workers and execute threads 
workers = [] 
for _ in range(8): 
    t = threading.Thread(target=training_func) 
    t.start() 
    workers.append(t) 

for t in workers: 
    t.join() 

# evaluate accuracy of the model 
print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}, 
    session=sess)) 

Я должен что-то отсутствует, поскольку 8, как ниже, дают непоследовательные результаты (точность около 0,1), когда с 1 нитью получается только ожидаемая точность (около 0,92). Кто-нибудь знает о моей ошибке? Благодаря!

+2

Вы понимаете, что графики TF скомпилированы и выполнены сильным параллельным движком. Если вы посмотрите на использование ЦП при однопоточном тренинге, вы увидите, что все ядра получают нагрузку, а не только одну. Что вы хотите сделать, нанизав обучение? Я ожидаю, что проблемы, которые вы видите, происходят из нескольких потоков, обновляющих веса без какого-либо контроля и переписывания изменений друг друга. –

+0

Моей целью было бы ускорить дорогостоящее обучение. Я понимаю, что TF действительно параллельна, но также и то, что ускорение можно получить при многопоточности - например, в приведенном выше примере, диапазон (1) дает 15-20% использования для всех ядер, тогда как диапазон (16) приводит к использованию 60-80%. –

+0

Я подозревал, что моя проблема исходит из неконтролируемых одновременных обновлений веса. Однако [этот код учебника TF] (https://github.com/tensorflow/tensorflow/blob/r0.12/tensorflow/models/embedding/word2vec_optimized.py) делает что-то аналогичное моему примеру (l.319 - l. 340), но я не понимаю, почему это работает в их случае. Может быть, их обучение op (word2vec.neg_train) управляет этими параллельными обновлениями внутри? –

ответ

1

Обратите внимание: к сожалению, threading с python не создает реальный параллелизм из-за GIL. Итак, что происходит здесь, так это то, что у вас будет несколько потоков, которые все работают на одном процессоре, где на самом деле они работают последовательно. Поэтому я бы предложил использовать координатор в Tensorflow. Более подробная информации о координаторе можно найти здесь:

https://www.tensorflow.org/programmers_guide/threading_and_queues
https://www.tensorflow.org/programmers_guide/reading_data

Наконец, я хотел бы предложить вам сказать:

with tf.device('/cpu:0'): 
    your code should go here... 'for the first thread' 

Затем использовать другой процессор для других потоков и так далее ... Надеюсь, что этот ответ найдет вас хорошо!

+0

Только что увидел ответ и еще не протестировал его, но информация выглядит актуальной, спасибо за ответ. Я сейчас упрекаю, и я выберу ответ, если он решает проблему (когда у меня есть время проверить :)) –

Смежные вопросы