view trunk/aid/nn/perceptron.d @ 3:314d68bafeff

Backprop and backprop_test added (no testing).
author revcompgeek
date Fri, 11 Apr 2008 18:12:55 -0600
parents 9655c8362b25
children ff92c77006c7
line wrap: on
line source

module aid.nn.perceptron;

import aid.nn.outputFunctions;
import aid.misc;
import std.math;
import std.string;


class perceptron {
	private int numInputs;
	private double[] weights;
	private OutputFunction func;
	public double learningRate;
	
	/**
	 * This is the constructor for loading the neural network from a string.
	 *
	 * Params:
	 *  savedString = The string that was output from the save function.
	 *
	 * Throws:
	 *  Throws an InputException when the string is in the wrong format.
	 */

	public this(char[] savedString){
		//TODO: Impliment loading!
		throw new Exception("Not implimented.");
	}

	// This is private because one type of perceptron training requires the use of the sign function.
	this(int numInputs, double learningRate=0.3, bool randomize=true,OutputFunction f=null){
		this.numInputs = numInputs + 1;
		weights.length = numInputs + 1;
		func = f;
		this.learningRate = learningRate;
		if(randomize){
			for(int i = 0; i < this.numInputs; i++){
				weights[i] = rnd() * 2 - 1;
			}
		} else {
			for(int i = 0; i < this.numInputs; i++){
				weights[i] = 0;
			}
		}
	}
	
	/**
	 * Evaluates the neural network.
	 *
	 * Params:
	 *  inputs = The set of inputs to evaluate.
	 *
	 * Returns: 1 to indicate true, -1 for false
	 */

	double evaluate(double[] inputs){
		if(inputs.length != numInputs-1) throw new InputException("Wrong input length. %d %d");
		double total = weights[0];
		for(int i = 1; i < numInputs; i++){
			total += inputs[i-1] * weights[i];
		}
		if(func != null) return func(total);
		return total;
	}

	public double[] getWeights(){
		return weights.dup;
	}
	
	/**
	 * Trains the neural network. This must be overloaded in a subclass.
	 *
	 * Params:
	 *  inputs = The array of inputs to the nerual network.
	 *  targetOutput = The output that the nerual network should give.
	 */
	 /* Returns: True if it trained the network, false if not.
	 */

	void train(double[] inputs,double targetOutput){
		if(inputs.length != numInputs-1) throw new InputException("Wrong input length.");
		double output = evaluate(inputs);
		double error = this.learningRate * (targetOutput - output);
		weights[0] += error;
		for(int i = 1; i < numInputs; i++){
			weights[i] += error * inputs[i-1];
		}
	}

	/**
	 * Calculates the error based on the sum squared error function.
	 *
	 * Params:
	 *  inputs = An array of arrays of all testing inputs.
	 *  outputs = An array of all the outputs that the cooresponding inputs should have.
	 *
	 * Returns:
	 *  The error value.
	 */

	double getErrorValue(double[][] inputsArray, double[] outputsArray){
		double total = 0;
		if(inputsArray.length != outputsArray.length) throw new InputException("inputsArray and outputsArray must be the same length");
		if(inputsArray.length < 1) throw new InputException("Must have at least 1 training example");
		double output,temp;
		for(int i = 0; i < inputsArray.length; i++){
			output=evaluate(inputsArray[i]);
			temp = outputsArray[i] - output;
			total += temp*temp;
		}
		return total*0.5;
	}
}

// TODO: Impliment loading and saving of files