comparison trunk/aid/nn/perceptron.d @ 2:9655c8362b25

Added the Perceptron class and the perceptron_test testing program.
author revcompgeek
date Sat, 05 Apr 2008 23:41:30 -0600
parents
children 314d68bafeff
comparison
equal deleted inserted replaced
1:5dd9f598bcd8 2:9655c8362b25
1 module aid.nn.perceptron;
2
3 import std.random;
4 import std.math;
5 import std.string;
6
7 double rnd(){ // The function that should be included in every math library!
8 return (cast(double)rand())/uint.max;
9 }
10
11 class InputException : Exception {
12 this(char[] message){
13 super(message);
14 }
15 }
16
17 // The output functions
18
19 alias double function(double) OutputFunction;
20
21 double sign(double y){
22 if(y>0) return 1;
23 return -1;
24 }
25
26 double sigmoid(double x){
27 return 1/(1+exp(-x));
28 }
29
30 double tanh(double x){
31 return cast(double)tanh(cast(real)x);
32 }
33
34 // End output functions
35
36
37 class perceptron {
38 private int numInputs;
39 private double[] weights;
40 private OutputFunction func;
41 public double learningRate;
42
43 /**
44 * This is the constructor for loading the neural network from a string.
45 *
46 * Params:
47 * savedString = The string that was output from the save function.
48 *
49 * Throws:
50 * Throws an InputException when the string is in the wrong format.
51 */
52
53 public this(char[] savedString){
54 //TODO: Impliment loading!
55 throw new Exception("Not implimented.");
56 }
57
58 // This is private because one type of perceptron training requires the use of the sign function.
59 this(int numInputs, double learningRate=0.3, bool randomize=true,OutputFunction f=null){
60 this.numInputs = numInputs + 1;
61 weights.length = numInputs + 1;
62 func = f;
63 this.learningRate = learningRate;
64 if(randomize){
65 for(int i = 0; i < this.numInputs; i++){
66 weights[i] = rnd() * 2 - 1;
67 }
68 } else {
69 for(int i = 0; i < this.numInputs; i++){
70 weights[i] = 0;
71 }
72 }
73 }
74
75 /**
76 * Evaluates the neural network.
77 *
78 * Params:
79 * inputs = The set of inputs to evaluate.
80 *
81 * Returns: 1 to indicate true, -1 for false
82 */
83
84 double evaluate(double[] inputs){
85 if(inputs.length != numInputs-1) throw new InputException("Wrong input length. %d %d");
86 double total = weights[0];
87 for(int i = 1; i < numInputs; i++){
88 total += inputs[i-1] * weights[i];
89 }
90 if(func != null) return func(total);
91 return total;
92 }
93
94 public double[] getWeights(){
95 return weights.dup;
96 }
97
98 /**
99 * Trains the neural network. This must be overloaded in a subclass.
100 *
101 * Params:
102 * inputs = The array of inputs to the nerual network.
103 * targetOutput = The output that the nerual network should give.
104 */
105 /* Returns: True if it trained the network, false if not.
106 */
107
108 void train(double[] inputs,double targetOutput){
109 if(inputs.length != numInputs-1) throw new InputException("Wrong input length.");
110 double output = evaluate(inputs);
111 double error = this.learningRate * (targetOutput - output);
112 weights[0] += error;
113 for(int i = 1; i < numInputs; i++){
114 weights[i] += error * inputs[i-1];
115 }
116 }
117
118 /**
119 * Calculates the error based on the sum squared error function.
120 *
121 * Params:
122 * inputs = An array of arrays of all testing inputs.
123 * outputs = An array of all the outputs that the cooresponding inputs should have.
124 *
125 * Returns:
126 * The error value.
127 */
128
129 double getErrorValue(double[][] inputsArray, double[] outputsArray){
130 double total = 0;
131 if(inputsArray.length != outputsArray.length) throw new InputException("inputsArray and outputsArray must be the same length");
132 if(inputsArray.length < 1) throw new InputException("Must have at least 1 training example");
133 double output,temp;
134 for(int i = 0; i < inputsArray.length; i++){
135 output=evaluate(inputsArray[i]);
136 temp = outputsArray[i] - output;
137 total += temp*temp;
138 }
139 return total*0.5;
140 }
141 }
142