2016-12-25 6 views
1

Я пытаюсь создать свой собственный итератор данных для использования с mxnet. Когда я запускаю его, я получаю ошибку:Ошибка при использовании моего собственного итератора данных python в mxnet

Traceback (most recent call last): 
File "train.py", line 24, in <module> 
batch_end_callback = mx.callback.Speedometer(batch_size, 1) # output progress for each 200 data batches 
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/model.py", line 811, in fit 
sym_gen=self.sym_gen) 
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/model.py", line 236, in _train_multi_device 
executor_manager.load_data_batch(data_batch) 
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/executor_manager.py", line 410, in load_data_batch 
self.curr_execgrp.load_data_batch(data_batch) 
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/executor_manager.py", line 257, in load_data_batch 
_load_data(data_batch, self.data_arrays) 
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/executor_manager.py", line 93, in _load_data 
_load_general(batch.data, targets) 
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/executor_manager.py", line 89, in _load_general 
d_src[slice_idx].copyto(d_dst) 
AttributeError: 'numpy.ndarray' object has no attribute 'copy' 

Я предполагаю, что это имеет какое-то отношение к тому, как я возвращаю данные. См. Мой код итератора данных ниже:

from mxnet.io import DataIter, DataDesc 
import csv 
from random import shuffle 
import numpy as np 
from cv2 import imread, resize 

class MyData(DataIter): 
    def __init__(self, root_dir, flist_name, batch_size, size=(256,256), shuffle=True): 
     super(MyData, self).__init__() 
     self.batch_size = batch_size 
     self.root_dir = root_dir 
     self.flist_name = flist_name 
     self.size = size 
     self.shuffle = shuffle 

     self.data = [] 
     with open(flist_name, 'rb') as csvfile: 
      csvreader = csv.reader(csvfile) 
      for row in csvreader: 
       self.data.append(row) 
     self.num_data = len(self.data) 
     self.provide_data = [DataDesc('data', (self.batch_size, 6, self.size[0], self.size[1]), np.float32)] 
     self.provide_label = [DataDesc('Pa_label', (self.batch_size, 1), np.float32)] 
     self.reset() 

    def reset(self): 
     """Reset the iterator. """ 
     self.cursor = 0 
     if self.shuffle: 
      shuffle(self.data) 

    def iter_next(self): 
     """Iterate to next batch. 
     Returns 
     ------- 
     has_next : boolean 
      Whether the move is successful. 
     """ 
     self.cursor += self.batch_size 
     success = self.cursor < self.num_data 
     return success 

    def getdata(self): 
     """Get data of current batch. 
     Returns 
     ------- 
     data : NDArray 
      The data of current batch. 
     """ 
     datalist = self.data[self.cursor:self.cursor+self.batch_size] 
     ret = np.ndarray(shape=(0,6,self.size[0],self.size[1]), dtype=np.float32) 
     for data_row in datalist: 
      img1 = resize(imread(data_row[0]), self.size) 
      img2 = resize(imread(data_row[1]), self.size) 
      img1 = np.rollaxis(img1, 2) 
      img2 = np.rollaxis(img2, 2) 
      img = np.concatenate((img1, img2), 0) 
      imge = np.expand_dims(img,0) 
      ret = np.append(ret, imge, 0) 

     print ret.shape 
     pad = self.batch_size - ret.shape[0] 
     if pad > 0: 
      ret = np.append(ret, np.zeros((pad, 6, self.size[0], self.size[1])), 0) 
     return ret 

    def getlabel(self): 
     """Get label of current batch. 
     Returns 
     ------- 
     label : NDArray 
      The label of current batch. 
     """ 
     datalist = self.data[self.cursor:self.cursor+self.batch_size] 
     ret = np.ndarray(shape=(0,1,1,1), dtype=np.float32) 
     for data_row in datalist: 
      label = np.ndarray(shape=(1,1,1,1), dtype=np.float32) 
      label[0,0,0,0] = float(data_row[2])/float(data_row[5]) 
      np.append(ret, label, 0) 

     pad = self.batch_size - ret.shape[0] 
     np.append(ret, np.zeros((pad, 1,1,1)), 0) 
     return ret 

    def getindex(self): 
     """Get index of the current batch. 
     Returns 
     ------- 
     index : numpy.array 
      The index of current batch 
     """ 
     return self.cursor 

    def getpad(self): 
     """Get the number of padding examples in current batch. 
     Returns 
     ------- 
     pad : int 
      Number of padding examples in current batch 
     """ 
     if self.cursor + self.batch_size > self.num_data: 
      return self.cursor + self.batch_size - self.num_data 
     else: 
      return 0 

ответ

1

numpy.ndarray не имеет copyto способ. Попробуйте использовать mx.ndarray.

+0

[Он делает.] (Https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.copy.html) –

Смежные вопросы