2016-06-02 2 views
2

Я пытаюсь реализовать рекуррентный тензор состояния, используя tf.scan. Код у меня на данный момент это:Попытка реализовать повторяющуюся сеть с tf.scan()

import tensorflow as tf 
import math 
import numpy as np 

INPUTS = 10 
HIDDEN_1 = 20 
BATCH_SIZE = 3 


def iterate_state(prev_state_tuple, input): 
    with tf.name_scope('h1'): 
     weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(float(INPUTS)))) 
     biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0)) 
     matmuladd = tf.matmul(inputs, weights) + biases 
     unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple) 
     prev_state = unpacked_state 
     state = 0.9* prev_state + 0.1*matmuladd 
     output = tf.nn.relu(state) 
     return tf.concat(0,[state, output]) 

def data_iter(): 
    while True: 
     idxs = np.random.rand(BATCH_SIZE, INPUTS) 
     yield idxs 

with tf.Graph().as_default(): 
    inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS)) 
    with tf.variable_scope('states'): 
     initial_state = tf.zeros([HIDDEN_1], 
           name='initial_state') 
     initial_out = tf.zeros([HIDDEN_1], 
           name='initial_out') 
     concat_tensor = tf.concat(0,[initial_state, initial_out]) 
     states, output = tf.scan(iterate_state, inputs, 
            initializer=concat_tensor, name='states') 

    sess = tf.Session() 
    # Run the Op to initialize the variables. 
    sess.run(tf.initialize_all_variables()) 
    iter_ = data_iter() 
    for i in xrange(0, 2): 
     print ("iteration: ",i) 
     input_data = iter_.next() 
     out,st = sess.run([output,states], feed_dict={ inputs: input_data}) 

Однако, я получаю эту ошибку при выполнении этого:

Traceback (most recent call last): 
    File "cycles_in_graphs_with_scan.py", line 37, in <module> 
    initializer=concat_tensor, name='states') 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 442, in __iter__ 
    raise TypeError("'Tensor' object is not iterable.") 
TypeError: 'Tensor' object is not iterable. 
(tensorflow)[email protected] ~/projects/stuff $ python cycles_in_graphs_with_scan.py 
Traceback (most recent call last): 
    File "cycles_in_graphs_with_scan.py", line 37, in <module> 
    initializer=concat_tensor, name='states') 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 442, in __iter__ 
    raise TypeError("'Tensor' object is not iterable.") 
TypeError: 'Tensor' object is not iterable. 

Я уже пытался с pack/unpack и concat/split, но я получаю эту же ошибку.

Любые идеи, как решить эту проблему?

ответ

3

Вы получаете сообщение об ошибке, потому что tf.scan() возвращает одногоtf.Tensor, поэтому линия:

states, output = tf.scan(...) 

... не destructure (распаковка) тензор вернулся из tf.scan() в два значения (states и outputs). Фактически, код пытается обработать результат tf.scan() в виде списка длины 2, и присвоить первый элемент для states и второй элемент в output, но — в отличие от списка Python или кортеж — tf.Tensor не поддерживает это.

Вместо этого вам необходимо извлечь значения из результата tf.scan() вручную. Например, с помощью tf.split():

scan_result = tf.scan(...) 
# Assumes values are packed together along `split_dim`. 
states, output = tf.split(split_dim, 2, scan_result) 

В качестве альтернативы, вы можете использовать tf.slice() или tf.unpack() извлечь соответствующие states и output значения.

Смежные вопросы