1

Недавно я обновил tesnorflow от Rev8 до Rev12. В Rev8 по умолчанию флаг «state_is_tuple» в rnn_cell.LSTMCell установлен в False, поэтому я инициализировал свою ячейку LSTM со списком, см. Код ниже.Как инициализировать LSTMCell с кортежем

#model definition 
lstm_cell = rnn_cell.LSTMCell(self.config.hidden_dim) 
outputs, states = tf.nn.rnn(lstm_cell, data, initial_state=self.init_state) 


#init_state place holder and feed_dict 
def add_placeholders(self): 
    self.init_state = tf.placeholder("float", [None, self.cell_size]) 

def get_feed_dict(self, data, label): 
    feed_dict = {self.input_data: data, 
      self.input_label: reg_label, 
      self.init_state: np.zeros((self.config.batch_size, self.cell_size))} 
    return feed_dict 

В Rev12, флаг по умолчанию «state_is_tuple» установлено значение Правда, для того, чтобы сделать свой старый код работать, я должен был явно включить флаг значение False. Тем не менее, теперь я получил предупреждение от tensorflow говоря: «Использование каскадного состояние медленнее и вскоре устареет Использование state_is_tuple = True»

Я пытался инициализировать LSTM клетки с кортеж, изменив определение шаблонного для self.init_state к следующему:

self.init_state = tf.placeholder("float", (None, self.cell_size)) 

, но теперь я получил сообщение об ошибке, говорящее:

«объект„тензорный“не Iterable»

Кто-нибудь знает, как сделать эту работу?

+1

К сожалению, кортеж сложной структуры. Вы * должны * явно сделать «init_state» заполнителем? Было бы гораздо лучше использовать 'cell.zero_state' вместо этого. Не беспокойтесь, вы можете передать состояние в 'feed_dict' через пробежки – martianwars

ответ

1

Подача «нулевого состояния» на LSTM намного проще, используя cell.zero_state. Вам не нужно явно определять начальное состояние в качестве заполнителя. Определите его как тензор и подавайте, если потребуется. Вот как это работает,

lstm_cell = rnn_cell.LSTMCell(self.config.hidden_dim) 
self.initial_state = lstm_cell.zero_state(self.batch_size, dtype=tf.float32) 
outputs, states = tf.nn.rnn(lstm_cell, data, initial_state=self.init_state) 

Если вы хотите, чтобы кормить некоторое другое значение в качестве начального состояния, Скажем next_state = states[-1] к примеру, вычислить его в сессии и передать его в feed_dict как -

feed_dict[self.initial_state] = next_state 

В контексте вашего вопроса, lstm_cell.zero_state() должно быть достаточно.


Несвязанный, но помните, что вы можете пройти как тензоры, так и заполнители в словаре фида! Вот как работает self.initial_state в приведенном выше примере. Посмотрите на рабочий стол PTB Tutorial.