2017-01-06 4 views
0

Я пытаюсь обучить CNN (Sklearn Neural Network). У меня 4 изображения 128 x 128 пикселей. форма -> (4, 128, 128) Я читаю изображения, как -Обучение CNN с изображениями в нейронной сети sklearn

in1 = misc.imread('../data/Train_Data/train-1.jpg', mode='L', flatten=True)/255. in2 = misc.imread('../data/Train_Data/train-2.jpg', mode='L', flatten=True)/255. in3 = misc.imread('../data/Train_Data/train-3.jpg', mode='L', flatten=True)/255. in4 = misc.imread('../data/Train_Data/train-4.jpg', mode='L', flatten=True)/255. 

Затем NumPy массив создан, как это -

X_train = [in1,in2,in3,in4] 
X_train = np.array(X_train) 

То же, что для метки и тестового набора.

Тогда я готовлю мой CNN -

nn = Classifier(
    layers=[ 
     Convolution('Rectifier', channels=12, kernel_shape=(3, 3), border_mode='full'), 
     Convolution('Rectifier', channels=8, kernel_shape=(3, 3), border_mode='valid'), 
     Layer('Rectifier', units=64), 
     Layer('Softmax')], 
    learning_rate=0.002, 
    valid_size=0.2, 
    n_stable=10, 
    verbose=True) 


nn.fit(X_train, y_train) 

Он бросает ошибку как -

Traceback (most recent call last): File "/home/zaverichintan/PycharmProjects/WBC_identification/neural/trial.py", line 91, in nn.fit(X_train, y_train) File "/home/zaverichintan/miniconda2/lib/python2.7/site-packages/sknn/mlp.py", line 383, in fit ys = [lb.fit_transform(y[:,i]) for i, lb in enumerate(self.label_binarizers)] File "/home/zaverichintan/miniconda2/lib/python2.7/site-packages/sklearn/base.py", line 494, in fit_transform return self.fit(X, **fit_params).transform(X) File "/home/zaverichintan/miniconda2/lib/python2.7/site-packages/sklearn/preprocessing/label.py", line 335, in transform sparse_output=self.sparse_output) File "/home/zaverichintan/miniconda2/lib/python2.7/site-packages/sklearn/preprocessing/label.py", line 497, in label_binarize y = column_or_1d(y) File "/home/zaverichintan/miniconda2/lib/python2.7/site-packages/sklearn/utils/validation.py", line 563, in column_or_1d raise ValueError("bad input shape {0}".format(shape)) ValueError: bad input shape (4, 128)

ответ

-1

ваш in1, in2,....inN являются 2D массивы, которые 128x128 вы должны преобразовать их все в 1D массивы 16384. in1.shape должен печатать (16384,) и X_train.shape должен печатать (4,16384). Вы можете использовать массивы numpy и применить функцию [reshape] [1]. https://docs.scipy.org/doc/numpy/reference/generated/numpy.reshape.html

Смежные вопросы