2016-08-23 2 views
1

Есть несколько хороших примеров, как преобразовать массив NumPy в массив Java, но не наоборот - как преобразовать данные из объекта Java обратно в массив NumPy. У меня есть сценарий Python, как это:Быстрое преобразование массива Java в массив NumPy (Py4J)

from py4j.java_gateway import JavaGateway 
    gateway = JavaGateway()   # connect to the JVM 
    my_java = gateway.jvm.JavaClass(); # my Java object 
    .... 
    int_array=my_java.doSomething(int_array); # do something 

    my_numpy=np.zeros((size_y,size_x)); 
    for jj in range(size_y): 
     for ii in range(size_x): 
      my_numpy[jj,ii]=int_array[jj][ii]; 

my_numpy является массив Numpy, int_array это массив Java целых чисел - int[ ][ ] вид массива. Инициализирован в сценарии Python как:

int_class=gateway.jvm.int  # make int class 
    double_class=gateway.jvm.double # make double class 

    int_array = gateway.new_array(int_class,size_y,size_x) 
    double_array = gateway.new_array(double_class,size_y,size_x) 

Хотя, это работает, как это, это не самый быстрый способ и работает довольно медленно - в течение ~ массива 1000x1000, преобразование заняло более 5 минут.

Есть ли способ, как сделать это в разумные сроки?

Если я пытаюсь:

test=np.array(int_array) 

я получаю:

ValueError: invalid __array_struct__ 

ответ

0

У меня была аналогичная проблема, просто пытается построить спектральные векторы (Java массивы) я получил со стороны Java с помощью py4j. Здесь преобразование из массива Java в список Python осуществляется с помощью функции list(). Это может дать некоторые подсказки относительно того, как использовать его для заполнения Numpy массивы ...

vectors = space.getVectorsAsArray(); # Java array (MxN) 
wvl = space.getAverageWavelengths(); # Java array (N) 

wavelengths = list(wvl) 

import matplotlib.pyplot as mp 
mp.hold 
for i, dataset in enumerate(vectors): 
    mp.plot(wavelengths, list(dataset)) 

ли это быстрее, чем вложенная для петель, которые вы использовали, я не могу сказать, но это также делает трюк:

import numpy 
from numpy import array 
x = array(wavelengths) 
v = array(list(vectors)) 

mp.plot(x, numpy.rot90(v)) 
Смежные вопросы