2016-07-12 4 views
1

Я в настоящее время использую matplotlib.pyplot визуализировать некоторые 2D данных:pyplot.imshow прямоугольники

from matplotlib import pyplot as plt 
import numpy as np 
A=np.matrix("1 2 1;3 0 3;1 2 0") # 3x3 matrix with 2D data 
plt.imshow(A, interpolation="nearest") # draws one square per matrix entry 
plt.show() 

Теперь я переехал данные из квадратов в прямоугольники, то есть у меня есть два дополнительные массивы, например:

grid_x = np.array([0.0, 1.0, 4.0, 5.0]) # points on the x-axis 
grid_x = np.array([0.0, 2.5, 4.0, 5.0]) # points on the y-axis 

теперь я хочу сетку с прямоугольниками:

  • верхний левый угол: (grid_x[i], grid_y[j])
  • правый нижний угол: (grid_x[i+1], grid_y[j+1])
  • данные (цвет): A[i,j]

Что такое простой способ построить данные на новой сетке? imshow, по-видимому, можно использовать, я посмотрел на pcolormesh, но его сбивает с сеткой как 2D-массив, используя две матрицы, такие как np.mgrid[0:5:0.5,0:5:0.5] для обычной сетки и построение чего-то похожего на нерегулярное.

Что такое простой способ визуализации прямоугольников?

+0

Что такое значение 'grid_x'? Второй «grid_x» должен быть «grid_y»? –

+0

массивы '' grid'' являются шагами по осям x/y. Таким образом, точка '' A [i, j] '' соответствует позиции '' grid_x [i], grid_y [j] '' на сетке (или прямоугольник '' grid_x [i: i + 1], grid_y [ J: J + 1], ''). – allo

+0

Если шаги сетки определяют точки, у вас есть '' n + 1'' записи в '' grid_x'' для прямоугольников '' n'' слева направо. '' imshow'' отображает только один квадрат для каждой точки, поэтому вы можете видеть записи '' A'' либо как верхние левые точки квадратов, либо как средние точки, если они равноудалены, Не важно. – allo

ответ

1
import matplotlib.pyplot as plt 
from matplotlib.patches import Rectangle 
import matplotlib.cm as cm 
from matplotlib.collections import PatchCollection 
import numpy as np 

A = np.matrix("1 2 1;3 0 3;1 2 0;4 1 2") # 4x3 matrix with 2D data 

grid_x0 = np.array([0.0, 1.0, 4.0, 6.7]) 
grid_y0 = np.array([0.0, 2.5, 4.0, 7.8, 12.4]) 

grid_x1, grid_y1 = np.meshgrid(grid_x0, grid_y0) 
grid_x2 = grid_x1[:-1, :-1].flat 
grid_y2 = grid_y1[:-1, :-1].flat 
widths = np.tile(np.diff(grid_x0)[np.newaxis], (len(grid_y0)-1, 1)).flat 
heights = np.tile(np.diff(grid_y0)[np.newaxis].T, (1, len(grid_x0)-1)).flat 

fig = plt.figure() 
ax = fig.add_subplot(111) 
ptchs = [] 
for x0, y0, w, h in zip(grid_x2, grid_y2, widths, heights): 
    ptchs.append(Rectangle(
     (x0, y0), w, h, 
    )) 
p = PatchCollection(ptchs, cmap=cm.viridis, alpha=0.4) 
p.set_array(np.ravel(A)) 
ax.add_collection(p) 
plt.xlim([0, 8]) 
plt.ylim([0, 13]) 
plt.show() 

enter image description here

Вот еще один способ, с помощью изображения и R-дерево и imshow с colorbar, вам необходимо изменить x-ticks и y-ticks (Есть много SO Q & А о том, как это сделать).

from rtree import index 
import matplotlib.pyplot as plt 
import numpy as np 

eps = 1e-3 

A = np.matrix("1 2 1;3 0 3;1 2 0;4 1 2") # 4x3 matrix with 2D data 
grid_x0 = np.array([0.0, 1.0, 4.0, 6.7]) 
grid_y0 = np.array([0.0, 2.5, 4.0, 7.8, 12.4]) 

grid_x1, grid_y1 = np.meshgrid(grid_x0, grid_y0) 
grid_x2 = grid_x1[:-1, :-1].flat 
grid_y2 = grid_y1[:-1, :-1].flat 
grid_x3 = grid_x1[1:, 1:].flat 
grid_y3 = grid_y1[1:, 1:].flat 

fig = plt.figure() 

rows = 100 
cols = 200 
im = np.zeros((rows, cols), dtype=np.int8) 
grid_j = np.linspace(grid_x0[0], grid_x0[-1], cols) 
grid_i = np.linspace(grid_y0[0], grid_y0[-1], rows) 
j, i = np.meshgrid(grid_j, grid_i) 

i = i.flat 
j = j.flat 

idx = index.Index() 

for m, (x0, y0, x1, y1) in enumerate(zip(grid_x2, grid_y2, grid_x3, grid_y3)): 
    idx.insert(m, (x0, y0, x1, y1)) 


for k, (i0, j0) in enumerate(zip(i, j)): 
    ind = next(idx.intersection((j0-eps, i0-eps, j0+eps, i0+eps))) 

    im[np.unravel_index(k, im.shape)] = A[np.unravel_index(ind, A.shape)] 
plt.imshow(im) 
plt.colorbar() 
plt.show() 

enter image description here

+0

Вы просто рисуете два прямоугольника, а сетка 3x3 должна иметь 9 (со значениями из матрицы '' A'' в моем примере). вы никогда не рисуете недиагональные точки, используя '' zip'' на двух осях. – allo

