Вот многоразовая функция, основанная на коде от @ Офир-Карми:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
import itertools
import numpy as np
def gridshow(grid_x, grid_y, data, **kwargs):
vmin = kwargs.pop("vmin", None)
vmax = kwargs.pop("vmax", None)
data = np.array(data).reshape(-1)
# there should be data for (n-1)x(m-1) cells
assert (grid_x.shape[0] - 1) * (grid_y.shape[0] - 1) == data.shape[0], "Wrong number of data points. grid_x=%s, grid_y=%s, data=%s" % (grid_x.shape, grid_y.shape, data.shape)
ptchs = []
for j, i in itertools.product(xrange(len(grid_y) - 1), xrange(len(grid_x) - 1)):
xy = grid_x[i], grid_y[j]
width = grid_x[i+1] - grid_x[i]
height = grid_y[j+1] - grid_y[j]
ptchs.append(Rectangle(xy=xy, width=width, height=height, rasterized=True, linewidth=0, linestyle="None"))
p = PatchCollection(ptchs, linewidth=0, **kwargs)
p.set_array(np.array(data))
p.set_clim(vmin, vmax)
ax = plt.gca()
ax.set_aspect("equal")
plt.xlim([grid_x[0], grid_x[-1]])
plt.ylim([grid_y[0], grid_y[-1]])
ret = ax.add_collection(p)
plt.sci(ret)
return ret
if __name__ == "__main__":
grid_x = np.linspace(0, 20, 21) + np.random.randn(21)/5.0
grid_y = np.linspace(0, 18, 19) + np.random.randn(19)/5.0
grid_x = np.round(grid_x, 2)
grid_y = np.round(grid_y, 2)
data = np.random.randn((grid_x.shape[0] -1) * (grid_y.shape[0] -1))
fig = plt.figure()
ax = fig.add_subplot(111)
gridshow(grid_x, grid_y, data, alpha=1.0)
plt.savefig("test.png")
Я не совсем уверен, что производительность для больших сетей и если **kwargs
следует наложить на PatchCollection
. И между некоторыми прямоугольниками, кажется, 1px пробелы, вероятно, из-за плохого округления. Возможно, dx, width, height
необходимо согласовать floor
/ceil
со следующим полным пикселом.
Другое решение с использованием rtree
и imshow
:
import matplotlib.pyplot as plt
import numpy as np
from rtree import index
def gridshow(grid_x, grid_y, data, rows=200, cols=200, eps=1e-3, **kwargs):
grid_x1, grid_y1 = np.meshgrid(grid_x, grid_y)
grid_x2 = grid_x1[:-1, :-1].flat
grid_y2 = grid_y1[:-1, :-1].flat
grid_x3 = grid_x1[1:, 1:].flat
grid_y3 = grid_y1[1:, 1:].flat
grid_j = np.linspace(grid_x[0], grid_x[-1], cols)
grid_i = np.linspace(grid_y[0], grid_y[-1], rows)
j, i = np.meshgrid(grid_j, grid_i)
i = i.flat
j = j.flat
im = np.empty((rows, cols), dtype=np.float64)
idx = index.Index()
for m, (x0, y0, x1, y1) in enumerate(zip(grid_x2, grid_y2, grid_x3, grid_y3)):
idx.insert(m, (x0, y0, x1, y1))
for k, (i0, j0) in enumerate(zip(i, j)):
ind = next(idx.intersection((j0-eps, i0-eps, j0+eps, i0+eps)))
im[np.unravel_index(k, im.shape)] = data[np.unravel_index(ind, data.shape)]
fig = plt.gca()
return plt.imshow(im, interpolation="nearest")
if __name__ == "__main__":
grid_x = np.linspace(0, 200, 201) + np.random.randn(201)/5.0
grid_y = np.linspace(0, 108, 109) + np.random.randn(109)/5.0
grid_x = np.round(grid_x, 2)
grid_y = np.round(grid_y, 2)
data = np.random.randn((grid_x.shape[0] -1) * (grid_y.shape[0] -1))
fig = plt.figure()
ax = fig.add_subplot(111)
gridshow(grid_x, grid_y, data, alpha=1.0)
plt.savefig("test.png")
Что такое значение 'grid_x'? Второй «grid_x» должен быть «grid_y»? –
массивы '' grid'' являются шагами по осям x/y. Таким образом, точка '' A [i, j] '' соответствует позиции '' grid_x [i], grid_y [j] '' на сетке (или прямоугольник '' grid_x [i: i + 1], grid_y [ J: J + 1], ''). – allo
Если шаги сетки определяют точки, у вас есть '' n + 1'' записи в '' grid_x'' для прямоугольников '' n'' слева направо. '' imshow'' отображает только один квадрат для каждой точки, поэтому вы можете видеть записи '' A'' либо как верхние левые точки квадратов, либо как средние точки, если они равноудалены, Не важно. – allo