comparison trunk/backprop_test.d @ 4:73beed484455

Backprop working correctly.
author revcompgeek
date Sat, 12 Apr 2008 21:55:37 -0600
parents 314d68bafeff
children 810d58835f86
comparison
equal deleted inserted replaced
3:314d68bafeff 4:73beed484455
1 module backprop_test; 1 module backprop_test;
2 2
3 import aid.nn.multilayer.backprop; 3 import aid.nn.multilayer.backprop;
4 import aid.nn.outputFunctions; 4 import aid.nn.outputFunctions;
5 import std.stdio; 5 import std.stdio;
6 import std.random;
6 7
7 /+float[][] trainingInputs = [ 8 /+double[][] trainingInputs = [
8 [0,0,0], 9 [0,0,0],
9 [0,0,1], 10 [0,0,1],
10 [0,1,0], 11 [0,1,0],
11 [0,1,1], 12 [0,1,1],
12 [1,0,0], 13 [1,0,0],
13 [1,0,1], 14 [1,0,1],
14 [1,1,0], 15 [1,1,0],
15 [1,1,1]]; 16 [1,1,1]];
16 17
17 float[][] trainingOutputs = [ 18 double[][] trainingOutputs = [
18 [0.1], 19 [0.1],
19 [0.9], 20 [0.9],
20 [0.9], 21 [0.9],
21 [0.1], 22 [0.1],
22 [0.9], 23 [0.9],
23 [0.1], 24 [0.1],
24 [0.1], 25 [0.1],
25 [0.9]];+/ 26 [0.9]];+/
26 27
27 float[][] trainingInputs = [ 28 /+double[][] trainingInputs = [
28 [0,0], 29 [0,0],
30 [1,0],
29 [0,1], 31 [0,1],
30 [1,0],
31 [1,1]]; 32 [1,1]];
32 33
33 float[][] trainingOutputs = [ 34 double[][] trainingOutputs = [
35 [0.9],
34 [0.1], 36 [0.1],
35 [0.9], 37 [0.1],
36 [0.9], 38 [0.9]];+/
37 [0.1]]; 39
40 double[][] trainingInputs = [
41 [0.9,0.1,0.1,0.1,0.1,0.1,0.1,0.1],
42 [0.1,0.9,0.1,0.1,0.1,0.1,0.1,0.1],
43 [0.1,0.1,0.9,0.1,0.1,0.1,0.1,0.1],
44 [0.1,0.1,0.1,0.9,0.1,0.1,0.1,0.1],
45 [0.1,0.1,0.1,0.1,0.9,0.1,0.1,0.1],
46 [0.1,0.1,0.1,0.1,0.1,0.9,0.1,0.1],
47 [0.1,0.1,0.1,0.1,0.1,0.1,0.9,0.1],
48 [0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.9]];
38 49
39 void main(){ 50 void main(){
40 Backprop nn = new Backprop(2,[4,1],[&sigmoid,&sigmoid]); 51 //rand_seed(0,0);
52 Backprop nn = new Backprop(8,[3,8],[&sigmoid,&sigmoid],.1);
41 53
42 float error = 10.0; 54 double error = nn.calculateError(trainingInputs,trainingInputs);
43 float[] output; 55 double[] output;
44 int iter = 0; 56 int iter = 0;
45 while(error >= 0.5){ 57 writef("weights="); printArray(nn.getWeights());
46 error = nn.calculateError(trainingInputs,trainingOutputs); 58 writef("outputs="); printArray(nn.evaluateFull(trainingInputs[$-1]));
47 if(iter % 100 == 0){ 59 while(error >= 0.01 && iter < 50000){
60 if(iter % 500 == 0){
48 writefln("Iter: %d",iter); 61 writefln("Iter: %d",iter);
49 for(int i=0; i<trainingInputs.length; i++){ 62 for(int i=0; i<trainingInputs.length; i++){
50 output = nn.evaluate(trainingInputs[i]); 63 output = nn.evaluate(trainingInputs[i]);
51 writef(" %d:", i); printArray(output); 64 writef(" %d:", i); printArray(output);
52 } 65 }
53 writefln(" Error: %f", error); 66 writefln(" Error: %f", error);
54 } 67 }
55 nn.train(trainingInputs,trainingOutputs); 68 nn.train(trainingInputs,trainingInputs);
69 error = nn.calculateError(trainingInputs,trainingInputs);
70 iter++;
56 } 71 }
57 writefln("Total Iters: %d",iter); 72 writefln("Total Iters: %d",iter);
58 for(int i=0; i<trainingInputs.length; i++){ 73 for(int i=0; i<trainingInputs.length; i++){
59 output = nn.evaluate(trainingInputs[i]); 74 writef(" %d:", i); printArray(nn.evaluateFull(trainingInputs[i])[0]);
60 writef(" %d:", i); printArray(output);
61 } 75 }
62 writefln(" Error: %f", error); 76 writefln(" Error: %f", error);
77 writef("weights="); printArray(nn.getWeights());
63 } 78 }
64 79
65 void printArray(float[] array){ 80 void printArray(double[] array){
66 writef("["); 81 writef("[");
67 for(int i=0; i<array.length-1; i++){ 82 for(int i=0; i<array.length-1; i++){
68 writef("%f, ",array[i]); 83 writef("%f, ",array[i]);
69 } 84 }
70 writefln("%f]",array[$]); 85 writefln("%f]",array[$-1]);
71 } 86 }
87
88 void printArray(double[][] array){
89 writef("[");
90 for(int i=0; i<array.length; i++){
91 printArray(array[i]);
92 }
93 writefln("]");
94 }
95
96 void printArray(double[][][] array){
97 writef("[");
98 for(int i=0; i<array.length; i++){
99 printArray(array[i]);
100 }
101 writefln("]");
102 }