2016-10-29 3 views
7

Я следую за этим Manipulating matrix elements in tensorflow. используя tf.scatter_update. Но моя проблема: Что произойдет, если мой tf.Variable является 2D? Предположим, что:Используйте tf.scatter_update в двумерном tf.Variable

a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]]) 

Как я могу обновить, например, первый элемент каждой строки и присвоить этому значение 1?

Я пытался что-то вроде

for line in range(2): 
    sess.run(tf.scatter_update(a[line],[0],[1])) 

но неудачно (я ожидал, что), и дает мне ошибку:

TypeError: Input 'ref' of 'ScatterUpdate' Op requires l-value input

Как я могу исправить такого рода проблемы?

`

ответ

5

В tensorflow вы не можете обновить тензорной, но вы можете обновить переменную.

Оператор scatter_update может обновить только первое измерение переменной. Вы должны всегда передавать опорный тензор на обновление рассеяния (a вместо a[line]).

Это, как вы можете обновить первый элемент переменной:

import tensorflow as tf 

g = tf.Graph() 
with g.as_default(): 
    a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]]) 
    b = tf.scatter_update(a, [0, 1], [[1, 0, 0, 0], [1, 0, 0, 0]]) 

with tf.Session(graph=g) as sess: 
    sess.run(tf.initialize_all_variables()) 
    print sess.run(a) 
    print sess.run(b) 

Выход:

[[0 0 0 0] 
[0 0 0 0]] 
[[1 0 0 0] 
[1 0 0 0]] 

Но того, чтобы снова изменить весь тензор может быть быстрее просто задаём полностью новый.

Смежные вопросы