comparison trunk/aid/nn/multilayer/backprop.d @ 4:73beed484455

Backprop working correctly.
author revcompgeek
date Sat, 12 Apr 2008 21:55:37 -0600
parents 314d68bafeff
children 810d58835f86
comparison
equal deleted inserted replaced
3:314d68bafeff 4:73beed484455
2 2
3 import aid.nn.outputFunctions; 3 import aid.nn.outputFunctions;
4 import aid.misc; 4 import aid.misc;
5 import std.random; 5 import std.random;
6 import std.stream; 6 import std.stream;
7 import std.stdio;
7 8
8 class Backprop { 9 class Backprop {
9 private uint numInputs; 10 private uint numInputs;
10 private float[][][] units; // Includes the output units. units[layer][unit][inputWeight] 11 private double[][][] units; // Includes the output units. units[layer][unit][inputWeight]
11 private OutputFunctionPtr[] functions; 12 private OutputFunctionPtr[] functions;
12 public float learningRate; 13 public double learningRate;
13 14
14 ///Constructor 15 ///Constructor
15 public this(uint numInputs,uint[] numUnits,OutputFunctionPtr[] functions,float value=0.05,bool randomize=true){ 16 public this(uint numInputs,uint[] numUnits,OutputFunctionPtr[] functions,double learningRate=0.03,double value=0.1,bool randomize=true){
16 if(numUnits.length == 0) throw new InputException("numUnits must be greater than 0"); 17 if(numUnits.length == 0) throw new InputException("numUnits must be greater than 0");
17 if(numUnits.length != functions.length) throw new InputException("numUnits and functions must be the same length"); 18 if(numUnits.length != functions.length) throw new InputException("numUnits and functions must be the same length");
18 this.numInputs = numInputs; 19 this.numInputs = numInputs;
19 this.functions = functions; 20 this.functions = functions;
21 this.learningRate = learningRate;
22 units.length = numUnits.length;
20 initUnitLayer(0,numUnits[0],numInputs,value,randomize); 23 initUnitLayer(0,numUnits[0],numInputs,value,randomize);
21 for(int i=1; i<numUnits.length; i++){ 24 for(int i=1; i<numUnits.length; i++){
22 initUnitLayer(i,numUnits[i],numUnits[i-1],value,randomize); 25 initUnitLayer(i,numUnits[i],numUnits[i-1],value,randomize);
23 } 26 }
24 } 27 }
25 28
26 // Helper function to initialize a certain layer. 29 // Helper function to initialize a certain layer.
27 private void initUnitLayer(uint layer,uint num,uint numPrev,float value,bool randomize){ 30 private void initUnitLayer(uint layer,uint num,uint numPrev,double value,bool randomize){
28 units[layer].length = num; 31 units[layer].length = num;
29 for(int i=0; i<num; i++){ 32 for(int i=0; i<num; i++){
30 units[layer][i].length = numPrev+1; // include the bias weight 33 units[layer][i].length = numPrev+1; // include the bias weight
31 for(int j=0; j<numPrev+1; j++){ 34 for(int j=0; j<numPrev+1; j++){
32 if(randomize) units[layer][i][j] = rnd() * value * 2 - value; // between -value and value 35 if(randomize) units[layer][i][j] = rnd() * value * 2 - value; // between -value and value
35 } 38 }
36 } 39 }
37 40
38 ////////////////////////////////////////////////////// Evaluation ////////////////////////////////////////////////////// 41 ////////////////////////////////////////////////////// Evaluation //////////////////////////////////////////////////////
39 /// Evaluates the neural network. 42 /// Evaluates the neural network.
40 public float[] evaluate(float[] inputs){ 43 public double[] evaluate(double[] inputs){
41 return evaluateFull(inputs)[$]; // the last item (outputs) of the return value 44 return evaluateFull(inputs)[$-1]; // the last item (outputs) of the return value
42 } 45 }
43 46
44 /// Evaluates the neural network and returns the output from all units. 47 /// Evaluates the neural network and returns the output from all units.
45 public float[][] evaluateFull(float[] inputs){ 48 public double[][] evaluateFull(double[] inputs){
46 if(inputs.length != numInputs) throw new InputException("Wrong length of inputs."); 49 if(inputs.length != numInputs) throw new InputException("Wrong length of inputs.");
47 float[][] outputs; 50 double[][] outputs;
48 outputs.length = units.length; 51 outputs.length = units.length;
49 outputs[0] = evaluateLayer(0,inputs); 52 outputs[0] = evaluateLayer(0,inputs);
50 for(int i=0; i<units.length; i++){ 53 for(int i=1; i<units.length; i++){
51 outputs[i] = this.evaluateLayer(i,outputs[i-1]); 54 outputs[i] = this.evaluateLayer(i,outputs[i-1]);
52 } 55 }
53 return outputs; 56 return outputs;
54 } 57 }
55 58
56 // Helper function to evaluate the outputs of a single layer. 59 // Helper function to evaluate the outputs of a single layer.
57 private float[] evaluateLayer(uint layer,float[] layerInputs){ 60 private double[] evaluateLayer(uint layer,double[] layerInputs){
58 float[] output; 61 double[] output;
59 output.length = layerInputs.length; 62 output.length = units[layer].length;
60 for(int i=0; i<layerInputs.length; i++){ 63 //printArray(layerInputs);
64 for(int i=0; i<units[layer].length; i++){
61 output[i] = evaluateUnit(layer,i,layerInputs); 65 output[i] = evaluateUnit(layer,i,layerInputs);
62 } 66 }
63 return output; 67 return output;
64 } 68 }
65 69
66 // Helper function to evaluate the output of a single unit. 70 // Helper function to evaluate the output of a single unit.
67 private float evaluateUnit(uint layer, uint unit, float[] layerInputs){ 71 private double evaluateUnit(uint layer, uint unit, double[] layerInputs){
68 float total = units[layer][unit][0]; //bias 72 //writef("(%d,%d)=",layer,unit);
69 for(int i=1; i<layerInputs.length; i++){ 73 //printArray(layerInputs);
74 double total = units[layer][unit][0]; //bias
75 for(int i=1; i<layerInputs.length+1; i++){
70 total += layerInputs[i-1] * units[layer][unit][i]; // wi * xi 76 total += layerInputs[i-1] * units[layer][unit][i]; // wi * xi
71 } 77 //writef("@");
78 }
79 //writefln(" ! %f",total);
72 if(functions[layer] != null) return functions[layer](total); // apply the function (if there is one) 80 if(functions[layer] != null) return functions[layer](total); // apply the function (if there is one)
73 else return total; // just return the result instead 81 writefln("no function");
82 return total; // just return the result instead
74 } 83 }
75 84
76 85
77 ////////////////////////////////////////////////////// Training ////////////////////////////////////////////////////// 86 ////////////////////////////////////////////////////// Training //////////////////////////////////////////////////////
78 /// Trains the neural network. 87 /// Trains the neural network.
79 /// TODO: 88 /// TODO:
80 /// Pull error calculation into a separate function. 89 /// Pull error calculation into a separate function.
81 public void train(float[][] allInputs, float[][] allOutputs){ 90 public void train(double[][] trainingInputs, double[][] trainingOutputs){
82 if(allInputs.length != allOutputs.length) throw new InputException("allInputs and allOutputs must be the same size"); 91 if(trainingInputs.length != trainingOutputs.length) throw new InputException("trainingInputs and trainingOutputs must be the same size");
83 float[][][] weightUpdate; 92 double[][][] weightUpdate;
84 float[][] outputsError; 93 double[][] outputsError;
85 float[][] outputs; 94 double[][] outputs;
86 float total; //temp variable 95 double total; //temp variable
87 96
88 // Initialize the weightUpdate and outputsError variables 97 // Initialize the weightUpdate and outputsError variables
89 weightUpdate.length = units.length; 98 weightUpdate.length = units.length;
90 outputsError.length = units.length; 99 outputsError.length = units.length;
91 for(int i=0; i<weightUpdate.length; i++){ 100 //writefln("#%d,%d",weightUpdate.length,outputsError.length);
101 for(int i=0; i<units.length; i++){
92 weightUpdate[i].length = units[i].length; 102 weightUpdate[i].length = units[i].length;
93 outputsError[i].length = units[i].length; 103 outputsError[i].length = units[i].length;
94 for(int j=0; j<weightUpdate[i].length; i++){ 104 //writefln("##(%d)%d,%d",i,weightUpdate[i].length,outputsError[i].length);
105 for(int j=0; j<weightUpdate[i].length; j++){
95 weightUpdate[i][j].length = units[i][j].length; 106 weightUpdate[i][j].length = units[i][j].length;
107 for(int k=0; k<weightUpdate[i][j].length; k++) weightUpdate[i][j][k] = 0.0f;
108 //writefln("###(%d)%d",j,weightUpdate[i][j].length);
96 } 109 }
97 } 110 }
98 111
99 112
100 // Loop through each of the training examples 113 // Loop through each of the training examples
101 for(int example=0; example < allInputs.length; example++){ 114 for(int example=0; example < trainingInputs.length; example++){
102 outputs = evaluateFull(allInputs[example]); 115 outputs = evaluateFull(trainingInputs[example]);
103 116
104 // Computing error of output layer 117 // Computing error of output layer
105 for(int i=0; i<outputs[$].length; i++) 118 for(int i=0; i<outputs[$-1].length; i++){ // units of last layer
106 outputsError[$][i] = outputs[$][i] * (1 - outputs[$][i]) * (allOutputs[example][i] - outputs[$][i]); // o(1-o)(t-o) 119 //writefln("{%d,%d,%d,%d}",example,i,outputs.length,outputsError[$-1].length);
107 120 outputsError[$-1][i] = outputs[$-1][i] * (1 - outputs[$-1][i]) * (trainingOutputs[example][i] - outputs[$-1][i]);
121 } // o(1-o)(t-o)
122
123 //printArray(outputsError[$-1]);
124 //printArray(units[length-1]);
125
126 //*
108 // Loop through each of the hidden layers (backwards - BACKpropagation!) 127 // Loop through each of the hidden layers (backwards - BACKpropagation!)
109 for(int i=units.length-2; i >= 0; i--){ // -2 to skip the output layer 128 for(int layer=units.length-2; layer >= 0; layer--){ // -2 to skip the output layer
129 //writef("|");
110 // loop through the units in each hidden layer 130 // loop through the units in each hidden layer
111 for(int j=0; j<units[i].length; j++){ 131 for(int unit=0; unit<units[layer].length; unit++){
132 //writef("*");
112 total=0; 133 total=0;
113 // total up w * e for the units the output of this unit goes into 134 // total up w * e for the units the output of this unit goes into
114 for(int k=0; k<units[i+1].length; k++){ 135 for(int k=0; k<units[layer+1].length; k++){
115 total += units[i+1][k][j+1] * outputsError[i+1][k]; 136 //writef("{weight=%f,error=%f}", units[layer+1][k][unit+1/* +1 for bias*/], outputsError[layer+1][k]);
137 total += units[layer+1][k][unit+1/* +1 for bias*/] * outputsError[layer+1][k];
116 } 138 }
139 //writefln("=%f(total)",total);
117 // multiply total by o(1-o), store in outputsError 140 // multiply total by o(1-o), store in outputsError
118 outputsError[i][j] = outputs[i][j] * (1 - outputs[i][j]) * total; 141 outputsError[layer][unit] = outputs[layer][unit] * (1 - outputs[layer][unit]) * total;
119 } 142 }
120 } 143 } //writefln();
144
145 //writef("outputError="); printArray(outputsError);
121 146
122 // special case for the units that receive the input values 147 // special case for the units that receive the input values
123 for(int j=0; j<units[0].length; j++){ // unit 148 for(int unit=0; unit<units[0].length; unit++){ // unit
124 weightUpdate[0][j][0] += outputsError[0][j]; //bias 149 //writefln(":%d,%d,%d,%d",j,weightUpdate.length,weightUpdate[0].length,weightUpdate[0][j].length);
125 for(int k=1; k<units[0][j].length; k++){ // input 150 weightUpdate[0][unit][0] += outputsError[0][unit]; //bias
126 weightUpdate[0][j][k] += outputsError[0][j] * allInputs[example][k-1]; 151 for(int input=1; input<units[0][unit].length; input++){ // input
152 weightUpdate[0][unit][input] += outputsError[0][unit] * trainingInputs[example][input-1]; // account for bias
127 } 153 }
128 } 154 }
129 155
130 // Update the weightUpdate array 156 // Update the weightUpdate array
131 for(int i=1; i<units.length; i++){ // layer 157 for(int i=1; i<units.length; i++){ // layer
132 for(int j=0; j<units[i].length; j++){ // unit 158 for(int j=0; j<units[i].length; j++){ // unit
133 weightUpdate[i][j][0] += outputsError[i][j]; //bias 159 weightUpdate[i][j][0] += outputsError[i][j]; //bias
134 for(int k=1; k<units[i][j].length; k++){ // input 160 for(int k=1; k<units[i][j].length; k++){ // input
161 //writefln("[%d,%d,%d]=%f; %f; %f",i,j,k,weightUpdate[i][j][k],outputsError[i][j],outputs[i-1][k-1]);
135 weightUpdate[i][j][k] += outputsError[i][j] * outputs[i-1][k-1]; // previous layer, account for bias 162 weightUpdate[i][j][k] += outputsError[i][j] * outputs[i-1][k-1]; // previous layer, account for bias
136 } 163 }
137 } 164 }
138 } 165 }
139 } 166 }
140 167
141 // Apply the weightUpdate array to the weights 168 // Apply the weightUpdate array to the weights
142 for(int i=0; i<units.length; i++){ // layer 169 for(int i=0; i<units.length; i++){ // layer
143 for(int j=0; j<units[i].length; j++){ // unit 170 for(int j=0; j<units[i].length; j++){ // unit
144 for(int k=0; k<units[i][j].length; k++){ // input 171 for(int k=0; k<units[i][j].length; k++){ // input
172 //writefln("[%d,%d,%d]=%f; %f",i,j,k,units[i][j][k],weightUpdate[i][j][k]);
145 units[i][j][k] += this.learningRate * weightUpdate[i][j][k]; 173 units[i][j][k] += this.learningRate * weightUpdate[i][j][k];
146 } 174 }
147 } 175 }
148 } 176 }
149 } 177 }
150 178
151 /// Calculate the output error 179 /// Calculate the output error
152 float calculateError(float[][] allInputs, float[][] allOutputs){ 180 double calculateError(double[][] trainingInputs, double[][] trainingOutputs){
153 if(allInputs.length != allOutputs.length) throw new InputException("allInputs and allOutputs must be the same size"); 181 if(trainingInputs.length != trainingOutputs.length) throw new InputException("trainingInputs and trainingOutputs must be the same size");
154 float[] outputs; 182 double[] outputs;
155 float total,temp; 183 double total=0,temp;
156 for(int i=0; i<allInputs.length; i++){ 184 for(int i=0; i<trainingInputs.length; i++){
157 outputs = evaluate(allInputs[i]); 185 outputs = evaluate(trainingInputs[i]);
158 if(outputs.length != allOutputs[i].length) throw new InputException("Wrong output length"); 186 if(outputs.length != trainingOutputs[i].length) throw new InputException("Wrong output length");
159 for(int j=0; j<outputs.length; j++){ 187 for(int j=0; j<outputs.length; j++){
160 temp = allOutputs[i][j] - outputs[j]; 188 temp = trainingOutputs[i][j] - outputs[j];
189 //writefln("&%f,%f",temp*temp,total);
161 total += temp * temp; 190 total += temp * temp;
162 } 191 }
163 } 192 }
164 return 0.5 * total; 193 return 0.5 * total;
165 } 194 }
166 } 195
167 196 double[][][] getWeights(){
168 197 return units.dup;
169 198 }
170 199 }
200
201 void printArray(double[] array){
202 writef("[");
203 for(int i=0; i<array.length-1; i++){
204 writef("%f, ",array[i]);
205 }
206 writefln("%f]",array[$-1]);
207 }
208
209 void printArray(double[][] array){
210 writef("[");
211 for(int i=0; i<array.length; i++){
212 printArray(array[i]);
213 }
214 writefln("]");
215 }
216
217 void printArray(double[][][] array){
218 writef("[");
219 for(int i=0; i<array.length; i++){
220 printArray(array[i]);
221 }
222 writefln("]");
223 }
224
225