2017-02-01 2 views
2

Я могу понять следующее поведение numpy.как работает numpy.where?

>>> a 
array([[ 0. , 0. , 0. ], 
     [ 0. , 0.7, 0. ], 
     [ 0. , 0.3, 0.5], 
     [ 0.6, 0. , 0.8], 
     [ 0.7, 0. , 0. ]]) 
>>> argmax_overlaps = a.argmax(axis=1) 
>>> argmax_overlaps 
array([0, 1, 2, 2, 0]) 
>>> max_overlaps = a[np.arange(5),argmax_overlaps] 
>>> max_overlaps 
array([ 0. , 0.7, 0.5, 0.8, 0.7]) 
>>> gt_argmax_overlaps = a.argmax(axis=0) 
>>> gt_argmax_overlaps 
array([4, 1, 3]) 
>>> gt_max_overlaps = a[gt_argmax_overlaps,np.arange(a.shape[1])] 
>>> gt_max_overlaps 
array([ 0.7, 0.7, 0.8]) 
>>> gt_argmax_overlaps = np.where(a == gt_max_overlaps) 
>>> gt_argmax_overlaps 
(array([1, 3, 4]), array([1, 2, 0])) 

я понял, 0,7, 0,7 и 0,8 является [1,1], а [3,2] и [4,0], так что я получил кортеж (array[1,3,4] and array[1,2,0]) каждый массив из которых, состоящих из 0-я и 1-я индексы этих трех элементов. Затем я попробовал другие примеры, чтобы понять, что мое понимание верное.

>>> np.where(a == [0.3]) 
(array([2]), array([1])) 

0,3 находится в [2,1], поэтому результат выглядит так, как я ожидал. Затем я попробовал

>>> np.where(a == [0.3, 0.5]) 
(array([], dtype=int64),) 

?? Я ожидал увидеть (массив ([2,2]), массив ([2,3])). Почему я вижу результат выше?

>>> np.where(a == [0.7, 0.7, 0.8]) 
(array([1, 3, 4]), array([1, 2, 0])) 
>>> np.where(a == [0.8,0.7,0.7]) 
(array([1]), array([1])) 

Не могу понять и второй результат. Может ли кто-нибудь объяснить это мне? Благодарю.

+1

Используйте 'np.where ((a == 0.3) | (a == 0.5))' и 'np.where ((a == 0.7) | (a == 0.8))', чтобы получить правильный результат , Однако я не знаю, почему 'np.where (a == [0.7, 0.7, 0.8])' работает, а 'np.where (a == [0.7.0.8])' выдает 'DeprecationWarning'. Похож на ошибку. – Khris

+1

Когда 'where' дает неожиданные индексы, посмотрите на массив условий. 'where' просто сообщает вам, где этот массив' True'. – hpaulj

ответ

1

Первое, что нужно знать, это то, что np.where(a == [whatever]) просто показывает вам индексы, где a == [whatever] - это правда. Поэтому вы можете получить подсказку, посмотрев на значение a == [whatever]. В вашем случае, что «работает»:

>>> a == [0.7, 0.7, 0.8] 
array([[False, False, False], 
     [False, True, False], 
     [False, False, False], 
     [False, False, True], 
     [ True, False, False]], dtype=bool) 

Вы не получаете то, что считаете себя. Вы считаете, что запрашивают индексы каждого элемента отдельно, но вместо этого он получает позиции, где значения соответствуют в том же положении в строке. В основном, что это сравнение делает, говорит «для каждой строки, скажите мне, равен ли первый элемент 0.7, независимо от того, равен ли он второй, а третий - 0,8». Затем он возвращает индексы этих совпадающих позиций. Другими словами, сравнение выполняется между целыми строками, а не только отдельными значениями. Для вашего последнего примера:

>>> a == [0.8,0.7,0.7] 
array([[False, False, False], 
     [False, True, False], 
     [False, False, False], 
     [False, False, False], 
     [False, False, False]], dtype=bool) 

Теперь вы получите другой результат. Он не запрашивает «индексы, где a имеет значение 0,8», он запрашивает только индексы, в которых есть 0,8 в начале строки, а также 0,7 в любой из двух последующих позиций.

Этот тип сравнения строк может быть выполнен только в том случае, если значение, которое вы сравниваете, соответствует форме одной строки a. Поэтому, когда вы пытаетесь создать список из двух элементов, он возвращает пустой набор, потому что он пытается сравнить список как скалярное значение с отдельными значениями в вашем массиве.

Результат состоит в том, что вы не можете использовать == в списке значений и ожидать, что он просто скажет вам, где происходит какое-либо из значений. Равенство будет соответствовать значению и позиции (если значение, которое вы сравниваете, имеет ту же форму, что и строка вашего массива), или попытается сравнить весь список как скаляр (если форма не совпадает) , Если вы хотите найти значения независимо друг от друга, что вам нужно сделать что-то вроде того, что предложило Khris в комментарии:

np.where((a==0.3)|(a==0.5)) 

То есть, вам нужно сделать две (или более) отдельных сравнение с отдельными значениями, а не одно сравнение со списком значений.

+0

Вау, это был случай. python умный и странный :) –

Смежные вопросы