2017-01-15 3 views
0

В документации онлайн говорится, что move_average и moving_variance являются как model_variables, так и tf.model_variables() возвращает тензоры типа local_variables. Означает ли это, что переменные model_variables не сохраняются, когда я сохраняю свое состояние?переменные модели в Batch_norm Tensorflow

Я пытаюсь применить нормализацию партии к нескольким трехмерным сверткам и полностью связанным слоям. Я тренировал свою сеть с помощью batch_norm и сохранял файл контрольной точки, но когда я пошел на восстановление сохраненного состояния, он сказал, что move_mean не может быть найден. Точная ошибка заключалась в том, что когда TF отправлял восстановленное значение в move_mean, форма тензора lhs [] не могла быть согласована с формой rhs [20].

График восстанавливается отлично, когда я не добавляю batch_norm вокруг своих слоев. Я планирую добавить глобальную переменную в конце обучения, которая сохранит мои значения moving_mean и moving_variance. Это способ, которым TF предназначался для меня использовать batch_norm?

Спасибо!

ответ

1

Переменные moving_mean и moving_variance не были сохранены в моем сохраненном состоянии, потому что я установил update_collections в значение по умолчанию. Поскольку я никогда не включал контрольную зависимость, когда я запускал слои, эти переменные никогда не обновлялись.

Код для включения является:

from tensorflow.python import control_flow_ops 

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 
if update_ops: 
    updates = tf.tuple(update_ops) 
    total_loss = control_flow_ops.with_dependencies(updates, total_loss) 

Или установите

updates_collection=None 

для обновления в месте.

См. the API description и current github discussion для получения дополнительной информации.

Смежные вопросы