Я пытаюсь реализовать алгоритм быстрой сортировки, используя numba в Python.Как ускорить сортировку с помощью numba?
Это, кажется, намного медленнее функции сортировки numpy.
Как я могу улучшить его? Мой код здесь:
import numba as nb
@nb.autojit
def quick_sort(list_):
"""
Iterative version of quick sort
"""
#temp_stack = []
#temp_stack.append((left,right))
max_depth = 1000
left = 0
right = list_.shape[0]-1
i_stack_pos = 0
a_temp_stack = np.ndarray((max_depth, 2), dtype=np.int32)
a_temp_stack[i_stack_pos,0] = left
a_temp_stack[i_stack_pos,1] = right
i_stack_pos+=1
#Main loop to pop and push items until stack is empty
while i_stack_pos>0:
i_stack_pos-=1
right = a_temp_stack[ i_stack_pos, 1 ]
left = a_temp_stack[ i_stack_pos, 0 ]
piv = partition(list_,left,right)
#If items in the left of the pivot push them to the stack
if piv-1 > left:
#temp_stack.append((left,piv-1))
a_temp_stack[ i_stack_pos, 0 ] = left
a_temp_stack[ i_stack_pos, 1 ] = piv-1
i_stack_pos+=1
#If items in the right of the pivot push them to the stack
if piv+1 < right:
a_temp_stack[ i_stack_pos, 0 ] = piv+1
a_temp_stack[ i_stack_pos, 1 ] = right
i_stack_pos+=1
@nb.autojit(nopython=True)
def partition(list_, left, right):
"""
Partition method
"""
#Pivot first element in the array
piv = list_[left]
i = left + 1
j = right
while 1:
while i <= j and list_[i] <= piv:
i +=1
while j >= i and list_[j] >= piv:
j -=1
if j <= i:
break
#Exchange items
list_[i], list_[j] = list_[j], list_[i]
#Exchange pivot to the right position
list_[left], list_[j] = list_[j], list_[left]
return j
Мой тестовый код здесь:
x = np.random.random_integers(0,1000,1000000)
y = x.copy()
quick_sort(y)
z = np.sort(x)
np.testing.assert_array_equal(z, y)
y = x.copy()
with Timer('nb'):
numba_fns.quick_sort(y)
with Timer('np'):
x = np.sort(x)
UPDATE:
Я переписал функцию, чтобы заставить циклическую часть кода для запуска в nopython Режим. Цикл while не вызывает сбоя nopython. Тем не менее, я не получил каких-либо улучшений производительности:
@nb.autojit
def quick_sort2(list_):
"""
Iterative version of quick sort
"""
max_depth = 1000
left = 0
right = list_.shape[0]-1
i_stack_pos = 0
a_temp_stack = np.ndarray((max_depth, 2), dtype=np.int32)
a_temp_stack[i_stack_pos,0] = left
a_temp_stack[i_stack_pos,1] = right
i_stack_pos+=1
#Main loop to pop and push items until stack is empty
return _quick_sort2(list_, a_temp_stack, left, right)
@nb.autojit(nopython=True)
def _quick_sort2(list_, a_temp_stack, left, right):
i_stack_pos = 1
while i_stack_pos>0:
i_stack_pos-=1
right = a_temp_stack[ i_stack_pos, 1 ]
left = a_temp_stack[ i_stack_pos, 0 ]
piv = partition(list_,left,right)
#If items in the left of the pivot push them to the stack
if piv-1 > left:
a_temp_stack[ i_stack_pos, 0 ] = left
a_temp_stack[ i_stack_pos, 1 ] = piv-1
i_stack_pos+=1
if piv+1 < right:
a_temp_stack[ i_stack_pos, 0 ] = piv+1
a_temp_stack[ i_stack_pos, 1 ] = right
i_stack_pos+=1
@nb.autojit(nopython=True)
def partition(list_, left, right):
"""
Partition method
"""
#Pivot first element in the array
piv = list_[left]
i = left + 1
j = right
while 1:
while i <= j and list_[i] <= piv:
i +=1
while j >= i and list_[j] >= piv:
j -=1
if j <= i:
break
#Exchange items
list_[i], list_[j] = list_[j], list_[i]
#Exchange pivot to the right position
list_[left], list_[j] = list_[j], list_[left]
return j
Даже с JIT компилятором, это, вероятно, маловероятно, что вы собираетесь бить алгоритм, реализованный в [прямой C] (https://github.com/numpy/numpy/blob/master /numpy/core/src/npysort/quicksort.c.src). Возможно также, что ваш код возвращается в [режим объекта] (http://numba.pydata.org/numba-doc/0.17.0/user/troubleshoot.html#the-compiled-code-is-too- медленный). – Seth
Вы делаете это упражнение 'numba' или потому, что вам нужна быстрая сортировка? – hpaulj
Это не упражнение. Мне нужна быстрая сортировка в функции numba, которую я пишу. Функция вызывает quicksort несколько раз. Для запуска в режиме nopython функция не может использовать функцию сортировки numpy, поэтому мне нужно написать свой собственный. – Ginger