2015-08-13 6 views
3

Я пытаюсь построить очень простой многослойный персептрон (MLP) в keras:Python: keras ошибка несоответствия формы

model = Sequential() 
model.add(Dense(16, 8, init='uniform', activation='tanh')) 
model.add(Dropout(0.5)) 
model.add(Dense(8, 2, init='uniform', activation='tanh')) 

sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) 
model.compile(loss='mean_squared_error', optimizer=sgd) 

model.fit(X_train, y_train, nb_epoch=1000, batch_size=50) 
score = model.evaluate(X_test, y_test, batch_size=50) 

Моя форма обучения данные: X_train.shape дает (34180, 16)

лейблы принадлежат бинарной классу с формой: y_train.shape дает (34180,)

Так что мой keras код должен производить сеть с цепью: 16x8 => 8x2

который производит ошибку несоответствия формы:

ValueError: Input dimension mis-match. (input[0].shape[1] = 2, input[1].shape[1] = 1) 

Apply node that caused the error: Elemwise{sub,no_inplace}(Elemwise{Composite{tanh((i0 + i1))}}[(0, 0)].0, <TensorType(float64, matrix)>) 
Inputs types: [TensorType(float64, matrix), TensorType(float64, matrix)] 
Inputs shapes: [(50, 2), (50, 1)] 
Inputs strides: [(16, 8), (8, 8)] 

В Epoch 0 в строке model.fit(X_train, y_train, nb_epoch=1000, batch_size=50). Я наблюдаю за чем-то очевидным в Keras?

EDIT: Я прошел через вопрос here но не решает мою проблему

ответ

10

У меня была такая же проблема, и затем нашел эту нить;

https://github.com/fchollet/keras/issues/68

Оказывается вам сформулировать окончательный выходной слой 2 или для любого количества категорий метки должна быть категорическим типа, где в основном это двоичный вектор для каждого наблюдения, например, в 3 классе выходной сигнал [0,2,1,0,1,0] становится [[1,0,0], [0,0,1], [0,1,0], [1,0,0], [ 0,1,0], [1,0,0]].

Функция np_utils.to_categorical решила это для меня;

from keras.utils import np_utils, generic_utils 

y_train, y_test = [np_utils.to_categorical(x) for x in (y_train, y_test)] 
+0

Другой вариант, который поможет вам «Unmap» ваш один докрасна векторов, является использование 'sklearn.preprocessing.LabelBinarizer'. http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelBinarizer.html – hlin117

Смежные вопросы