2014-01-04 4 views
15

У меня есть несколько значений, которые определены на той же нерегулярной сетке (x, y, z), что я хочу интерполировать на новую сетку (x1, y1, z1). то есть у меня есть f(x, y, z), g(x, y, z), h(x, y, z) и я хочу рассчитать f(x1, y1, z1), g(x1, y1, z1), h(x1, y1, z1).Ускорение scipy griddata для множественных интерполяций между двумя нерегулярными сетками

На данный момент я делаю это, используя scipy.interpolate.griddata, и он хорошо работает. Однако, поскольку мне приходится выполнять каждую интерполяцию по отдельности и есть много точек, она довольно медленная, с большим количеством дублирования в вычислении (например, поиск ближайших точек, настройка сетки и т. Д.).

Есть ли способ ускорить расчет и уменьшить дублированные вычисления? то есть что-то вдоль линий определения двух сеток, а затем изменение значений для интерполяции?

+0

метод интерполяции Что вы используете, то есть '' nearest', linear' ...? Кроме того, сколько у вас очков в нерегулярной сетке? – Jaime

+0

Я использую линейную интерполяцию (ближайшая не будет достаточно). Исходная сетка (x, y, z) состоит из 3,5 миллионов точек. Новая сетка (x1, y1, z1) состоит из примерно 300 000 точек. Линейная интерполяция занимает ~ 30 с на ноутбуке с процессором i7 со здоровым объемом ОЗУ. У меня есть 6 наборов значений для интерполяции, так что это является основным узким местом для меня. –

ответ

27

Есть несколько вещей происходит каждый раз, когда вы делаете вызов scipy.interpolate.griddata:

  1. Во-первых, призыв к sp.spatial.qhull.Dealunay сделан триангуляции нерегулярные сетки координат.
  2. Затем для каждой точки новой сетки выполняется поиск триангуляции, чтобы найти треугольник (на самом деле, в котором симплекс, который в вашем трехмерном случае будет находиться в тетраэдре).
  3. Вычислены барицентрические координаты каждой новой точки сетки относительно вершин охватывающего симплекса.
  4. Интерполированные значения вычисляются для этой точки сетки с использованием барицентрических координат и значений функции в вершинах прилагаемого симплекса.

Первые три шага одинаковы для всех интерполяций, так что, если вы могли бы хранить, для каждой новой точки сетки, индексы вершин вмещающего симплекс и веса для интерполяции, вы бы свести к минимуму количество вычислений много. К сожалению, это не так легко сделать непосредственно с функциональными возможностями, доступными, хотя это действительно возможно:

import scipy.interpolate as spint 
import scipy.spatial.qhull as qhull 
import itertools 

def interp_weights(xyz, uvw): 
    tri = qhull.Delaunay(xyz) 
    simplex = tri.find_simplex(uvw) 
    vertices = np.take(tri.simplices, simplex, axis=0) 
    temp = np.take(tri.transform, simplex, axis=0) 
    delta = uvw - temp[:, d] 
    bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta) 
    return vertices, np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True))) 

def interpolate(values, vtx, wts): 
    return np.einsum('nj,nj->n', np.take(values, vtx), wts) 

Функция interp_weights делает расчеты для первых трех шагов, которые я перечислил выше.Тогда функция interpolate использует эти значения calcualted сделать шаг-очень быстро:

m, n, d = 3.5e4, 3e3, 3 
# make sure no new grid point is extrapolated 
bounding_cube = np.array(list(itertools.product([0, 1], repeat=d))) 
xyz = np.vstack((bounding_cube, 
       np.random.rand(m - len(bounding_cube), d))) 
f = np.random.rand(m) 
g = np.random.rand(m) 
uvw = np.random.rand(n, d) 

In [2]: vtx, wts = interp_weights(xyz, uvw) 

In [3]: np.allclose(interpolate(f, vtx, wts), spint.griddata(xyz, f, uvw)) 
Out[3]: True 

In [4]: %timeit spint.griddata(xyz, f, uvw) 
1 loops, best of 3: 2.81 s per loop 

In [5]: %timeit interp_weights(xyz, uvw) 
1 loops, best of 3: 2.79 s per loop 

In [6]: %timeit interpolate(f, vtx, wts) 
10000 loops, best of 3: 66.4 us per loop 

In [7]: %timeit interpolate(g, vtx, wts) 
10000 loops, best of 3: 67 us per loop 

Так первым, он делает то же самое, как griddata, это хорошо. Во-вторых, настройка интерполяции, то есть вычисление vtx и wts примерно совпадает с вызовом griddata. Но в-третьих, вы можете теперь интерполировать для разных значений в одной и той же сетке практически мгновенно.

Единственное, что не предусмотрено здесь griddata, это присвоение fill_value точкам, которые необходимо экстраполировать. Вы можете сделать это путем проверки точек, для которых по крайней мере один из весов является отрицательным, например:

def interpolate(values, vtx, wts, fill_value=np.nan): 
    ret = np.einsum('nj,nj->n', np.take(values, vtx), wts) 
    ret[np.any(wts < 0, axis=1)] = fill_value 
    return ret 
+2

Отлично, именно то, что я был после! Огромное спасибо. Было бы неплохо, если бы такая функциональность была включена в scipy для будущих версий griddata. –

+0

отлично работает для меня! Он также использует гораздо меньше памяти, чем scipy.itnerpolate.griddata, когда вы запускаете несколько раз на моей машине. – Matthias123

+0

