Оказывается, что чистый цикл Python может быть намного быстрее, чем индексирование NumPy (или вызовы np.where) в этом случае.
Рассмотрим следующие варианты:
import numpy as np
import collections
import itertools as IT
shape = (2600,5200)
# shape = (26,52)
emiss_data = np.random.random(shape)
obj_data = np.random.random_integers(1, 800, size=shape)
UNIQ_IDS = np.unique(obj_data)
def using_where():
max = np.max
where = np.where
MAX_EMISS = [max(emiss_data[where(obj_data == i)]) for i in UNIQ_IDS]
return MAX_EMISS
def using_index():
max = np.max
MAX_EMISS = [max(emiss_data[obj_data == i]) for i in UNIQ_IDS]
return MAX_EMISS
def using_max():
MAX_EMISS = [(emiss_data[obj_data == i]).max() for i in UNIQ_IDS]
return MAX_EMISS
def using_loop():
result = collections.defaultdict(list)
for val, idx in IT.izip(emiss_data.ravel(), obj_data.ravel()):
result[idx].append(val)
return [max(result[idx]) for idx in UNIQ_IDS]
def using_sort():
uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1
vals = uind.argsort()
count = np.bincount(uind)
start = 0
end = 0
out = np.empty(count.shape[0])
for ind, x in np.ndenumerate(count):
end += x
out[ind] = np.max(np.take(emiss_data, vals[start:end]))
start += x
return out
def using_split():
uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1
vals = uind.argsort()
count = np.bincount(uind)
return [np.take(emiss_data, item).max()
for item in np.split(vals, count.cumsum())[:-1]]
for func in (using_index, using_max, using_loop, using_sort, using_split):
assert using_where() == func()
Вот ориентиры, с shape = (2600,5200)
:
In [57]: %timeit using_loop()
1 loops, best of 3: 9.15 s per loop
In [90]: %timeit using_sort()
1 loops, best of 3: 9.33 s per loop
In [91]: %timeit using_split()
1 loops, best of 3: 9.33 s per loop
In [61]: %timeit using_index()
1 loops, best of 3: 63.2 s per loop
In [62]: %timeit using_max()
1 loops, best of 3: 64.4 s per loop
In [58]: %timeit using_where()
1 loops, best of 3: 112 s per loop
Таким образом using_loop
(чистый Python) оказывается больше, чем 11x быстрее, чем using_where
.
Я не совсем уверен, почему чистый Python быстрее, чем NumPy здесь. Я предполагаю, что чистая версия Python zips (да, каламбур) через оба массива один раз. Он использует тот факт, что, несмотря на все причудливое индексирование, , мы действительно просто хотим посетить каждое значение после. Таким образом, он борется с тем, что нужно точно определить, в какую группу попало каждое значение в emiss_data
. Но это просто расплывчатая спекуляция. Я не знал, что это будет быстрее, пока я не сравню их.
А вы вычисление 'UNIQ_IDS' в этом скрипте или это предопределено? – Daniel
UNIQ_IDS предопределен ... список ints len = 800.Это всего лишь фрагмент кода, извините за путаницу. –