annotate trunk/aid/nn/multilayer/backprop.d @ 3:314d68bafeff

Backprop and backprop_test added (no testing).
author revcompgeek
date Fri, 11 Apr 2008 18:12:55 -0600
parents
children 73beed484455
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
3
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
1 module aid.nn.multilevel.backprop;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
2
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
3 import aid.nn.outputFunctions;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
4 import aid.misc;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
5 import std.random;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
6 import std.stream;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
7
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
8 class Backprop {
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
9 private uint numInputs;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
10 private float[][][] units; // Includes the output units. units[layer][unit][inputWeight]
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
11 private OutputFunctionPtr[] functions;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
12 public float learningRate;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
13
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
14 ///Constructor
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
15 public this(uint numInputs,uint[] numUnits,OutputFunctionPtr[] functions,float value=0.05,bool randomize=true){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
16 if(numUnits.length == 0) throw new InputException("numUnits must be greater than 0");
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
17 if(numUnits.length != functions.length) throw new InputException("numUnits and functions must be the same length");
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
18 this.numInputs = numInputs;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
19 this.functions = functions;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
20 initUnitLayer(0,numUnits[0],numInputs,value,randomize);
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
21 for(int i=1; i<numUnits.length; i++){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
22 initUnitLayer(i,numUnits[i],numUnits[i-1],value,randomize);
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
23 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
24 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
25
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
26 // Helper function to initialize a certain layer.
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
27 private void initUnitLayer(uint layer,uint num,uint numPrev,float value,bool randomize){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
28 units[layer].length = num;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
29 for(int i=0; i<num; i++){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
30 units[layer][i].length = numPrev+1; // include the bias weight
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
31 for(int j=0; j<numPrev+1; j++){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
32 if(randomize) units[layer][i][j] = rnd() * value * 2 - value; // between -value and value
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
33 else units[layer][i][j] = value;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
34 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
35 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
36 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
37
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
38 ////////////////////////////////////////////////////// Evaluation //////////////////////////////////////////////////////
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
39 /// Evaluates the neural network.
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
40 public float[] evaluate(float[] inputs){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
41 return evaluateFull(inputs)[$]; // the last item (outputs) of the return value
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
42 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
43
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
44 /// Evaluates the neural network and returns the output from all units.
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
45 public float[][] evaluateFull(float[] inputs){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
46 if(inputs.length != numInputs) throw new InputException("Wrong length of inputs.");
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
47 float[][] outputs;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
48 outputs.length = units.length;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
49 outputs[0] = evaluateLayer(0,inputs);
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
50 for(int i=0; i<units.length; i++){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
51 outputs[i] = this.evaluateLayer(i,outputs[i-1]);
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
52 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
53 return outputs;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
54 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
55
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
56 // Helper function to evaluate the outputs of a single layer.
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
57 private float[] evaluateLayer(uint layer,float[] layerInputs){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
58 float[] output;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
59 output.length = layerInputs.length;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
60 for(int i=0; i<layerInputs.length; i++){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
61 output[i] = evaluateUnit(layer,i,layerInputs);
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
62 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
63 return output;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
64 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
65
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
66 // Helper function to evaluate the output of a single unit.
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
67 private float evaluateUnit(uint layer, uint unit, float[] layerInputs){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
68 float total = units[layer][unit][0]; //bias
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
69 for(int i=1; i<layerInputs.length; i++){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
70 total += layerInputs[i-1] * units[layer][unit][i]; // wi * xi
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
71 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
72 if(functions[layer] != null) return functions[layer](total); // apply the function (if there is one)
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
73 else return total; // just return the result instead
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
74 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
75
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
76
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
77 ////////////////////////////////////////////////////// Training //////////////////////////////////////////////////////
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
78 /// Trains the neural network.
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
79 /// TODO:
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
80 /// Pull error calculation into a separate function.
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
81 public void train(float[][] allInputs, float[][] allOutputs){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
82 if(allInputs.length != allOutputs.length) throw new InputException("allInputs and allOutputs must be the same size");
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
83 float[][][] weightUpdate;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
84 float[][] outputsError;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
85 float[][] outputs;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
86 float total; //temp variable
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
87
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
88 // Initialize the weightUpdate and outputsError variables
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
89 weightUpdate.length = units.length;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
90 outputsError.length = units.length;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
91 for(int i=0; i<weightUpdate.length; i++){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
92 weightUpdate[i].length = units[i].length;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
93 outputsError[i].length = units[i].length;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
94 for(int j=0; j<weightUpdate[i].length; i++){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
95 weightUpdate[i][j].length = units[i][j].length;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
96 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
97 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
98
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
99
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
100 // Loop through each of the training examples
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
101 for(int example=0; example < allInputs.length; example++){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
102 outputs = evaluateFull(allInputs[example]);
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
103
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
104 // Computing error of output layer
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
105 for(int i=0; i<outputs[$].length; i++)
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
106 outputsError[$][i] = outputs[$][i] * (1 - outputs[$][i]) * (allOutputs[example][i] - outputs[$][i]); // o(1-o)(t-o)
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
107
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
108 // Loop through each of the hidden layers (backwards - BACKpropagation!)
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
109 for(int i=units.length-2; i >= 0; i--){ // -2 to skip the output layer
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
110 // loop through the units in each hidden layer
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
111 for(int j=0; j<units[i].length; j++){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
112 total=0;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
113 // total up w * e for the units the output of this unit goes into
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
114 for(int k=0; k<units[i+1].length; k++){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
115 total += units[i+1][k][j+1] * outputsError[i+1][k];
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
116 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
117 // multiply total by o(1-o), store in outputsError
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
118 outputsError[i][j] = outputs[i][j] * (1 - outputs[i][j]) * total;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
119 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
120 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
121
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
122 // special case for the units that receive the input values
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
123 for(int j=0; j<units[0].length; j++){ // unit
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
124 weightUpdate[0][j][0] += outputsError[0][j]; //bias
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
125 for(int k=1; k<units[0][j].length; k++){ // input
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
126 weightUpdate[0][j][k] += outputsError[0][j] * allInputs[example][k-1];
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
127 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
128 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
129
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
130 // Update the weightUpdate array
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
131 for(int i=1; i<units.length; i++){ // layer
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
132 for(int j=0; j<units[i].length; j++){ // unit
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
133 weightUpdate[i][j][0] += outputsError[i][j]; //bias
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
134 for(int k=1; k<units[i][j].length; k++){ // input
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
135 weightUpdate[i][j][k] += outputsError[i][j] * outputs[i-1][k-1]; // previous layer, account for bias
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
136 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
137 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
138 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
139 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
140
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
141 // Apply the weightUpdate array to the weights
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
142 for(int i=0; i<units.length; i++){ // layer
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
143 for(int j=0; j<units[i].length; j++){ // unit
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
144 for(int k=0; k<units[i][j].length; k++){ // input
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
145 units[i][j][k] += this.learningRate * weightUpdate[i][j][k];
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
146 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
147 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
148 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
149 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
150
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
151 /// Calculate the output error
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
152 float calculateError(float[][] allInputs, float[][] allOutputs){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
153 if(allInputs.length != allOutputs.length) throw new InputException("allInputs and allOutputs must be the same size");
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
154 float[] outputs;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
155 float total,temp;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
156 for(int i=0; i<allInputs.length; i++){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
157 outputs = evaluate(allInputs[i]);
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
158 if(outputs.length != allOutputs[i].length) throw new InputException("Wrong output length");
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
159 for(int j=0; j<outputs.length; j++){
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
160 temp = allOutputs[i][j] - outputs[j];
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
161 total += temp * temp;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
162 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
163 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
164 return 0.5 * total;
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
165 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
166 }
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
167
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
168
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
169
314d68bafeff Backprop and backprop_test added (no testing).
revcompgeek
parents:
diff changeset
170