2016-05-05 2 views
0

В этой ассоциативной бумаге lstm, http://arxiv.org/abs/1602.03032, они просят переставить сложный тензор.Внедрение перестановки комплексных чисел в TensorFlow

Они предоставили свой код здесь: https://github.com/mohammadpz/Associative_LSTM/blob/master/bricks.py#L79

Я пытаюсь повторить это в tensorflow. Вот что я сделал:

# shape: C x F/2 
# output = self.permutations: [num_copies x cell_size] 
permutations = [] 
indices = numpy.arange(self._dim/2) #[1 ,2 ,3 ...64] 
for i in range(self._num_copies): 
    numpy.random.shuffle(indices) #[4, 48, 32, ...64] 
    permutations.append(numpy.concatenate(
     [indices, 
     [ind + self._dim/2 for ind in indices]])) 
    #you're appending a row with two columns -- a permutation in the first column, and the same permutation + dim/2 for imaginary 
# C x F (numpy) 
self.permutations = tf.constant(numpy.vstack(permutations), dtype = tf.int32) #This is a permutation tensor that has the stored permutations 
# output = self.permutations: [num_copies x cell_size] 

def permute(complex_tensor): #complex tensor is [batch_size x cell_size] 
gather_tensor = tf.gather_nd(complex_tensor, self.permutations) 
return gather_tensor 

В принципе, у меня вопрос: насколько эффективно это может быть сделано в TensorFlow? В любом случае, чтобы размер размера партии был зафиксирован complex tensor?

Кроме того, есть gather_nd лучший способ сделать это? Или лучше сделать цикл for и перебрать по каждой строке в self.permutations с помощью tf.gather?

def permute(self, complex_tensor): 
inputs_permuted = [] 
for i in range(self.permutations.get_shape()[0].value): 
    inputs_permuted.append(
    tf.gather(complex_tensor, self.permutations[i])) 
return tf.concat(0, inputs_permuted) 

Я думал, что gather_nd будет гораздо более эффективным.

ответ

0

Nevermind, я понял, трюк состоит в том, чтобы просто переставить исходный тензор ввода, используя транспонирование tf. Это позволит вам сделать tf.gather на всей матрице. Затем вы можете объединить матрицы вместе. Извините, если это потрачено впустую.

+0

Можете ли вы поделиться некоторым кодом для этого случайно? – bge0

+0

Насколько я могу сказать, 'perm = []' просто переставляет измерение, а не каждую строку – bge0

Смежные вопросы