2015-03-22 2 views
1

Я пытаюсь реализовать алгоритм быстрой сортировки, используя numba в Python.Как ускорить сортировку с помощью numba?

Это, кажется, намного медленнее функции сортировки numpy.

Как я могу улучшить его? Мой код здесь:

import numba as nb 

@nb.autojit 
def quick_sort(list_): 
    """ 
    Iterative version of quick sort 
    """ 
    #temp_stack = [] 
    #temp_stack.append((left,right)) 

    max_depth = 1000 

    left = 0 
    right = list_.shape[0]-1 

    i_stack_pos = 0 
    a_temp_stack = np.ndarray((max_depth, 2), dtype=np.int32) 
    a_temp_stack[i_stack_pos,0] = left 
    a_temp_stack[i_stack_pos,1] = right 
    i_stack_pos+=1 
    #Main loop to pop and push items until stack is empty 

    while i_stack_pos>0: 

     i_stack_pos-=1 
     right = a_temp_stack[ i_stack_pos, 1 ] 
     left = a_temp_stack[ i_stack_pos, 0 ] 

     piv = partition(list_,left,right) 
     #If items in the left of the pivot push them to the stack 
     if piv-1 > left: 
      #temp_stack.append((left,piv-1)) 

      a_temp_stack[ i_stack_pos, 0 ] = left 
      a_temp_stack[ i_stack_pos, 1 ] = piv-1 
      i_stack_pos+=1 
     #If items in the right of the pivot push them to the stack 
     if piv+1 < right: 
      a_temp_stack[ i_stack_pos, 0 ] = piv+1 
      a_temp_stack[ i_stack_pos, 1 ] = right 
      i_stack_pos+=1 

@nb.autojit(nopython=True) 
def partition(list_, left, right): 
    """ 
    Partition method 
    """ 
    #Pivot first element in the array 
    piv = list_[left] 
    i = left + 1 
    j = right 

    while 1: 
     while i <= j and list_[i] <= piv: 
      i +=1 
     while j >= i and list_[j] >= piv: 
      j -=1 
     if j <= i: 
      break 
     #Exchange items 
     list_[i], list_[j] = list_[j], list_[i] 
    #Exchange pivot to the right position 
    list_[left], list_[j] = list_[j], list_[left] 
    return j 

Мой тестовый код здесь:

x = np.random.random_integers(0,1000,1000000) 
    y = x.copy() 

    quick_sort(y) 

    z = np.sort(x) 

    np.testing.assert_array_equal(z, y) 

    y = x.copy() 
    with Timer('nb'): 
     numba_fns.quick_sort(y) 

    with Timer('np'): 
     x = np.sort(x) 

UPDATE:

Я переписал функцию, чтобы заставить циклическую часть кода для запуска в nopython Режим. Цикл while не вызывает сбоя nopython. Тем не менее, я не получил каких-либо улучшений производительности:

@nb.autojit 
def quick_sort2(list_): 
    """ 
    Iterative version of quick sort 
    """ 

    max_depth = 1000 

    left  = 0 
    right  = list_.shape[0]-1 

    i_stack_pos = 0 
    a_temp_stack = np.ndarray((max_depth, 2), dtype=np.int32) 
    a_temp_stack[i_stack_pos,0] = left 
    a_temp_stack[i_stack_pos,1] = right 
    i_stack_pos+=1 
    #Main loop to pop and push items until stack is empty 

    return _quick_sort2(list_, a_temp_stack, left, right) 

@nb.autojit(nopython=True) 
def _quick_sort2(list_, a_temp_stack, left, right): 

    i_stack_pos = 1 
    while i_stack_pos>0: 

     i_stack_pos-=1 
     right = a_temp_stack[ i_stack_pos, 1 ] 
     left = a_temp_stack[ i_stack_pos, 0 ] 

     piv = partition(list_,left,right) 
     #If items in the left of the pivot push them to the stack 
     if piv-1 > left:    
      a_temp_stack[ i_stack_pos, 0 ] = left 
      a_temp_stack[ i_stack_pos, 1 ] = piv-1 
      i_stack_pos+=1 
     if piv+1 < right: 
      a_temp_stack[ i_stack_pos, 0 ] = piv+1 
      a_temp_stack[ i_stack_pos, 1 ] = right 
      i_stack_pos+=1 

