Mercurial > projects > aid
view trunk/aid/nn/perceptron.d @ 6:ff92c77006c7
Added support for reading training examples from files.
author | revcompgeek |
---|---|
date | Tue, 06 May 2008 21:43:55 -0600 |
parents | 314d68bafeff |
children |
line wrap: on
line source
/** * perceptron.d * Holds the simple perceptron class. */ module aid.nn.perceptron; import aid.nn.outputFunctions; import aid.misc; import std.math; import std.string; class perceptron { private int numInputs; private double[] weights; private OutputFunction func; public double learningRate; /** * This is the constructor for loading the neural network from a string. * * Params: * savedString = The string that was output from the save function. * * Throws: * Throws an InputException when the string is in the wrong format. */ public this(char[] savedString){ //TODO: Impliment loading! throw new Exception("Not implimented."); } // This is private because one type of perceptron training requires the use of the sign function. this(int numInputs, double learningRate=0.3, bool randomize=true,OutputFunction f=null){ this.numInputs = numInputs + 1; weights.length = numInputs + 1; func = f; this.learningRate = learningRate; if(randomize){ for(int i = 0; i < this.numInputs; i++){ weights[i] = rnd() * 2 - 1; } } else { for(int i = 0; i < this.numInputs; i++){ weights[i] = 0; } } } /** * Evaluates the neural network. * * Params: * inputs = The set of inputs to evaluate. * * Returns: 1 to indicate true, -1 for false */ double evaluate(double[] inputs){ if(inputs.length != numInputs-1) throw new InputException("Wrong input length. %d %d"); double total = weights[0]; for(int i = 1; i < numInputs; i++){ total += inputs[i-1] * weights[i]; } if(func != null) return func(total); return total; } public double[] getWeights(){ return weights.dup; } /** * Trains the neural network. This must be overloaded in a subclass. * * Params: * inputs = The array of inputs to the nerual network. * targetOutput = The output that the nerual network should give. */ /* Returns: True if it trained the network, false if not. */ void train(double[] inputs,double targetOutput){ if(inputs.length != numInputs-1) throw new InputException("Wrong input length."); double output = evaluate(inputs); double error = this.learningRate * (targetOutput - output); weights[0] += error; for(int i = 1; i < numInputs; i++){ weights[i] += error * inputs[i-1]; } } /** * Calculates the error based on the sum squared error function. * * Params: * inputs = An array of arrays of all testing inputs. * outputs = An array of all the outputs that the cooresponding inputs should have. * * Returns: * The error value. */ double getErrorValue(double[][] inputsArray, double[] outputsArray){ double total = 0; if(inputsArray.length != outputsArray.length) throw new InputException("inputsArray and outputsArray must be the same length"); if(inputsArray.length < 1) throw new InputException("Must have at least 1 training example"); double output,temp; for(int i = 0; i < inputsArray.length; i++){ output=evaluate(inputsArray[i]); temp = outputsArray[i] - output; total += temp*temp; } return total*0.5; } } // TODO: Impliment loading and saving of files