Оба текущие ответы вроде работают путем фильтрации имя переменной, используя строку «Momentum». Но это очень хрупкое с двух сторон:
- Он может молча (повторно) инициализировать некоторые другие переменные, которые вы на самом деле не хотите сбросить! Либо просто из-за конфликта имен, либо потому, что у вас есть более сложный граф и, например, оптимизируйте разные части отдельно.
- Он будет работать только для одного конкретного оптимизатора и как вы узнаете имена, которые нужно искать для других?
- Бонус: обновление до тензорного потока может бесшумно сломайте свой код.
К счастью, абстрактный Optimizer
класса tensorflow имеет механизм для этого, эти дополнительных переменных оптимизатора называются "slots", и вы можете получить все имена слотов оптимизатора, используя get_slot_names()
метод:
opt = tf.train.MomentumOptimizer(...)
print(opt.get_slot_names())
# prints ['momentum']
И вы может получить переменную, соответствующую прорезь для конкретного (обучаемый) переменной v
с использованием метода get_slot(var, slot_name)
:
opt.get_slot(some_var, 'momentum')
Собирает все это вместе, вы можете создать цит, который инициализирует состояние оптимизатора следующим образом:
var_list = # list of vars to optimize, e.g.
# tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
opt = tf.train.MomentumOptimizer(0.1, 0.95)
step_op = opt.minimize(loss, var_list=var_list)
reset_opt_op = tf.variables_initializer([opt.get_slot(var, name) for name in opt.get_slot_names() for var in var_list])
Это будет действительно только сбросить правильные переменное и быть устойчивыми по оптимизаторам.
За исключением одного unfortunate caveat: AdamOptimizer
. Это также поддерживает счетчик того, как часто он называется. Это означает, что вы должны действительно серьезно думать о том, что вы здесь делаете, но для полноты вы можете получить дополнительные состояния как opt._get_beta_accumulators()
. Возвращенный список должен быть добавлен в список в приведенной выше строке reset_opt_op
.
Он должен быть: var_list = [вар для вара в tf.global_variables(), если 'Momentum' в var.name] –
@ MichaelPresečan фиксированный, спасибо! –