У меня разные области применения, и они имеют переменные с одинаковыми именами, но с разными значениями. Я хочу поменять значения этих переменных между областями. Пример:Tensorflow: как обменивать переменные между областями и задавать переменные в области от другого
with tf.variable_scope('sc1'):
a1 = tf.Variable(0, name='test_var1')
b1 = tf.Variable(1, name='test_var2')
with tf.variable_scope('sc2'):
a2 = tf.Variable(2, name='test_var1')
b2 = tf.Variable(3, name='test_var2')
Я хочу установить a2
к 0, b2
1, a1
до 2 и b1
до 3.
Я думал о получении необходимых переменных с tf.get_collection_ref
, но я не могу видеть, как Я могу изменить область действия переменной, поэтому, вероятно, мне нужно изменить значения переменных. В этом случае мне нужно сохранить одно значение во временной переменной, а затем удалить эту временную переменную. Я не уверен, что это сработает, и это кажется слишком сложным. Есть ли простой способ сделать это?
UPD1: Также мне нужно установить все переменные в одну коллекцию из другой коллекции. Я думаю, что это аналогичная проблема. Например, в приведенном выше коде установите a2
равным 0 и b2
1.
UPD2: Этот код не работает:
with tf.variable_scope('sc1'):
a1 = tf.get_variable(name='test_var1', initializer=0.)
b1 = tf.Variable(0, name='test_var2')
with tf.variable_scope('sc2'):
a2 = tf.get_variable(name='test_var1', initializer=1.)
b2 = tf.Variable(1, name='test_var2')
def swap_tf_scopes(col1, col2):
col1_dict = {}
col2_dict = {}
for curr_var in col1:
curr_var_name = curr_var.name.split('/')[-1]
col1_dict[curr_var_name] = curr_var
for curr_var in col2:
curr_var_name = curr_var.name.split('/')[-1]
curr_col1_var = col1_dict[curr_var_name]
tmp_t = tf.identity(curr_col1_var)
assign1 = curr_col1_var.assign(curr_var)
assign2 = curr_var.assign(tmp_t)
return [assign1, assign2]
col1 = tf.get_collection(tf.GraphKeys.VARIABLES, scope='sc1')
col2 = tf.get_collection(tf.GraphKeys.VARIABLES, scope='sc2')
tf_ops_t = swap_tf_collections(col1, col2)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(tf_ops_t)
print sess.run(col1) #prints [0.0, 1] but I expect [1.0, 1]
print sess.run(col2) #prints [1.0, 1] but I expect [0.0, 0]
Предполагается, что на основе быстрого считывания: запись в одну переменную перезаписывает ввод другой записи. 'tf.identity' недостаточно для принудительной копии данных Тензора. Попробуйте что-то вроде 'tmp_t = curr_col1_var + 0.0'. Надеюсь, это поможет! –