2016-02-02 3 views
26

Этот вопрос касается доступа к отдельным элементам в тензоре, например [[1,2,3]]. Мне нужно получить доступ к внутреннему элементу [1,2,3] (это может быть выполнено с использованием .eval() или sess.run()), но это займет больше времени, когда размер тензора огромен)Tensorflow python: Доступ к отдельным элементам в тензоре

Есть ли любой способ сделать то же самое быстрее?

Спасибо заранее.

ответ

0

Я подозреваю, что остальная часть вычислений требует времени, а не доступа к одному элементу.

Также может потребоваться копия из любой памяти, поэтому, если она находится на графической карте, ее необходимо скопировать обратно в оперативную память сначала, а затем получить доступ к вашему элементу. Если это так, вы можете пропустить его, добавив операцию tenorflow, чтобы перенести первый элемент, и только верните это.

36

Существует два основных способа доступа к подмножествам элементов в тензоре, каждый из которых должен работать для вашего примера.

  1. Используйте оператор индексирования (на основе tf.slice()), чтобы извлечь непрерывный кусок от тензора.

    input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 
    
    output = input[0, :] 
    print sess.run(output) # ==> [1 2 3] 
    

    Оператор индексирования поддерживает многие из тех же характеристик среза, что и NumPy.

  2. Используйте опцию tf.gather() для выбора несмежного среза из тензора.

    input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 
    
    output = tf.gather(input, 0) 
    print sess.run(output) # ==> [1 2 3] 
    
    output = tf.gather(input, [0, 2]) 
    print sess.run(output) # ==> [[1 2 3] [7 8 9]] 
    

    Обратите внимание, что tf.gather() только позволяет вам выбрать целые кусочки в 0-ом измерении (целые строки в примере матрицы), так что вам может понадобиться tf.reshape() или tf.transpose() ваш вклад, чтобы получить соответствующие элементы.

+1

«.. вам может понадобиться tf.reshape() или tf.transpose() ваш вход для получения соответствующих элементов». -> или используйте 'tf.gather_nd'? –

1

Вы просто не можете получить значение 0-го элемента [[1,2,3]] без пробега() - нин или Eval() - Инж операцию, которая будет получать его. Поскольку перед запуском или «eval» у вас есть только описание того, как получить этот внутренний элемент (поскольку TF использует символические графики/вычисления). Поэтому, даже если вы будете использовать tf.gather/tf.slice, вам все равно придется получать значения этих операций через eval/run. См. Ответ @ mrry.

Смежные вопросы