comparison trunk/backprop_test.d @ 5:810d58835f86

Added momentum and stochastic training to backprop.
author revcompgeek
date Tue, 15 Apr 2008 14:39:49 -0600
parents 73beed484455
children ff92c77006c7
comparison
equal deleted inserted replaced
4:73beed484455 5:810d58835f86
1
1 module backprop_test; 2 module backprop_test;
3
2 4
3 import aid.nn.multilayer.backprop; 5 import aid.nn.multilayer.backprop;
4 import aid.nn.outputFunctions; 6 import aid.nn.outputFunctions;
7 import aid.misc;
5 import std.stdio; 8 import std.stdio;
6 import std.random; 9 import std.random;
10 import std.conv;
7 11
8 /+double[][] trainingInputs = [ 12 double[][] trainingInputs, trainingOutputs;
9 [0,0,0], 13 uint numInputs;
10 [0,0,1], 14 uint[] outputsArray;
11 [0,1,0],
12 [0,1,1],
13 [1,0,0],
14 [1,0,1],
15 [1,1,0],
16 [1,1,1]];
17 15
18 double[][] trainingOutputs = [ 16 void initTrainingExample(int example) {
19 [0.1], 17 if(example == 0) {
20 [0.9], 18 numInputs = 3;
21 [0.9], 19 outputsArray = [2,1];
22 [0.1], 20 trainingInputs = [[0, 0, 0],
23 [0.9], 21 [0, 0, 1],
24 [0.1], 22 [0, 1, 0],
25 [0.1], 23 [0, 1, 1],
26 [0.9]];+/ 24 [1, 0, 0],
25 [1, 0, 1],
26 [1, 1, 0],
27 [1, 1, 1]];
28
29 trainingOutputs = [[0.1],
30 [0.9],
31 [0.9],
32 [0.1],
33 [0.9],
34 [0.1],
35 [0.1],
36 [0.9]];
37 } else if(example == 1) {
38 numInputs = 2;
39 outputsArray = [2,1];
40 trainingInputs = [[0, 0],
41 [1, 0],
42 [0, 1],
43 [1, 1]];
44
45 trainingOutputs = [[0.9],
46 [0.1],
47 [0.1],
48 [0.9]];
49 } else if(example == 2) {
50 numInputs = 8;
51 outputsArray = [3,8];
52 trainingInputs = [
53 [0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
54 [0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
55 [0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1],
56 [0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1],
57 [0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1],
58 [0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1],
59 [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1],
60 [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9]];
61
62 trainingOutputs = trainingInputs;
63 }
64 }
27 65
28 /+double[][] trainingInputs = [ 66 void main(char[][] args) {
29 [0,0], 67 double learningRate = 0.2, momentum = 0.3, randomSize = 0.1, errorMin = 0.05;
30 [1,0], 68 int trainingExample = 0, maxIters = 10000; // 0 to 2
31 [0,1], 69 bool quiet = false; // don't print output each time
32 [1,1]]; 70 int printEvery = 500; // output every ~ times
33
34 double[][] trainingOutputs = [
35 [0.9],
36 [0.1],
37 [0.1],
38 [0.9]];+/
39
40 double[][] trainingInputs = [
41 [0.9,0.1,0.1,0.1,0.1,0.1,0.1,0.1],
42 [0.1,0.9,0.1,0.1,0.1,0.1,0.1,0.1],
43 [0.1,0.1,0.9,0.1,0.1,0.1,0.1,0.1],
44 [0.1,0.1,0.1,0.9,0.1,0.1,0.1,0.1],
45 [0.1,0.1,0.1,0.1,0.9,0.1,0.1,0.1],
46 [0.1,0.1,0.1,0.1,0.1,0.9,0.1,0.1],
47 [0.1,0.1,0.1,0.1,0.1,0.1,0.9,0.1],
48 [0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.9]];
49
50 void main(){
51 //rand_seed(0,0);
52 Backprop nn = new Backprop(8,[3,8],[&sigmoid,&sigmoid],.1);
53 71
54 double error = nn.calculateError(trainingInputs,trainingInputs); 72 //try {
73 for(int i = 1; i < args.length; i++) {
74 switch(args[i]) {
75 case "-s":
76 case "--seed":
77 rand_seed(123, 0);
78 break;
79 case "-l":
80 case "--learning-rate":
81 //if(args.length = i + 1)
82 //throw new Error("Wrong number of paramaters");
83 learningRate = toDouble(args[++i]);
84 break;
85 case "-m":
86 case "--momentum":
87 momentum = toDouble(args[++i]);
88 break;
89 case "-r":
90 case "--random-size":
91 randomSize = toDouble(args[++i]);
92 break;
93 case "-e":
94 case "--error-min":
95 errorMin = toDouble(args[++i]);
96 break;
97 case "-n":
98 case "--example-number":
99 trainingExample = toInt(args[++i]);
100 if(trainingExample > 2 || trainingExample < 0)
101 throw new Error("example number must be between 0 and 2");
102 case "-x":
103 case "--example":
104 switch(args[++i]) {
105 case "parity":
106 trainingExample = 0;
107 break;
108 case "xor":
109 trainingExample = 1;
110 break;
111 case "identity":
112 trainingExample = 2;
113 break;
114 default:
115 throw new Error("Wrong example name. Must be parity, xor or identity");
116 }
117 break;
118 case "-q":
119 case "--quiet":
120 quiet = true;
121 break;
122 case "-p":
123 case "--print-every":
124 printEvery = toInt(args[++i]);
125 break;
126 case "-i":
127 case "--min-iters":
128 case "--min-iterations":
129 maxIters = toInt(args[++i]);
130 break;
131 default:
132 throw new Error("Unknown switch: " ~ args[i]);
133 }
134 }
135 //} catch(ArrayBoundsError) {
136 // throw new Error("Wrong number of paramaters");
137 //}
138
139 initTrainingExample(trainingExample);
140
141 Backprop nn = new Backprop(numInputs, outputsArray, [&sigmoid, &sigmoid], learningRate, momentum, randomSize, true);
142
143 double error = nn.calculateError(trainingInputs, trainingOutputs);
55 double[] output; 144 double[] output;
56 int iter = 0; 145 int iter = 0;
57 writef("weights="); printArray(nn.getWeights()); 146 //writef("weights=");
58 writef("outputs="); printArray(nn.evaluateFull(trainingInputs[$-1])); 147 //printArray(nn.getWeights());
59 while(error >= 0.01 && iter < 50000){ 148 //writef("outputs=");
60 if(iter % 500 == 0){ 149 //printArray(nn.evaluateFull(trainingInputs[$ - 1]));
61 writefln("Iter: %d",iter); 150 while (error >= errorMin && iter < maxIters) {
62 for(int i=0; i<trainingInputs.length; i++){ 151 if(iter % printEvery == 0) {
63 output = nn.evaluate(trainingInputs[i]); 152 writefln("Iter: %d", iter);
64 writef(" %d:", i); printArray(output); 153 if(!quiet) {
154 for(int i = 0; i < trainingInputs.length; i++) {
155 output = nn.evaluate(trainingInputs[i]);
156 writef(" %d:", i);
157 printArray(output);
158 }
65 } 159 }
66 writefln(" Error: %f", error); 160 writefln(" Error: %f", error);
67 } 161 }
68 nn.train(trainingInputs,trainingInputs); 162 nn.train(trainingInputs, trainingOutputs, true);
69 error = nn.calculateError(trainingInputs,trainingInputs); 163 error = nn.calculateError(trainingInputs, trainingOutputs);
70 iter++; 164 iter++;
71 } 165 }
72 writefln("Total Iters: %d",iter); 166 writefln("Total Iters: %d", iter);
73 for(int i=0; i<trainingInputs.length; i++){ 167 for(int i = 0; i < trainingInputs.length; i++) {
74 writef(" %d:", i); printArray(nn.evaluateFull(trainingInputs[i])[0]); 168 writef(" %d:", i);
169 if(trainingExample == 2)
170 printArray(nn.evaluateFull(trainingInputs[i])[0]);
171 else
172 printArray(nn.evaluate(trainingInputs[i]));
75 } 173 }
76 writefln(" Error: %f", error); 174 writefln(" Error: %f", error);
77 writef("weights="); printArray(nn.getWeights()); 175 writef("weights=");
176 printArray(nn.getWeights());
78 } 177 }
79
80 void printArray(double[] array){
81 writef("[");
82 for(int i=0; i<array.length-1; i++){
83 writef("%f, ",array[i]);
84 }
85 writefln("%f]",array[$-1]);
86 }
87
88 void printArray(double[][] array){
89 writef("[");
90 for(int i=0; i<array.length; i++){
91 printArray(array[i]);
92 }
93 writefln("]");
94 }
95
96 void printArray(double[][][] array){
97 writef("[");
98 for(int i=0; i<array.length; i++){
99 printArray(array[i]);
100 }
101 writefln("]");
102 }