2013-06-03 3 views
2

Я пытаюсь изучить cython; однако, я должен делать что-то неправильно. Эта небольшая часть тестового кода работает примерно в 50 раз медленнее, чем моя векторизованная версия numpy. Может кто-нибудь, пожалуйста, скажите, почему мой cython медленнее моего питона? Благодарю.cython работает медленнее, чем numpy для вычисления расстояния

Код вычисляет расстояние между точкой в ​​R^3, loc и множеством точек в R^3, точках.

import numpy as np 
cimport numpy as np 
import cython 
cimport cython 

DTYPE = np.float64 
ctypedef np.float64_t DTYPE_t 
@cython.boundscheck(False) # turn of bounds-checking for entire function 
@cython.wraparound(False) 
@cython.nonecheck(False) 
def distMeasureCython(np.ndarray[DTYPE_t, ndim=2] points, np.ndarray[DTYPE_t, ndim=1] loc): 
    cdef unsigned int i 
    cdef unsigned int L = points.shape[0] 
    cdef np.ndarray[DTYPE_t, ndim=1] d = np.zeros(L) 
    for i in xrange(0,L): 
     d[i] = np.sqrt((points[i,0] - loc[0])**2 + (points[i,1] - loc[1])**2 + (points[i,2] - loc[2])**2) 
    return d 

Это числовой код, с которым он сравнивается.

from numpy import * 
N = 1e6 
points = random.uniform(0,1,(N,3)) 
loc = random.uniform(0,1,(3)) 

def distMeasureNumpy(points,loc): 
    d = points - loc 
    d = sqrt(sum(d*d,axis=1)) 
    return d 

Версия numpy/python занимает около 44 мс, а версия cython занимает около 2 секунд. Я запускаю python 2.7 на Mac OSX. Я использую команду% timeit ipython для запуска двух функций.

+0

Я не вижу ничего явно неправильного с версией cython (и я удивлен, что это так медленно, как есть). Тем не менее, вы не будете бить правильное векторное выражение numpy с помощью cython. Cython лучше (и вообще неплохо) для операций, которые не могут быть векторизованы. Кроме того, вы можете немного ускорить свою версию numpy, используя 'd = np.hypot (* d.T)'. –

+0

Запустили ли вы 'cython -a your_code.pyx' и просмотрели' your_code.html'? Это удобный способ проверить код C, созданный cython, и узнать, сколько было переведено на C и сколько еще работает на уровне Python. –

ответ

5

Звонок np.sqrt, который является вызовом функции Python, убивает вашу производительность. Вы вычисляете квадратный корень скалярного значения с плавающей запятой, поэтому вы должны использовать функцию sqrt из библиотеки математики C. Ниже приведена измененная версия вашего кода:

import numpy as np 
cimport numpy as np 
import cython 
cimport cython 

from libc.math cimport sqrt 

DTYPE = np.float64 
ctypedef np.float64_t DTYPE_t 
@cython.boundscheck(False) # turn of bounds-checking for entire function 
@cython.wraparound(False) 
@cython.nonecheck(False) 
def distMeasureCython(np.ndarray[DTYPE_t, ndim=2] points, 
         np.ndarray[DTYPE_t, ndim=1] loc): 
    cdef unsigned int i 
    cdef unsigned int L = points.shape[0] 
    cdef np.ndarray[DTYPE_t, ndim=1] d = np.zeros(L) 
    for i in xrange(0,L): 
     d[i] = sqrt((points[i,0] - loc[0])**2 + 
        (points[i,1] - loc[1])**2 + 
        (points[i,2] - loc[2])**2) 
    return d 

Следующие демонстрируют улучшение производительности. Ваш исходный код находится в модуле check_speed_original, а модифицированная версия в check_speed:

In [11]: import check_speed_original 

In [12]: import check_speed 

Настройка тестовых данных:

In [13]: N = 10**6 

In [14]: points = random.uniform(0,1,(N,3)) 

In [15]: loc = random.uniform(0,1,(3,)) 

Оригинальная версия занимает 1,26 секунды на моем компьютере:

In [16]: %timeit check_speed_original.distMeasureCython(points, loc) 
1 loops, best of 3: 1.26 s per loop 

Модифицированная версия принимает 4.47 миллисекунды:

In [17]: %timeit check_speed.distMeasureCython(points, loc) 
100 loops, best of 3: 4.47 ms per loop 

В случае, если кто беспокоится о том, что результаты могут быть разными:

In [18]: d1 = check_speed.distMeasureCython(points, loc) 

In [19]: d2 = check_speed_original.distMeasureCython(points, loc) 

In [20]: np.all(d1 == d2) 
Out[20]: True 
+0

Это работает! Спасибо. Я получаю время выполнения, о котором вы упомянули выше. Спасибо за подсказку о html тоже. Это мой первый вопрос, поэтому я должен закрыть его сейчас, когда он разрешен? Спасибо за помощь WW. – plancherel

3

Как уже упоминалось, это numpy.sqrt вызов в коде. Тем не менее, я думаю, что не нужно использовать cdef extern, так как Cython предоставляет эти основные библиотеки C/C++. (см. the docs). Таким образом, вы могли бы просто воспроизвести это следующим образом:

from libc.math cimport sqrt 

Просто, чтобы избавиться от накладных расходов.

+0

Хорошая точка, спасибо. Я обновил свой ответ. –

Смежные вопросы