2015-05-21 2 views
1

Возможно ли вернуть значения, вычисленные в функции сканирования, не возвращая их обратно в функцию сканирования.Python - theano.scan() - значения возвращаемой функции без обратной связи с контуром

например.

import theano 
import theano.tensor as T 
import numpy as np 

theano.config.exception_verbosity='high' 
theano.config.optimizer='None' 

def f(seq_v, prev_v): 
    return seq_v*prev_v#, prev_v+1 

a = T.dvector('a') 

ini = T.constant(1, dtype=theano.config.floatX) 

result, updates = theano.scan(fn=f, 
           outputs_info=[ini], 
           sequences=[a], 
           non_sequences=None) 

fn = theano.function(inputs=[a], outputs=result) 

A = np.arange(1,5) 
out = fn(A) 

print 'Values:\nf:\t{}'.format(out) 

который дает

Values: 
f: [ 1. 2. 6. 24.] 

Однако, я хотел бы выход обоих значений в f() без подачи последнего значения обратно в функцию сканирования:

def f(seq_v, prev_v): 
    return seq_v*prev_v, prev_v+1 

, чтобы дать что-то например:

Values: 
f: [ [1. , 2.] [2. , 3.] [6. , 4.] [24. , 5.] ] 

(я просто хотел бы отметить, что эта проблема является тривиальной, но я хотел бы использовать эту идею для отладки функций сканирования и проверки выходных значений)

ответ

1

Вы должны указать outputs_info быть None для выхода, который вы не хотите возвращаться в f. Для получения дополнительной информации см. the scan documentation. Ниже приведен пример, который должен делать то, что вы хотите.

import theano 
import theano.tensor as T 
import numpy as np 

theano.config.exception_verbosity='high' 
theano.config.optimizer='None' 

def f(seq_v, prev_v): 
    return seq_v*prev_v, seq_v+1 

a = T.vector('a') 

ini = T.constant(1, dtype=theano.config.floatX) 

result, updates = theano.scan(fn=f, 
           outputs_info=[ini,None], 
           sequences=[a]) 

fn = theano.function(inputs=[a], outputs=result) 

A = np.arange(1,5, dtype=T.config.floatX) 
out = fn(A) 

print('Values:\nf:\t{}'.format(out)) 

Выход:

Values: 
f: [array([ 1., 2., 6., 24.], dtype=float32), array([ 2., 3., 4., 5.], dtype=float32)] 
Смежные вопросы