Мы можем быть немного умнее об индексации и экономить около ~ 4 раза в цене.
Первая позволяет строить некоторые данные правильной формы:
seed = np.random.randint(0, 100, (200,206))
data = np.random.randint(0, 100, (4e5,206))
seed[:, 0] = np.arange(200)
data[:, 0] = np.random.randint(0, 200, 4e5)
diam = np.empty(200)
Время оригинального ответа:
%%timeit
for i in range(200):
diam[i] = spd.cdist(seed[np.newaxis, i, 1:], data[data[:, 0]==i][:,1:]).max()
1 loops, best of 3: 1.35 s per loop
ответ moarningsun в:
%%timeit
seed_repeated = seed[data[:,0]]
dist_to_center = np.sqrt(np.sum((data[:,1:]-seed_repeated[:,1:])**2, axis=1))
diam = np.zeros(len(seed))
np.maximum.at(diam, data[:,0], dist_to_center)
1 loops, best of 3: 1.33 s per loop
ответ Divakar в:
%%timeit
data_sorted = data[data[:, 0].argsort()]
seed_ext = np.repeat(seed,np.bincount(data_sorted[:,0]),axis=0)
dists = np.sqrt(((data_sorted[:,1:] - seed_ext[:,1:])**2).sum(1))
shift_idx = np.append(0,np.nonzero(np.diff(data_sorted[:,0]))[0]+1)
diam_out = np.maximum.reduceat(dists,shift_idx)
1 loops, best of 3: 1.65 s per loop
Как мы видим, на самом деле ничего не получилось с помощью векторизованных решений, помимо большего объема памяти.Чтобы избежать этого, мы должны вернуться к первоначальному ответу, который действительно правильный способ сделать это, и вместо того, чтобы попытаться уменьшить количество индексации:
%%timeit
idx = data[:,0].argsort()
bins = np.bincount(data[:,0])
counter = 0
for i in range(200):
data_slice = idx[counter: counter+bins[i]]
diam[i] = spd.cdist(seed[None, i, 1:], data[data_slice, 1:]).max()
counter += bins[i]
1 loops, best of 3: 281 ms per loop
Дважды проверьте ответ:
np.allclose(diam, dam_out)
True
Это проблема с допущением, что петли python плохи. Они часто бывают, но не во всех ситуациях.
Это на самом деле довольно разумный код. Вы для цикла относительно малы по сравнению с количеством вычислений, которое делается внутри 'cdist'. Поскольку 'cdist' - довольно оптимальная скорость, вряд ли будет большой. – Daniel
@Ophion - повторяющиеся данные линейного поиска [:, 0] == i' можно избежать, чтобы снизить сложность от O (n ** 2) до O (n log (n)) или даже O (n). –
@moarningsun Правда, но то, что возможно и что доступно, - это две разные вещи, особенно учитывая, что это O (n * m) не O (n^2) и n << m. До сих пор никакое решение не было быстрее, чем у OP, и все они имеют значительно больший объем издержек памяти. – Daniel