2016-08-17 4 views
2

Я использую метод тензор, чтобы сделать градиентную приличную классификацию.Поезд только некоторых переменных в тензорном потоке

train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) 

здесь cost функциязатраты, которые я использовал в оптимизации. После запуска Graph в сессии, график может быть подан как:

sess.run(train_op, feed_dict) 

И с этим, все переменные в функции затрат будет обновляться для того, чтобы свести к минимуму расходы.

Вот мой вопрос. Как я могу обновить только некоторые переменные в функции стоимости при обучении ..? Есть ли способ конвертировать созданные переменные в константы или что-то в этом роде?

+0

Если вы определили свою собственную функцию стоимости, вы можете трудно написать переменные, которые вы хотите постоянно, а не обновлять их. Я не знаю, понимаете ли вы, что я имею в виду. – CoMartel

+3

Вы можете указать список переменных в 'GradientDescentOptimizer.minimize()' как 'var_list' (также см. Https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html#usage и https : //www.tensorflow.org/versions/r0.10/api_docs/python/train.html#Optimizer.minimize), делает ли это то, что вы хотите? – fwalch

+2

См. Http://stackoverflow.com/questions/35298326/freeze-some-variables-scopes-in-tensorflow-stop-gradient-vs-passing-variables?rq=1 – jean

ответ

0

Есть несколько хороших ответов, этот вопрос уже должен быть закрыт: stackoverflow Quora

Просто, чтобы избежать еще щелчка для людей получать здесь:

Минимизировать функцию tensorflow оптимизатора принимает var_list аргумента для этой цели:

first_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 
            "scope/prefix/for/first/vars") 
first_train_op = optimizer.minimize(cost, var_list=first_train_vars) 

second_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 
             "scope/prefix/for/second/vars")      
second_train_op = optimizer.minimize(cost, var_list=second_train_vars) 

Я принял это как от mrry

Чтобы получить список имен вы должны использовать вместо "scope/prefix/for/second/vars" вы можете использовать:

tf.get_default_graph().get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) 
Смежные вопросы