+0

см. Мой отредактированный ответ –

+0

Похоже, что я хотел. Я просто сделал что-то похожее, опубликую некоторые готовые к использованию функции в одно мгновение. – allo

1

Вот многоразовая функция, основанная на коде от @ Офир-Карми:

import matplotlib.pyplot as plt 
from matplotlib.patches import Rectangle 
from matplotlib.collections import PatchCollection 
import itertools 
import numpy as np 

def gridshow(grid_x, grid_y, data, **kwargs): 
    vmin = kwargs.pop("vmin", None) 
    vmax = kwargs.pop("vmax", None) 
    data = np.array(data).reshape(-1) 
    # there should be data for (n-1)x(m-1) cells 
    assert (grid_x.shape[0] - 1) * (grid_y.shape[0] - 1) == data.shape[0], "Wrong number of data points. grid_x=%s, grid_y=%s, data=%s" % (grid_x.shape, grid_y.shape, data.shape) 
    ptchs = [] 
    for j, i in itertools.product(xrange(len(grid_y) - 1), xrange(len(grid_x) - 1)): 
     xy = grid_x[i], grid_y[j] 
     width = grid_x[i+1] - grid_x[i] 
     height = grid_y[j+1] - grid_y[j] 
     ptchs.append(Rectangle(xy=xy, width=width, height=height, rasterized=True, linewidth=0, linestyle="None")) 
    p = PatchCollection(ptchs, linewidth=0, **kwargs) 
    p.set_array(np.array(data)) 
    p.set_clim(vmin, vmax) 
    ax = plt.gca() 
    ax.set_aspect("equal") 
    plt.xlim([grid_x[0], grid_x[-1]]) 
    plt.ylim([grid_y[0], grid_y[-1]]) 
    ret = ax.add_collection(p) 
    plt.sci(ret) 
    return ret 

if __name__ == "__main__": 
    grid_x = np.linspace(0, 20, 21) + np.random.randn(21)/5.0 
    grid_y = np.linspace(0, 18, 19) + np.random.randn(19)/5.0 
    grid_x = np.round(grid_x, 2) 
    grid_y = np.round(grid_y, 2) 
    data = np.random.randn((grid_x.shape[0] -1) * (grid_y.shape[0] -1)) 

    fig = plt.figure() 
    ax = fig.add_subplot(111) 
    gridshow(grid_x, grid_y, data, alpha=1.0) 
    plt.savefig("test.png") 

example image from the <code>gridshow</code> function

Я не совсем уверен, что производительность для больших сетей и если **kwargs следует наложить на PatchCollection. И между некоторыми прямоугольниками, кажется, 1px пробелы, вероятно, из-за плохого округления. Возможно, dx, width, height необходимо согласовать floor/ceil со следующим полным пикселом.

Другое решение с использованием rtree и imshow:

import matplotlib.pyplot as plt 
import numpy as np 
from rtree import index 

def gridshow(grid_x, grid_y, data, rows=200, cols=200, eps=1e-3, **kwargs): 
    grid_x1, grid_y1 = np.meshgrid(grid_x, grid_y) 
    grid_x2 = grid_x1[:-1, :-1].flat 
    grid_y2 = grid_y1[:-1, :-1].flat 
    grid_x3 = grid_x1[1:, 1:].flat 
    grid_y3 = grid_y1[1:, 1:].flat 

    grid_j = np.linspace(grid_x[0], grid_x[-1], cols) 
    grid_i = np.linspace(grid_y[0], grid_y[-1], rows) 
    j, i = np.meshgrid(grid_j, grid_i) 
    i = i.flat 
    j = j.flat 

    im = np.empty((rows, cols), dtype=np.float64) 

    idx = index.Index() 
    for m, (x0, y0, x1, y1) in enumerate(zip(grid_x2, grid_y2, grid_x3, grid_y3)): 
     idx.insert(m, (x0, y0, x1, y1)) 
    for k, (i0, j0) in enumerate(zip(i, j)): 
     ind = next(idx.intersection((j0-eps, i0-eps, j0+eps, i0+eps))) 
     im[np.unravel_index(k, im.shape)] = data[np.unravel_index(ind, data.shape)] 

    fig = plt.gca() 
    return plt.imshow(im, interpolation="nearest") 


if __name__ == "__main__": 
    grid_x = np.linspace(0, 200, 201) + np.random.randn(201)/5.0 
    grid_y = np.linspace(0, 108, 109) + np.random.randn(109)/5.0 
    grid_x = np.round(grid_x, 2) 
    grid_y = np.round(grid_y, 2) 
    data = np.random.randn((grid_x.shape[0] -1) * (grid_y.shape[0] -1)) 

    fig = plt.figure() 
    ax = fig.add_subplot(111) 
    gridshow(grid_x, grid_y, data, alpha=1.0) 
    plt.savefig("test.png") 

image with imshow and rtree

+0

Вы всегда можете использовать изображение с достаточным разрешением и R-tree для изменения его значений, а затем 'imshow'. –

+0

Все еще работая над этим, индексирование не кажется правильным. Вы знаете, как включить его в функцию типа '' imshow''? Например, цветная панель не работает '' RuntimeError: вы должны сначала определить изображение, например, с помощью imshow''. Также, если это возможно, не нужно определять собственный подзаголовок для получения осей. – allo

+0

Теперь метод 'imshow'' работает для меня. Принимает 3-6 секунд для разрешения данных 200x100 здесь, становится намного хуже для больших разрешений. Однако, вероятно, этого было бы достаточно для моих целей. – allo