2015-10-25 3 views
2

Я адаптируют нейронную сеть для классификации изображений из учебника PyBrain по:PyBrain - из = fnn.activateOnDataset (GridData)

http://pybrain.org/docs/tutorial/fnn.html

Он кормятся в данных изображения в формате PNG форме, каждое изображение присваивается определенного класса.

Он хорошо работает до тех пор:

out = fnn.activateOnDataset(griddata) 

В сообщении он возвращает это: AssertionError: (3, 2)

уверен, что это проблема с тем, как я объявил GridData набор данных, но я не знаю точно, что?

В учебной версии все работает нормально.

Мой код:

from pybrain.datasets   import ClassificationDataSet 
from pybrain.utilities   import percentError 
from pybrain.tools.shortcuts  import buildNetwork 
from pybrain.supervised.trainers import BackpropTrainer 
from pybrain.structure.modules import SoftmaxLayer 

from pylab import ion, ioff, figure, draw, contourf, clf, show, hold, plot 
from scipy import diag, arange, meshgrid, where 
from numpy.random import multivariate_normal 
import cv2 

from pyroc import * 

#Creates cover type array based on color of pixels in roadmap 

coverType = [(255,225,104,3), #Road 
(254,254,253,0), #Other 
(254,254,254,3), #Road 
(253,254,253,0),#Other 
(253,225,158,0),#Other 
] # have other cover type but sample amount included 

coverTypes = len(coverType) 

print coverTypes #to count 

#Creates dataset 

alldata = ClassificationDataSet(3,1,nb_classes=10) 

"""Classifies Roadmap Sub-Images by type and loads matching Satellite Sub-Image 
with classification into dataset.""" 

for eachFile in glob.glob('Roadmap Sub-Images/*'): 
    image = Image.open(eachFile) 
    fileName = eachFile 
    newFileName = fileName.replace("Roadmap Sub-Images", "Satellite Sub-Images")  

    colors = image.convert('RGB').getcolors() #Finds all colors in image and their frequency 
    colors.sort() #Sorts colors in image by their frequency 
    colorMostFrequent = colors[-1][1] #Finds last element in array, the most frequent color 

    for eachColor in range(1,151): #151 number of element in CoverType array 
     if colorMostFrequent[0] == coverType[eachColor][0] and colorMostFrequent[1] == coverType[eachColor][1] and colorMostFrequent[2] == coverType[eachColor][2]: 

     print newFileName #Check new route 
     image = cv2.imread(newFileName)  
     meanImage = cv2.mean(image) #Take average color   
     meanImageRGB = meanImage[:3] #Converts to RGB scale, excluding "alpha"   
     print meanImageRGB #Check RGB average colors   
     alldata.addSample(meanImageRGB,coverType[eachColor][3]) 



tstdata, trndata = alldata.splitWithProportion(0.25) 

trndata._convertToOneOfMany() 
tstdata._convertToOneOfMany() 

fnn = buildNetwork(trndata.indim, 5, trndata.outdim, outclass=SoftmaxLayer) 

trainer = BackpropTrainer(fnn, dataset=trndata, momentum=0.1, verbose=True, weightdecay=0.01) 

ticks = arange(-3.,6.,0.2) 

X, Y = meshgrid(ticks, ticks) 

#I think every thing is good to here problem with the griddata dataset I think? 

# need column vectors in dataset, not arrays 

griddata = ClassificationDataSet(2,1, nb_classes=4) 

for i in xrange(X.size): 
    griddata.addSample([X.ravel()[i],Y.ravel()[i]], [0]) 

griddata._convertToOneOfMany() # this is still needed to make the fnn feel comfy 

for i in range(20): 
    trainer.trainEpochs(1) 

    trnresult = percentError(trainer.testOnClassData(), 
          trndata['class']) 
    tstresult = percentError(trainer.testOnClassData(
     dataset=tstdata), tstdata['class']) 

    print "epoch: %4d" % trainer.totalepochs, \ 
      " train error: %5.2f%%" % trnresult, \ 
      " test error: %5.2f%%" % tstresult 


    out = fnn.activateOnDataset(alldata) 

    out = out.argmax(axis=1) # the highest output activation gives the class 
    out = out.reshape(X.shape) 

    figure(1) 
    ioff() # interactive graphics off 
    clf() # clear the plot 
    hold(True) # overplot on 
    for c in [0,1,2]: 
     here, _ = where(tstdata['class']==c) 
     plot(tstdata['input'][here,0],tstdata['input'][here,1],'o') 
    if out.max()!=out.min(): # safety check against flat field 
     contourf(X, Y, out) # plot the contour 
    ion() # interactive graphics on 
    draw() # update the plot 

ioff() 
show() 

ответ

0

Я считаю, что он должен делать с размерами ваших исходных данных набор не совместив с размерами вашего GridData.

alldata = ClassificationDataSet(3,1,nb_classes=10) griddata = ClassificationDataSet(2,1, nb_classes=4)

Они оба должны быть 3, 1. Однако, когда я изменить это мой код не на более позднем этапе, поэтому я тоже любопытно об этом.

Смежные вопросы