2015-01-24 6 views
1

Я только начал изучать машинное обучение с использованием python, и я создаю нейронную сеть с помощью pybrain для обучения распознаванию цифр. Я, наконец, получил работу над программой, однако точность очень низкая (~ 30%) на тренировках и тестах. Я думаю, что что-то не так, но я не мог найти проблему. Я застрял здесь целыми днями. Может кто-нибудь мне помочь? Большое спасибо.очень низкая точность нейронной сети с pybrain

каждый учебный материал: вход 28 * 28 оттенков серого, перестроенный в 1 * 784, выход 1 * 10 массив с 0 или 1 с указанием местоположения 1, указывающего цифру, например [0,0,0,0 , 0,0,0,1,0,0] означает, что цифра 7.

Вот код:

from pybrain.datasets.supervised import SupervisedDataSet as SDS 
from pybrain.tools.shortcuts import buildNetwork 
from pybrain.supervised.trainers import BackpropTrainer 
from sklearn.metrics import accuracy_score 

#build datasets 
size_sample = 500 
#randomly choose 500 training data  
(sample_X,sample_y) = randomSample(training_X, training_y, size_sample) 
ds = SDS(28*28,10) 
ds.setField('input', sample_X) 
#sample_y = sample_y.reshape(size_sample,10) 
ds.setField('target', sample_y) 
#build network 
Num_Hidden_Layers = 10 
net = buildNetwork(ds.indim,Num_Hidden_Layers,ds.outdim,bias=True,outclass=SoftmaxLayer) 
#train data 
trainer = BackpropTrainer(net,ds) 

# predict using test data 
print "Making predictions..." 
predict_y = [] 
for i in range(test_X.shape[0]): 
    pred = net.activate(test_X[i, :]) 
    print pred 
    print pred.argmax() 
    predict_y = np.append(predict_y, pred.argmax()) 

ответ

0

это потому, что вы на самом деле не с помощью 10 скрытых слоев, вы используя только 10 скрытых единиц. (http://pybrain.org/docs/quickstart/network.html) Попробуйте вместо этого:

buildNetwork(ds.indim,25, 50, 25,ds.outdim,bias=True,outclass=SoftmaxLayer) 

Если это работает, чтобы дать лучшую точность вывода, то вы знаете, что вы на правильном пути, в какой момент вы просто должны играть с параметрами.

0

Попробуйте использовать резистивный тренажер Backpropagation (Rprop). Это должно улучшить вашу сеть.

Смежные вопросы