2015-10-07 3 views
3

Я хотел бы создать генератор, который возвращает массив на лету. Например:Генератор Python с массивом numpy

import numpy as np 
def my_gen(): 
    c = np.ones(5) 
    j = 0 
    t = 10 
    while j < t: 
     c[0] = j 
     yield c 
     j += 1 

С простой цикл:

for g in my_gen(): 
    print (g) 

Я получил то, что я хочу. Но с list(my_gen()) у меня есть список, который содержит всегда одно и то же.

я выкопал немного глубже, и я нахожу, когда я yield c.tolist() вместо yield c, все прошло нормально ...

Я просто не могу объяснить себе, почему это странное поведение ...

ответ

6

Это потому, что c всегда указывается на одну и ту же ссылку на массив numpy, вы просто меняете элемент внутри c в функции генератора.

При печати просто печатается массив c в этот конкретный момент, поэтому вы правильно печатаете значения.

Но когда вы используете list(my_gen()), вы продолжаете добавлять в список ту же ссылку на c массив numpy, и, следовательно, любые изменения в этом массиве numpy также отражаются в ранее добавленных элементах в списке.

Он работает для вас, когда вы делаете yield c.tolist(), потому что создает новый список из Numpy массива, поэтому вы продолжаете добавлять новые объекты списков к list и, следовательно, изменениям в будущем к c не отражают в ранее добавленных списках ,

+0

Отлично! Спасибо за ваш ответ :) – XXXXXL

0

Хорошо, я думаю, потому что в этом генераторе, так как я возвращаю ту же ссылку, генератор дает всегда одно и то же. Если я yield np.array(c), это сработает ...

3

Альтернативный генератор возвращает копию списка. Я сохраняю np.ones() как удобный способ создания чисел, но сразу преобразую его в список (только один раз) (array.tolist() относительно дорого).

Я даю c[:], чтобы избежать проблемы с текущей версией.

def gen_c(): 
     c = np.ones(5,dtype=int).tolist() 
     j = 0 
     t = 10 
     while j < t: 
       c[0] = j 
       yield c[:] 
       j += 1 


In [54]: list(gen_c()) 
Out[54]: 
[[0, 1, 1, 1, 1], 
[1, 1, 1, 1, 1], 
[2, 1, 1, 1, 1], 
[3, 1, 1, 1, 1], 
[4, 1, 1, 1, 1], 
[5, 1, 1, 1, 1], 
[6, 1, 1, 1, 1], 
[7, 1, 1, 1, 1], 
[8, 1, 1, 1, 1], 
[9, 1, 1, 1, 1]] 
In [55]: np.array(list(gen_c())) 
Out[55]: 
array([[0, 1, 1, 1, 1], 
     [1, 1, 1, 1, 1], 
     [2, 1, 1, 1, 1], 
     [3, 1, 1, 1, 1], 
     [4, 1, 1, 1, 1], 
     [5, 1, 1, 1, 1], 
     [6, 1, 1, 1, 1], 
     [7, 1, 1, 1, 1], 
     [8, 1, 1, 1, 1], 
     [9, 1, 1, 1, 1]]) 
+0

Спасибо вам большое! Это значительно улучшает производительность !!! Большой! – XXXXXL