2016-12-04 3 views
0

Я пытаюсь написать алгоритм для обучения персептрона, но, кажется, значения превышают максимальное значение для double. Я пытаюсь выяснить со вчерашнего дня, но не могу.Алгоритм данных обучения не работает. (Perceptron)

Значение весов, как представляется, превышает, а также значение переменной мощности.

Текстовый файл, который читается, имеет вид: Input variables and the output

/** 
* Created by yafael on 12/3/16. 
*/ 
import java.io.*; 
import java.util.*; 

public class Perceptron { 

static double[] weights; 
static ArrayList<Integer> inputValues; 
static ArrayList<Integer> outputValues; 
static int[] inpArray; 
static int[] outArray; 

public static int numberOfInputValues(String filePath)throws IOException 
{ 
    Scanner valueScanner = new Scanner(new File(filePath)); 
    int num = valueScanner.nextInt(); 
    return num; 
} 

public static void inputs(String filePath)throws IOException 
{ 
    inputValues = new ArrayList<Integer>(); 
    outputValues = new ArrayList<Integer>(); 
    Scanner valueScanner = new Scanner(new File(filePath)); 
    int num = valueScanner.nextInt(); 

    while (valueScanner.hasNext()) 
    { 
     String temp = valueScanner.next(); 
     String[] values = temp.split(","); 
     for(int i = 0; i < values.length; i++) 
     { 
      if(i+1 != values.length) 
      { 
       inputValues.add(Integer.parseInt(values[i])); 
      }else 
      { 
       outputValues.add(Integer.parseInt(values[i])); 
      } 
     } 

    } 
    valueScanner.close(); 
} 

public static void trainData(int[] inp, int[] out, int num,int epoch) 
{ 
    weights = new double[num]; 
    Random r = new Random(); 
    int i,ep; 
    int error = 0; 
    /* 
    * Initialize weights 
    */ 
    for(i = 0; i < num; i++) 
    { 
     weights[i] = r.nextDouble(); 
    } 

    for(ep = 1; ep<= epoch; ep++) 
    { 
     double totalError = 0; 

     for(i = 0; i < inp.length/(num); i++) 
     { 
      double output = calculateOutput(inp, i, weights); 
      System.out.println("Output " + (i + 1) + ": " + output); 
      //System.out.println("Output: " + output); 
      if(output > 0) 
      { 
       error = out[i] - 1; 
      }else 
      { 
       error = out[i] - 0; 
      } 

      for(int temp = 0; temp < num; temp++) 
      { 
       double epCalc = (1000/(double)(1000+ep)); 
       weights[temp] += epCalc*error*inp[((i*weights.length)+temp)]; 
       //System.out.println("Epoch calculation: " + epCalc); 
       //System.out.println("Output: " + output); 
       //System.out.println("error: " + error); 
       //System.out.println("input " + ((i*weights.length)+temp) + ": " + inp[(i*weights.length)+temp]); 
      } 
      totalError += (error*error); 
     } 
     //System.out.println("Total Error: " + totalError); 

     if(totalError == 0) 
     { 
      System.out.println("In total error"); 
      for(int temp = 0; temp < num; temp++) 
      { 
       System.out.println("Weight " +(temp)+ ": " + weights[temp]); 
      } 

      double x = 0.0; 
      for(i = 0; i < inp.length/(num); i++) 
      { 
       for(int j = 0; j < weights.length; j++) 
       { 
        x = inp[((i*num) + j)] * weights[j]; 
       } 
       System.out.println("Output " + (i+1) + ": " + x); 
      } 
      break; 
     } 

    } 
    if(ep >= 10000) 
    { 
     System.out.println("Solution not found"); 
    } 
} 

public static double calculateOutput(int[] input, int start, double[] weights) 
{ 
    start = start * weights.length; 
    double sum = 0.0; 
    for(int i = 0; i < weights.length; i++) 
    { 
     //System.out.println("input[" + (start + i) + "]: " + input[(start+i)]); 
     //System.out.println("weights[i]" + weights[i]); 
     sum += (double)input[(start + i)] * weights[i]; 
    } 
    return sum - 1.0 ; 
} 
public static void main(String args[])throws IOException 
{ 
    BufferedReader obj = new BufferedReader(new InputStreamReader(System.in)); 

    //Read the file path from the user 
    String fileName; 
    System.out.println("Please enter file path for Execution: "); 
    fileName = obj.readLine(); 

    int numInputValues = numberOfInputValues(fileName); 

    //Call the function to store values in the ArrayList<> 
    inputs(fileName); 
    inpArray = inputValues.stream().mapToInt(i->i).toArray(); 
    outArray = outputValues.stream().mapToInt(i->i).toArray(); 

    trainData(inpArray, outArray, numInputValues, 10000); 
} 
} 
+0

Вы получаете ошибку переполнения? с какой конкретной проблемой вы столкнулись и в каком сегменте вашей программы вы столкнулись с проблемой? –

+0

@WasiAhmad благодарит за ваш комментарий. Проблема в том, что мои значения весов должны оставаться в диапазоне от 0,0 до 1,0. Но они экспоненциально возрастают. После нескольких итераций метод calculateOutput возвращает NaN – Yafael

+0

Можете ли вы поделиться своим полным кодом, чтобы я мог запускать и проверять? Кстати, какова точка вычитания '0 из вне [i]' внутри блока else. Кроме того, вы можете написать '1000.0/(1000.0 + ep)' вместо '(1000/(double) (1000 + ep))'. –

ответ

0

Я считаю, что ваш код является проблематичным, поэтому я даю вам простой пример, но я уверен, что вы получите помощь от этого кода для решения вашей проблемы.

import java.util.Random; 

public class Perceptron { 
    double[] weights; 
    double threshold; 
    public void Train(double[][] inputs, int[] outputs, double threshold, double lrate, int epoch) { 
     this.threshold = threshold; 
     int n = inputs[0].length; 
     int p = outputs.length; 
     weights = new double[n]; 
     Random r = new Random(); 

     //initialize weights 
     for(int i=0;i<n;i++) { 
      weights[i] = r.nextDouble(); 
     } 

     for(int i=0;i<epoch;i++) { 
      int totalError = 0; 
      for(int j =0;j<p;j++) { 
       int output = Output(inputs[j]); 
       int error = outputs[j] - output; 

       totalError +=error; 

       for(int k=0;k<n;k++) { 
        double delta = lrate * inputs[j][k] * error; 
        weights[k] += delta; 
       } 
      } 
      if(totalError == 0) 
       break; 
     } 
    } 

    public int Output(double[] input) { 
     double sum = 0.0; 
     for(int i=0;i<input.length;i++) { 
      sum += weights[i]*input[i]; 
     } 

     if(sum>threshold) 
      return 1; 
     else 
      return 0; 
    } 

    public static void main(String[] args) { 
     Perceptron p = new Perceptron(); 
     double inputs[][] = {{0,0},{0,1},{1,0},{1,1}}; 
     int outputs[] = {0,0,0,1}; 

     p.Train(inputs, outputs, 0.2, 0.1, 200); 
     System.out.println(p.Output(new double[]{0,0})); // prints 0 
     System.out.println(p.Output(new double[]{1,0})); // prints 0 
     System.out.println(p.Output(new double[]{0,1})); // prints 0 
     System.out.println(p.Output(new double[]{1,1})); // prints 1 
    } 
} 
Смежные вопросы