2016-01-19 2 views
2

Я пытаюсь найти самый быстрый способ получить функциональность оператора numpy 'where' в массиве 2D numpy; а именно, получение индексов, в которых выполняется условие. Это намного медленнее, чем другие языки, которые я использовал (например, IDL, Matlab).Самый быстрый способ найти индексы состояния в массиве numpy

У меня есть cythonized функция, которая проходит через массив в вложенных циклах. Скорей всего на порядок, но я хотел бы увеличить производительность еще больше, если это возможно.

TEST.py:

from cython_where import * 
import time 
import numpy as np 

data = np.zeros((2600,5200)) 
data[100:200,100:200] = 10 

t0 = time.time() 
inds,ct = cython_where(data,'EQ',10) 
print time.time() - t0 

t1 = time.time() 
tmp = np.where(data == 10) 
print time.time() - t1 

Моя cython_where.pyx программа:

from __future__ import division 
import numpy as np 
cimport numpy as np 
cimport cython 

DTYPE1 = np.float 
ctypedef np.float_t DTYPE1_t 
DTYPE2 = np.int 
ctypedef np.int_t DTYPE2_t 

@cython.boundscheck(False) 
@cython.wraparound(False) 
@cython.nonecheck(False) 

def cython_where(np.ndarray[DTYPE1_t, ndim=2] data, oper, DTYPE1_t val): 
    assert data.dtype == DTYPE1 

    cdef int xmax = data.shape[0] 
    cdef int ymax = data.shape[1] 
    cdef unsigned int x, y 
    cdef int count = 0 
    cdef np.ndarray[DTYPE2_t, ndim=1] xind = np.zeros(100000,dtype=int) 
    cdef np.ndarray[DTYPE2_t, ndim=1] yind = np.zeros(100000,dtype=int) 
    if(oper == 'EQ' or oper == 'eq'): #I didn't want to include GT, GE, LT, LE here 
    for x in xrange(xmax): 
    for y in xrange(ymax): 
     if(data[x,y] == val): 
     xind[count] = x 
     yind[count] = y 
     count += 1 

return tuple([xind[0:count],yind[0:count]]),count 

Выход test.py: cython_test]$ python TEST.py 0.0139019489288 0.0982608795166

Я также попытался NumPy-х argwhere, который с таким же быстром, как where. Я довольно новичок в numpy и cython, поэтому, если у вас есть другие идеи, чтобы действительно повысить производительность, я все уши!

+0

Как говорится в названии, я хочу, чтобы самый быстрый способ найти индексы 2D-массива при условии условия (например, arr == 2). Я уже улучшил работу с numpy where с помощью cythonization, как я объяснил выше. –

+0

Вы говорите о numpy.where, но документацию numpy.where даете в качестве примера: ix = np.in1d ​​(x.ravel(), goodvalues) .reshape (x.shape) для извлечения индексов. Вы попробовали? Это лучше? Или (a == 10) .nonzero()? –

+0

@ P.Brunet, я пробовал это, и он немного медленнее обычного np.where (x == val). Я не уверен, почему вы использовали бы этот метод, если не будете испытывать несколько значений. –

ответ

3

Статьи:

  • Numpy можно ускорить на плоский массив для усиления 4x:

    %timeit np.where(data==10) 
    1 loops, best of 3: 105 ms per loop 
    
    %timeit np.unravel_index(np.where(data.ravel()==10),data.shape) 
    10 loops, best of 3: 26.0 ms per loop 
    

Я думаю, что вы можете оптимизировать Cython код с этим, избегая вычисления k=i*ncol+j для каждой ячейки.

  • Numba дают простую альтернативу:

    from numba import jit 
    dtype=data.dtype 
    @jit(nopython=True) 
    def numbaeq(flatdata,x,nrow,ncol): 
        size=ncol*nrow 
        ix=np.empty(size,dtype=dtype) 
        jx=np.empty(size,dtype=dtype) 
        count=0 
        k=0 
        while k<size: 
        if flatdata[k]==x : 
         ix[count]=k//ncol 
         jx[count]=k%ncol 
         count+=1 
        k+=1   
        return ix[:count],jx[:count] 
    
    def whereequal(data,x): return numbaeq(data.ravel(),x,*data.shape) 
    

, который дает:

%timeit whereequal(data,10) 
    10 loops, best of 3: 20.2 ms per loop 

Не большой оптимизации для Numba на такой проблеме, при выполнении Cython.

  • k//ncol и k%ncol могут быть вычислены одновременно с оптимизированной divmod операции.
  • Конечные шаги - это язык ассемблера и параллелизация, но это другие виды спорта.
Смежные вопросы