2016-06-07 2 views
2

У меня возникают проблемы с эффективным использованием переменных областей. Я хочу определить некоторые переменные для весов, смещений и внутреннего состояния простой повторяющейся сети. Я вызываю get_saver() один раз после определения графа по умолчанию. Затем я перебираю партию выборок, используя tf.scan.Области переменных в Tensorflow

import tensorflow as tf 
import math 
import numpy as np 

INPUTS = 10 
HIDDEN_1 = 2 
BATCH_SIZE = 3 

def batch_vm2(m, x): 
    [input_size, output_size] = m.get_shape().as_list() 

    input_shape = tf.shape(x) 
    batch_rank = input_shape.get_shape()[0].value - 1 
    batch_shape = input_shape[:batch_rank] 
    output_shape = tf.concat(0, [batch_shape, [output_size]]) 

    x = tf.reshape(x, [-1, input_size]) 
    y = tf.matmul(x, m) 

    y = tf.reshape(y, output_shape) 

    return y 

def get_saver(): 
    with tf.variable_scope('h1') as scope: 
     weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(float(INPUTS)))) 
     biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0)) 
     state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False) 
     saver = tf.train.Saver([weights, biases, state]) 
    return saver 


def load(sess, saver, checkpoint_dir = None): 

     print("loading a session") 
     ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 
     if ckpt and ckpt.model_checkpoint_path: 
      saver.restore(sess, ckpt.model_checkpoint_path) 
     else: 
      raise Exception("no checkpoint found") 
     return 

def iterate_state(prev_state_tuple, input): 
    with tf.variable_scope('h1') as scope: 
     scope.reuse_variables() 
     weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(float(INPUTS)))) 
     biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0)) 
     state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False) 
     print("input: ",input.get_shape()) 
     matmuladd = batch_vm2(weights, input) + biases 
     matmulpri = tf.Print(matmuladd,[matmuladd], message=" malmul -> ") 
     #matmulvec = tf.reshape(matmuladd, [HIDDEN_1]) 
     #state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0)) 
     print("prev state: ",prev_state_tuple.get_shape()) 
     unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple) 
     prev_state = unpacked_state 
     state = state.assign(4.2*(0.9* prev_state + 0.1*matmuladd)) 
     #output = tf.nn.relu(state) 
     output = tf.nn.tanh(state) 
     state = tf.Print(state, [state], message=" state -> ") 
     output = tf.Print(output, [output], message=" output -> ") 
     #output = matmulpri 
     print(" state: ", state.get_shape()) 
     print(" output: ", output.get_shape()) 
     concat_result = tf.concat(0,[state, output]) 
     print (" concat return: ", concat_result.get_shape()) 
     return concat_result 

def data_iter(): 
    while True: 
     idxs = np.random.rand(BATCH_SIZE, INPUTS) 
     yield idxs 

with tf.Graph().as_default(): 
    inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS)) 

    saver = get_saver() 
    initial_state = tf.zeros([HIDDEN_1], 
          name='initial_state') 
    initial_out = tf.zeros([HIDDEN_1], 
          name='initial_out') 
    #concat_tensor = tf.concat(0,[initial_state, initial_out]) 
    concat_tensor = tf.concat(0,[initial_state, initial_out]) 
    print(" init state: ",initial_state.get_shape()) 
    print(" init out: ",initial_out.get_shape()) 
    print(" concat: ",concat_tensor.get_shape()) 
    scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan') 
    print ("scanout shape: ", scanout.get_shape()) 
    state, output = tf.split(1,2,scanout, name='split_scan_output') 
    print(" end state: ",state.get_shape()) 
    print(" end out: ",output.get_shape()) 

    #output,state,diagnostic = create_graph(inputs, state, prev_state) 

    sess = tf.Session() 
    # Run the Op to initialize the variables. 
    sess.run(tf.initialize_all_variables()) 
    if False: 
     load(sess, saver) 
    iter_ = data_iter() 
    for i in xrange(0, 5): 
     print ("iteration: ",i) 
     input_data = iter_.next() 
     out,st,so = sess.run([output,state,scanout], feed_dict={ inputs: input_data}) 
     saver.save(sess, 'my-model', global_step=1+i) 
     print("input vec: ", input_data) 
     print("state vec: ", st) 
     print("output vec: ", out) 
     print(" end state (runtime): ",st.shape) 
     print(" end out (runtime): ",out.shape) 
     print(" end scanout (runtime): ",so.shape) 

Моя надежда будет иметь переменные, извлеченные из get_variable внутри scan цит быть таким же, как определено внутри get_saver вызова. Однако, если я запускаю этот пример кода, я получаю следующий результат с ошибками:

(' init state: ', TensorShape([Dimension(2)])) 
(' init out: ', TensorShape([Dimension(2)])) 
(' concat: ', TensorShape([Dimension(4)])) 
Traceback (most recent call last): 
    File "cycles_in_graphs_with_scan.py", line 88, in <module> 
    scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan') 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/functional_ops.py", line 345, in scan 
    back_prop=back_prop, swap_memory=swap_memory) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1873, in while_loop 
    result = context.BuildLoop(cond, body, loop_vars) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1749, in BuildLoop 
    body_result = body(*vars_for_body_with_tensor_arrays) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/functional_ops.py", line 339, in compute 
    a = fn(a, elems_ta.read(i)) 
    File "cycles_in_graphs_with_scan.py", line 47, in iterate_state 
    weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(float(INPUTS)))) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 732, in get_variable 
    partitioner=partitioner, validate_shape=validate_shape) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 596, in get_variable 
    partitioner=partitioner, validate_shape=validate_shape) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 161, in get_variable 
    caching_device=caching_device, validate_shape=validate_shape) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 454, in _get_single_variable 
    " Did you mean to set reuse=None in VarScope?" % name) 
ValueError: Variable state_scan/h1/W does not exist, disallowed. Did you mean to set reuse=None in VarScope? 

любая идея, что я делаю неправильно в этом примере?

ответ

0
if False: 
    load(sess, saver) 

Эти две линии приводят к неинициализированным переменным.

Смежные вопросы