Кроме того, 'griddata' вмещает отсутствующие значения/дыры в функции -' nan', которая не работает с этим решением? – FooBar

0

Вы можете попробовать использовать Pandas, так как он обеспечивает высокопроизводительные структуры данных.

Это правда, что метод интерполяции - это обертка интерполяции НО, возможно, с улучшенными структурами вы получаете лучшую скорость.

import pandas as pd; 
wp = pd.Panel(randn(2, 5, 4)); 
wp.interpolate(); 

interpolate() заполняет значения NaN в наборе данных панели с помощью different methods. Надеюсь, что это быстрее, чем Scipy.

Если он не работает, есть один способ повышения производительности (вместо того, чтобы использовать распараллеленную версию коды): использовать Cython и реализовать небольшую подпрограмму в C для использования внутри кода Python. Here у вас есть пример.

3

Большое спасибо Хайме за его решение (даже если я не очень понимаю, как это делается барицентрическое вычисление ...)

Здесь вы найдете пример адаптированы из его дел в 2D:

import scipy.interpolate as spint 
import scipy.spatial.qhull as qhull 
import numpy as np 

def interp_weights(xy, uv,d=2): 
    tri = qhull.Delaunay(xy) 
    simplex = tri.find_simplex(uv) 
    vertices = np.take(tri.simplices, simplex, axis=0) 
    temp = np.take(tri.transform, simplex, axis=0) 
    delta = uv - temp[:, d] 
    bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta) 
    return vertices, np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True))) 

def interpolate(values, vtx, wts): 
    return np.einsum('nj,nj->n', np.take(values, vtx), wts) 

m, n = 101,201 
mi, ni = 1001,2001 

[Y,X]=np.meshgrid(np.linspace(0,1,n),np.linspace(0,2,m)) 
[Yi,Xi]=np.meshgrid(np.linspace(0,1,ni),np.linspace(0,2,mi)) 

xy=np.zeros([X.shape[0]*X.shape[1],2]) 
xy[:,0]=Y.flatten() 
xy[:,1]=X.flatten() 
uv=np.zeros([Xi.shape[0]*Xi.shape[1],2]) 
uv[:,0]=Yi.flatten() 
uv[:,1]=Xi.flatten() 

values=np.cos(2*X)*np.cos(2*Y) 

#Computed once and for all ! 
vtx, wts = interp_weights(xy, uv) 
valuesi=interpolate(values.flatten(), vtx, wts) 
valuesi=valuesi.reshape(Xi.shape[0],Xi.shape[1]) 
print "interpolation error: ",np.mean(valuesi-np.cos(2*Xi)*np.cos(2*Yi)) 
print "interpolation uncertainty: ",np.std(valuesi-np.cos(2*Xi)*np.cos(2*Yi)) 

можно прикладной преобразованию изображения, такие как отображение изображений с udge ускорить

Вы не можете использовать то же определение функции, что и новые координаты, будут меняться на каждой итерации, но вы можете вычислить триангуляцию Один раз для всех.

import scipy.interpolate as spint 
import scipy.spatial.qhull as qhull 
import numpy as np 
import time 

# Definition of the fast interpolation process. May be the Tirangulation process can be removed !! 
def interp_tri(xy): 
    tri = qhull.Delaunay(xy) 
    return tri 


def interpolate(values, tri,uv,d=2): 
    simplex = tri.find_simplex(uv) 
    vertices = np.take(tri.simplices, simplex, axis=0) 
    temp = np.take(tri.transform, simplex, axis=0) 
    delta = uv- temp[:, d] 
    bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta) 
    return np.einsum('nj,nj->n', np.take(values, vertices), np.hstack((bary, 1.0 - bary.sum(axis=1, keepdims=True)))) 

m, n = 101,201 
mi, ni = 101,201 

[Y,X]=np.meshgrid(np.linspace(0,1,n),np.linspace(0,2,m)) 
[Yi,Xi]=np.meshgrid(np.linspace(0,1,ni),np.linspace(0,2,mi)) 

xy=np.zeros([X.shape[0]*X.shape[1],2]) 
xy[:,1]=Y.flatten() 
xy[:,0]=X.flatten() 
uv=np.zeros([Xi.shape[0]*Xi.shape[1],2]) 
# creation of a displacement field 
uv[:,1]=0.5*Yi.flatten()+0.4 
uv[:,0]=1.5*Xi.flatten()-0.7 
values=np.zeros_like(X) 
values[50:70,90:150]=100. 

#Computed once and for all ! 
tri = interp_tri(xy) 
t0=time.time() 
for i in range(0,100): 
    values_interp_Qhull=interpolate(values.flatten(),tri,uv,2).reshape(Xi.shape[0],Xi.shape[1]) 
t_q=(time.time()-t0)/100 

t0=time.time() 
values_interp_griddata=spint.griddata(xy,values.flatten(),uv,fill_value=0).reshape(values.shape[0],values.shape[1]) 
t_g=time.time()-t0 

print "Speed-up:", t_g/t_q 
print "Mean error: ",(values_interp_Qhull-values_interp_griddata).mean() 
print "Standard deviation: ",(values_interp_Qhull-values_interp_griddata).std() 

На моем ноутбуке ускорение составляет от 20 до 40x!

Надежда, которая может помочь кому-то

+0

Функция 'interp_weights' здесь не работает,' delta = uv - temp [:, d] ', так как' d' выходит за границы на 'temp' – christopherlovell

Смежные вопросы