2016-11-23 4 views
0

Здравствуйте Я новичок в tensorflow я сделать мой график, и я пытаюсь запустить его, но я получаю эту ошибку:как изменить ранг объекта tensorflow

ValueError: Shape (3,) must have rank 2 

, которая исходит от этой линии

tf.matmul(tf.matmul(phix, tf.transpose(param)), B) 

Я проверил ранг моей переменной phix, и результат 0, я не понял почему, потому что его форма (3,3). это мой сценарий, не могли бы вы мне помочь.

import tensorflow as tf 
def phi(x, b, w, B): 
    z = tf.matmul(x,w) 
    phix = tf.cos(z) + b # attention shapes 
    phix /= tf.sqrt(float(float(int(w.get_shape()[0]))/2.)) 

    return phix, B 



def model(phix, B, param) : 
    return tf.matmul(tf.matmul(phix, tf.transpose(param)), B) 

B = tf.constant(1., shape=[1]) # constant (non trainable) 
x2 = tf.placeholder(tf.float32, shape=[3,1]) # variable 
W2 = tf.Variable(initial_value = tf.random_normal(shape=[1, 3]),trainable=False ,name="W2") 
b2 = tf.Variable(tf.random_uniform(shape=[3]),trainable=False , name="b2") 
y = tf.Variable(tf.random_uniform(shape=[3,3]),trainable=False) 
param = tf.Variable(tf.random_uniform(shape=[3])) # variable trainable 
norm = tf.sqrt(tf.reduce_sum(tf.square(param)))**2 ## attention ici c'est par ce que param est un vecteur de une dimmention 
phix, B = phi (x2,b2,W2,B) 
lamda = tf.constant(1. , shape=[3]) 
cost = tf.nn.l2_loss(y - model(phix, B, param)) + lamda * norm 

opt = tf.train.GradientDescentOptimizer(0.5).minimize(cost) 

init = tf.initialize_all_variables() 
with tf.Session() as sess: 
    sess.run(init) 

    for i in range(10): 
     sess.run(opt) 
     z4_op = sess.run(opt , feed_dict = {x2: [[1.0],[2.0],[3.0]]}) 

    print(z4_op)  

ответ

0

Чтобы использовать tf.matmul, входы должны быть 2D матрицы (в противном случае, вы можете использовать tf.mul или изменить тензоры быть 2D). Перестройка будет выглядеть следующим образом:

def model(phix, B, param): 
    one = tf.reshape(tf.matmul(phix, tf.reshape(param, [3, 1])), [3, 1]) 
    return tf.matmul(one, tf.reshape(B, [1,1])) 

Вы также должны кормить значение x2 заполнителя каждый раз, когда вы выполняете цит, что зависит от него:

for i in range(10): 
    sess.run(opt, feed_dict = {x2: [[1.0],[2.0],[3.0]]}) 
    z4_op = sess.run(opt, feed_dict = {x2: [[1.0],[2.0],[3.0]]})