0

Я пытаюсь запустить код примера, предоставленный DL4J для многоуровневого автоопределения шумоподавления. Тем не менее, мои результаты были очень плохо:DL4J - Точность с автоопределением Denoising

Warning: class 0 was never predicted by the model. This class was excluded from the average precision 
Warning: class 1 was never predicted by the model. This class was excluded from the average precision 
Warning: class 2 was never predicted by the model. This class was excluded from the average precision 
Warning: class 3 was never predicted by the model. This class was excluded from the average precision 
Warning: class 4 was never predicted by the model. This class was excluded from the average precision 
Warning: class 5 was never predicted by the model. This class was excluded from the average precision 
Warning: class 6 was never predicted by the model. This class was excluded from the average precision 
Warning: class 7 was never predicted by the model. This class was excluded from the average precision 
Warning: class 9 was never predicted by the model. This class was excluded from the average precision 

==========================Scores======================================== 
Accuracy: 0.0944 
Precision: 0.0944 
Recall: 0.1 
F1 Score: 0.0971 

Я точно не знаю, что может быть причиной, что результаты будут так плохо. Я использую набор данных MNIST, а код для автокодирования с суммированием шумоподавления был предоставлен DL4J. Мой код ниже:

/** 
* Created by chris on 1/31/16. 
* Import statements above 
*/ 
public class stackedDenoisingAutoencoderExample { 

    private static Logger log = LoggerFactory.getLogger(stackedDenoisingAutoencoderExample.class); 

    public static void main(String[] args) throws Exception { 
     final int numRows = 28; 
     final int numColumns = 28; 
     int outputNum = 10; 
     int numSamples = 60000; 
     int batchSize = 100; 
     int iterations = 10; 
     int seed = 123; 
     int listenerFreq = batchSize/5; 

     log.info("Load data...."); 
     DataSetIterator iter = new MnistDataSetIterator(batchSize,numSamples,true); 

     log.info("Build model...."); 

     MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() 
      .seed(seed) 
    .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) 
      .gradientNormalizationThreshold(1.0) 
      .iterations(iterations) 
      .momentum(0.5) 
      .momentumAfter(Collections.singletonMap(3, 0.9)) 
      .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) 
      .list(4) 
      .layer(0, new AutoEncoder.Builder().nIn(numRows * numColumns).nOut(500) 
        .weightInit(WeightInit.XAVIER).lossFunction(LossFunctions.LossFunction.RMSE_XENT) 
        .corruptionLevel(0.3) 
        .build()) 
      .layer(1, new AutoEncoder.Builder().nIn(500).nOut(250) 
        .weightInit(WeightInit.XAVIER).lossFunction(LossFunctions.LossFunction.RMSE_XENT) 
        .corruptionLevel(0.3) 

        .build()) 
      .layer(2, new AutoEncoder.Builder().nIn(250).nOut(200) 
        .weightInit(WeightInit.XAVIER).lossFunction(LossFunctions.LossFunction.RMSE_XENT) 
        .corruptionLevel(0.3) 
        .build()) 
      .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation("softmax") 
        .nIn(200).nOut(outputNum).build()) 
      .pretrain(true).backprop(false) 
      .build(); 

     MultiLayerNetwork model = new MultiLayerNetwork(conf); 
     model.init(); 
     model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq))); 

     log.info("Train model...."); 
     model.fit(iter); // achieves end to end pre-training 

     log.info("Evaluate model...."); 
     Evaluation eval = new Evaluation(outputNum); 

     DataSetIterator testIter = new MnistDataSetIterator(100,10000); 
     while(testIter.hasNext()) { 
      DataSet testMnist = testIter.next(); 
      INDArray predict2 = model.output(testMnist.getFeatureMatrix()); 
      eval.eval(testMnist.getLabels(), predict2); 
     } 

     log.info(eval.stats()); 
     log.info("****************Example finished********************"); 

    } 
} 

Спасибо!

+0

Задайте свой вопрос на GITTER чате на nd4j странице. – 404pio

ответ

Смежные вопросы