2015-03-25 1 views
1

Я пытаюсь обновить переменную Theano в функции, упрощенный, как это:Обновление переменной, которая была с приведением с theano.tensor.cast()

copy_func = theano.function(
    inputs=[idx], 
    updates=[ 
     (a_variable, T.set_subtensor(a_variable[some_ptr], another_variable[idx])) 
    ] 
) 

Моя проблема заключается в том, что я получаю ошибку

TypeError: ('update target must be a SharedVariable', Elemwise{Cast{int32}}.0) 

как я получаю эту переменную через используя следующие (в основном скопированы из deeplearning.net учебники) (another_variable инициализируется аналогично):

a_variable = theano.shared(np.asarray(data, 
           dtype=theano.config.floatX), 
       borrow=True) 
print type(a_variable) 
a_variable = T.cast(a_variable, 'int32') 
print type(a_variable) 
не

, который печатает

<class 'theano.tensor.sharedvar.TensorSharedVariable'> 
<class 'theano.tensor.var.TensorVariable'> 

, то есть переменная больше не «общий», объясняя ошибку. Это имеет смысл, так как я предполагаю, что переменная теперь просто просто литой вид исходных общих поплавков. Но как я могу обновить переменную, которая эффективно применяется?

ответ

1

Я решил это сам, и ответ был, конечно, очевидным.

Вместо перекрывая a_variable переменные с литой версией, я сохранил uncasted версии:

a_variable_casted = T.cast(a_variable, 'int32') 

Обновление теперь делается на a_variable, в то время как a_variable_casted используется для выполнения вычислений a_variable использовались для ранее.

Возможно, существует более элегантный способ сделать это, и в этом случае я бы хотел его услышать!

Смежные вопросы