2017-02-12 2 views
3

Я ищу способ TensorFlow реализации чего-то подобного функции list.index() на Python.Как найти индекс первого элемента соответствия в TensorFlow

Учитывая матрицу и значение, которое нужно найти, я хочу знать первое вхождение значения в каждой строке матрицы.

Например,

m is a <batch_size, 100> matrix of integers 
val = 23 

result = [0] * batch_size 
for i, row_elems in enumerate(m): 
    result[i] = row_elems.index(val) 

Я не могу предположить, что «вал» появляется только один раз в каждой строке, в противном случае я бы реализовать его с помощью tf.argmax (м == VAL). В моем случае важно получить индекс с первым возникновением «val», а не любым.

ответ

6

Кажется, что tf.argmax работает как np.argmax (согласно the test), который вернет первый индекс, если имеется несколько вхождений максимального значения. Вы можете использовать tf.argmax(tf.cast(tf.equal(m, val), tf.int32), axis=1), чтобы получить то, что хотите. Однако в настоящее время поведение tf.argmax не определено в случае множественных вхождений максимального значения.

Если вас беспокоит неопределенное поведение, вы можете применить tf.argmin к возвращаемому значению tf.where, как предложил @Igor Цветков. Например,

# test with tensorflow r1.0 
import tensorflow as tf 

val = 3 
m = tf.placeholder(tf.int32) 
m_feed = [[0 , 0, val, 0, val], 
      [val, 0, val, val, 0], 
      [0 , val, 0, 0, 0]] 

tmp_indices = tf.where(tf.equal(m, val)) 
result = tf.segment_min(tmp_indices[:, 1], tmp_indices[:, 0]) 

with tf.Session() as sess: 
    print(sess.run(result, feed_dict={m: m_feed})) # [2, 0, 1] 

tf.segment_min Обратите внимание, что повысит InvalidArgumentError, когда есть строка, не содержащая val. В вашем коде row_elems.index(val) тоже вызовет исключение, если row_elems не содержит val.

+0

Это очень полезно! что, если мы хотим обновить val, чтобы стать new_val? Я задал этот вопрос здесь: https://stackoverflow.com/questions/45684445/tensorflow-update-first-matching-element-in-each-row – reese0106

1

Выглядит немного некрасиво, но работает (при условии, m и val являются тензорами):

idx = list() 
for t in tf.unpack(m, axis=0): 
    idx.append(tf.reduce_min(tf.where(tf.equal(t, val)))) 
idx = tf.pack(idx, axis=0) 

EDIT: Как Yaroslav Bulatov упоминалось, вы могли бы достичь того же результата с tf.map_fn:

def index1d(t): 
    return tf.reduce_min(tf.where(tf.equal(t, val))) 

idx = tf.map_fn(index1d, m, dtype=tf.int64) 
+1

'map_fn' может сделать это без распаковки –

Смежные вопросы