2016-08-03 2 views
2

У меня есть придуманную версию сложной сети:Почему эта петля тензорного потока требует столько памяти?

import tensorflow as tf 

a = tf.ones([1000]) 
b = tf.ones([1000]) 

for i in range(int(1e6)): 
    a = a * b 

Моя интуиция, что это должно требовать очень мало памяти. Просто пространство для начального распределения массива и строка команд, которая использует узлы и перезаписывает память, хранящуюся в тензоре «a» на каждом шаге. Но использование памяти растет довольно быстро.

Что происходит здесь и как уменьшить использование памяти при вычислении тензора и перезаписать его несколько раз?

Edit:

Благодаря предложениям Ярослава решение оказалось использованием while_loop, чтобы минимизировать число узлов на графе. Это отлично работает и намного быстрее, требует гораздо меньше памяти, и все это содержится в графике.

import tensorflow as tf 

a = tf.ones([1000]) 
b = tf.ones([1000]) 

cond = lambda _i, _1, _2: tf.less(_i, int(1e6)) 
body = lambda _i, _a, _b: [tf.add(_i, 1), _a * _b, _b] 

i = tf.constant(0) 
output = tf.while_loop(cond, body, [i, a, b]) 

with tf.Session() as sess: 
    result = sess.run(output) 
    print(result) 

ответ

6

Ваша команда a*b переводится tf.mul(a, b), что эквивалентно tf.mul(a, b, g=tf.get_default_graph()). Эта команда добавляет узел Mul к текущему объекту Graph, поэтому вы пытаетесь добавить 1 миллион Mul узлов к текущему графику. Это также проблематично, поскольку вы не можете сериализовать объект Graph размером более 2 ГБ, есть некоторые проверки, которые могут потерпеть неудачу, если вы имеете дело с таким большим графиком.

Я рекомендую читать Programming Models for Deep Learning от MXNet. TensorFlow - это «символическое» программирование в их терминологии, и вы относитесь к нему как к императивному.

Чтобы получить то, что вы хотите с помощью цикла Python можно построить умножение цит один раз, и запустить его повторно, используя feed_dict кормить обновления

mul_op = a*b 
result = sess.run(a) 
for i in range(int(1e6)): 
    result = sess.run(mul_op, feed_dict={a: result}) 

Для большей эффективности можно использовать tf.Variable объекты и var.assign, чтобы избежать Python < -> Передача данных TensorFlow

+0

Это имеет смысл. Я хотел бы только назвать sess.run один раз, и обрабатывать все вычисления из ввода в вывод за один шаг, так как я возвращаюсь к большому графику. Возможно ли сделать этот цикл в графике без добавления дополнительных узлов? – jstaker7

+0

Если вы сохраняете входы в * b в объектах 'tf.Variable', вы должны изолировать его зависимости, поэтому вы можете сделать' sess.run' на этом узле миллион раз, не вычисляя ничего на графике –

+0

Не уверен Я понимаю. Скажем, 'b' - это обучаемая весовая матрица, которая обновляется во время back-prop. Если я вызову sess.run более одного раза, я не только возьму дополнительные накладные расходы за несколько вызовов sess.run, я отключу это вычисление от вычисления градиента, и вам нужно будет сделать какую-то грязную работу, чтобы убедиться, что она правильно обновляется. Правильны ли эти предположения? Наверное, я бы хотел, чтобы был лучший способ справиться с этим графиком. – jstaker7

Смежные вопросы