Mercurial > projects > aid
comparison trunk/backprop_test.d @ 4:73beed484455
Backprop working correctly.
author | revcompgeek |
---|---|
date | Sat, 12 Apr 2008 21:55:37 -0600 |
parents | 314d68bafeff |
children | 810d58835f86 |
comparison
equal
deleted
inserted
replaced
3:314d68bafeff | 4:73beed484455 |
---|---|
1 module backprop_test; | 1 module backprop_test; |
2 | 2 |
3 import aid.nn.multilayer.backprop; | 3 import aid.nn.multilayer.backprop; |
4 import aid.nn.outputFunctions; | 4 import aid.nn.outputFunctions; |
5 import std.stdio; | 5 import std.stdio; |
6 import std.random; | |
6 | 7 |
7 /+float[][] trainingInputs = [ | 8 /+double[][] trainingInputs = [ |
8 [0,0,0], | 9 [0,0,0], |
9 [0,0,1], | 10 [0,0,1], |
10 [0,1,0], | 11 [0,1,0], |
11 [0,1,1], | 12 [0,1,1], |
12 [1,0,0], | 13 [1,0,0], |
13 [1,0,1], | 14 [1,0,1], |
14 [1,1,0], | 15 [1,1,0], |
15 [1,1,1]]; | 16 [1,1,1]]; |
16 | 17 |
17 float[][] trainingOutputs = [ | 18 double[][] trainingOutputs = [ |
18 [0.1], | 19 [0.1], |
19 [0.9], | 20 [0.9], |
20 [0.9], | 21 [0.9], |
21 [0.1], | 22 [0.1], |
22 [0.9], | 23 [0.9], |
23 [0.1], | 24 [0.1], |
24 [0.1], | 25 [0.1], |
25 [0.9]];+/ | 26 [0.9]];+/ |
26 | 27 |
27 float[][] trainingInputs = [ | 28 /+double[][] trainingInputs = [ |
28 [0,0], | 29 [0,0], |
30 [1,0], | |
29 [0,1], | 31 [0,1], |
30 [1,0], | |
31 [1,1]]; | 32 [1,1]]; |
32 | 33 |
33 float[][] trainingOutputs = [ | 34 double[][] trainingOutputs = [ |
35 [0.9], | |
34 [0.1], | 36 [0.1], |
35 [0.9], | 37 [0.1], |
36 [0.9], | 38 [0.9]];+/ |
37 [0.1]]; | 39 |
40 double[][] trainingInputs = [ | |
41 [0.9,0.1,0.1,0.1,0.1,0.1,0.1,0.1], | |
42 [0.1,0.9,0.1,0.1,0.1,0.1,0.1,0.1], | |
43 [0.1,0.1,0.9,0.1,0.1,0.1,0.1,0.1], | |
44 [0.1,0.1,0.1,0.9,0.1,0.1,0.1,0.1], | |
45 [0.1,0.1,0.1,0.1,0.9,0.1,0.1,0.1], | |
46 [0.1,0.1,0.1,0.1,0.1,0.9,0.1,0.1], | |
47 [0.1,0.1,0.1,0.1,0.1,0.1,0.9,0.1], | |
48 [0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.9]]; | |
38 | 49 |
39 void main(){ | 50 void main(){ |
40 Backprop nn = new Backprop(2,[4,1],[&sigmoid,&sigmoid]); | 51 //rand_seed(0,0); |
52 Backprop nn = new Backprop(8,[3,8],[&sigmoid,&sigmoid],.1); | |
41 | 53 |
42 float error = 10.0; | 54 double error = nn.calculateError(trainingInputs,trainingInputs); |
43 float[] output; | 55 double[] output; |
44 int iter = 0; | 56 int iter = 0; |
45 while(error >= 0.5){ | 57 writef("weights="); printArray(nn.getWeights()); |
46 error = nn.calculateError(trainingInputs,trainingOutputs); | 58 writef("outputs="); printArray(nn.evaluateFull(trainingInputs[$-1])); |
47 if(iter % 100 == 0){ | 59 while(error >= 0.01 && iter < 50000){ |
60 if(iter % 500 == 0){ | |
48 writefln("Iter: %d",iter); | 61 writefln("Iter: %d",iter); |
49 for(int i=0; i<trainingInputs.length; i++){ | 62 for(int i=0; i<trainingInputs.length; i++){ |
50 output = nn.evaluate(trainingInputs[i]); | 63 output = nn.evaluate(trainingInputs[i]); |
51 writef(" %d:", i); printArray(output); | 64 writef(" %d:", i); printArray(output); |
52 } | 65 } |
53 writefln(" Error: %f", error); | 66 writefln(" Error: %f", error); |
54 } | 67 } |
55 nn.train(trainingInputs,trainingOutputs); | 68 nn.train(trainingInputs,trainingInputs); |
69 error = nn.calculateError(trainingInputs,trainingInputs); | |
70 iter++; | |
56 } | 71 } |
57 writefln("Total Iters: %d",iter); | 72 writefln("Total Iters: %d",iter); |
58 for(int i=0; i<trainingInputs.length; i++){ | 73 for(int i=0; i<trainingInputs.length; i++){ |
59 output = nn.evaluate(trainingInputs[i]); | 74 writef(" %d:", i); printArray(nn.evaluateFull(trainingInputs[i])[0]); |
60 writef(" %d:", i); printArray(output); | |
61 } | 75 } |
62 writefln(" Error: %f", error); | 76 writefln(" Error: %f", error); |
77 writef("weights="); printArray(nn.getWeights()); | |
63 } | 78 } |
64 | 79 |
65 void printArray(float[] array){ | 80 void printArray(double[] array){ |
66 writef("["); | 81 writef("["); |
67 for(int i=0; i<array.length-1; i++){ | 82 for(int i=0; i<array.length-1; i++){ |
68 writef("%f, ",array[i]); | 83 writef("%f, ",array[i]); |
69 } | 84 } |
70 writefln("%f]",array[$]); | 85 writefln("%f]",array[$-1]); |
71 } | 86 } |
87 | |
88 void printArray(double[][] array){ | |
89 writef("["); | |
90 for(int i=0; i<array.length; i++){ | |
91 printArray(array[i]); | |
92 } | |
93 writefln("]"); | |
94 } | |
95 | |
96 void printArray(double[][][] array){ | |
97 writef("["); | |
98 for(int i=0; i<array.length; i++){ | |
99 printArray(array[i]); | |
100 } | |
101 writefln("]"); | |
102 } |