Я пытаюсь найти самый быстрый способ получить функциональность оператора numpy 'where' в массиве 2D numpy; а именно, получение индексов, в которых выполняется условие. Это намного медленнее, чем другие языки, которые я использовал (например, IDL, Matlab).Самый быстрый способ найти индексы состояния в массиве numpy
У меня есть cythonized функция, которая проходит через массив в вложенных циклах. Скорей всего на порядок, но я хотел бы увеличить производительность еще больше, если это возможно.
TEST.py:
from cython_where import *
import time
import numpy as np
data = np.zeros((2600,5200))
data[100:200,100:200] = 10
t0 = time.time()
inds,ct = cython_where(data,'EQ',10)
print time.time() - t0
t1 = time.time()
tmp = np.where(data == 10)
print time.time() - t1
Моя cython_where.pyx программа:
from __future__ import division
import numpy as np
cimport numpy as np
cimport cython
DTYPE1 = np.float
ctypedef np.float_t DTYPE1_t
DTYPE2 = np.int
ctypedef np.int_t DTYPE2_t
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def cython_where(np.ndarray[DTYPE1_t, ndim=2] data, oper, DTYPE1_t val):
assert data.dtype == DTYPE1
cdef int xmax = data.shape[0]
cdef int ymax = data.shape[1]
cdef unsigned int x, y
cdef int count = 0
cdef np.ndarray[DTYPE2_t, ndim=1] xind = np.zeros(100000,dtype=int)
cdef np.ndarray[DTYPE2_t, ndim=1] yind = np.zeros(100000,dtype=int)
if(oper == 'EQ' or oper == 'eq'): #I didn't want to include GT, GE, LT, LE here
for x in xrange(xmax):
for y in xrange(ymax):
if(data[x,y] == val):
xind[count] = x
yind[count] = y
count += 1
return tuple([xind[0:count],yind[0:count]]),count
Выход test.py: cython_test]$ python TEST.py 0.0139019489288 0.0982608795166
Я также попытался NumPy-х argwhere
, который с таким же быстром, как where
. Я довольно новичок в numpy и cython, поэтому, если у вас есть другие идеи, чтобы действительно повысить производительность, я все уши!
Как говорится в названии, я хочу, чтобы самый быстрый способ найти индексы 2D-массива при условии условия (например, arr == 2). Я уже улучшил работу с numpy where с помощью cythonization, как я объяснил выше. –
Вы говорите о numpy.where, но документацию numpy.where даете в качестве примера: ix = np.in1d (x.ravel(), goodvalues) .reshape (x.shape) для извлечения индексов. Вы попробовали? Это лучше? Или (a == 10) .nonzero()? –
@ P.Brunet, я пробовал это, и он немного медленнее обычного np.where (x == val). Я не уверен, почему вы использовали бы этот метод, если не будете испытывать несколько значений. –