2016-07-19 4 views
1

У меня возникли проблемы с установкой очень простой модели в тензорном потоке. Если у меня есть столбец входных данных, который является постоянным, мой вывод всегда сходится, чтобы получить одно и то же значение для всех строк, которое является средним для моих выходных данных, y_, даже если в x_ есть еще один столбец, который имеет достаточно информации для воспроизведения y_ точно. Вот небольшой пример.Модель Tensorflow всегда производит среднее значение

import tensorflow as tf 

def weight_variable(shape): 
    """Initialize the weights with random weights""" 
    initial = tf.truncated_normal(shape, stddev=0.1, dtype=tf.float64) 
    return tf.Variable(initial) 

#Initialize my data 
x = tf.constant([[1.0,1.0],[1.0,2.0],[1.0,3.0]], dtype=tf.float64) 
y_ = tf.constant([1.0,2.0,3.0], dtype=tf.float64) 

w = weight_variable((2,1)) 
y = tf.matmul(x,w) 

error = tf.reduce_mean(tf.square(y_ - y)) 

train_step = tf.train.AdamOptimizer(1e-5).minimize(error) 

with tf.Session() as sess: 
    sess.run(tf.initialize_all_variables()) 

    #Train the model and output every 1000 iterations 
    for i in range(1000000): 
     sess.run(train_step) 
     err = sess.run(error) 

     if i % 1000 == 0: 
      print "\nerr:", err 
      print "x: ", sess.run(x) 
      print "w: ", sess.run(w) 
      print "y_: ", sess.run(y_) 
      print "y: ", sess.run(y) 

Этот пример всегда сходится к w = [2,0] и y = [2,2,2]. Это гладкая функция с минимумом при w = [0,1] и y = [1,2,3], где функция ошибки равна нулю. Почему он не сходится к этому? Я также попытался использовать градиентный спуск, и я попытался изменить скорость обучения.

ответ

3

Ваша цель y_ = tf.constant([1.0,2.0,3.0], dtype=tf.float64) имеет форму (1, 3). Выходной сигнал tf.matmul(x, w) имеет форму (3, 1). Таким образом, y_ - y имеет форму (3, 3) согласно правилам радиовещания numpy. Таким образом, вы действительно не оптимизируете функцию, которую, по вашему мнению, вы оптимизируете. Изменение y_ на следующее и дать ему шанс:

y_ = tf.constant([[1.0],[2.0],[3.0]], dtype=tf.float64) 

Это должно сходиться довольно быстро ожидаемый ответ, даже с большой скоростью обучения.

+0

Ох, это действительно проблема. Спасибо! – Tom

Смежные вопросы