@nb.autojit(nopython=True) 
def partition(list_, left, right): 
    """ 
    Partition method 
    """ 
    #Pivot first element in the array 
    piv = list_[left] 
    i = left + 1 
    j = right 

    while 1: 
     while i <= j and list_[i] <= piv: 
      i +=1 
     while j >= i and list_[j] >= piv: 
      j -=1 
     if j <= i: 
      break 
     #Exchange items 
     list_[i], list_[j] = list_[j], list_[i] 
    #Exchange pivot to the right position 
    list_[left], list_[j] = list_[j], list_[left] 
    return j 
+2

Даже с JIT компилятором, это, вероятно, маловероятно, что вы собираетесь бить алгоритм, реализованный в [прямой C] ​​(https://github.com/numpy/numpy/blob/master /numpy/core/src/npysort/quicksort.c.src). Возможно также, что ваш код возвращается в [режим объекта] (http://numba.pydata.org/numba-doc/0.17.0/user/troubleshoot.html#the-compiled-code-is-too- медленный). – Seth

+2

Вы делаете это упражнение 'numba' или потому, что вам нужна быстрая сортировка? – hpaulj

+0

Это не упражнение. Мне нужна быстрая сортировка в функции numba, которую я пишу. Функция вызывает quicksort несколько раз. Для запуска в режиме nopython функция не может использовать функцию сортировки numpy, поэтому мне нужно написать свой собственный. – Ginger

ответ

3

Одно небольшое предложение, которое может помочь (но, как вы правильно сказали, в комментариях к вашему вопросу, вы будете бороться, чтобы победить чистый C реализация):

Вы хотите убедиться, что большинство из них выполнено в режиме «nopython» (@jit(nopython=True)). Добавьте это перед вашими функциями и посмотрите, где он сломается. Также позвоните по телефону inspect_types() и проверьте, правильно ли он их идентифицирует.

Единственное, что выделяется в коде, который, как представляется, может быть в режиме объекта (в отличие от режима nopython), - это выделение массива numpy. Хотя numba может компилировать циклы отдельно в режиме nopython, я не знаю, может ли это сделать для while-loops. Вызовите inspect_types.

Мой обычный workround для создания массивов numpy при обеспечении остального в режиме nopython заключается в создании оберточной функции.

@nb.jit(nopython=True) # make sure it can be done in nopython mode 
def _quick_sort_impl(list_,output_array): 
    ...most of your code goes here... 

@nb.jit 
def quick_sort(list_): 
    # this code won't compile in nopython mode, but it's 
    # short and isolated 
    max_depth = 1000 
    a_temp_stack = np.ndarray((max_depth, 2), dtype=np.int32) 
    _quick_sort_impl(list_,a_temp_stack) 
+2

Итак ... быстрый тест предполагает, что он выполняет цикл while в режиме nopython, поэтому мое «улучшение» не имеет заметных различий. Единственное, что я вижу, это то, что он заканчивает компиляцию (и выбор) между 4 различными версиями 'partion' (на основе типа integer' left' и 'right'). Я не могу поверить, что это поможет слишком много. – DavidW

3

В общем, если вы не заставляйте режим nopython у вас есть высокие шансы не получить никакого улучшения производительности. Приводя из the docs about nopython mode:

[nopython] режим производит самый высокий код производительности, но требует, чтобы коренные типы всех значений в функции можно сделать вывод, и что никакие новые объекты не выделяются

Поэтому ваш вызов np.ndarray вызывает режим объекта и, следовательно, замедляет работу кода. Попробуйте выделить массив работы из-за пределы функции, как:

def quick_sort(list_): 

    max_depth = 1000 
    temp_stack_ = np.array((max_depth, 2), dtype=np.int32) 

    _quick_sort(list_, temp_stack_) 

... 

@numba.jit(nopython=True) 
def _quick_sort(list_, temp_stack_): 
    ... 
Смежные вопросы