Mercurial > projects > aid
comparison trunk/backprop_test.d @ 5:810d58835f86
Added momentum and stochastic training to backprop.
author | revcompgeek |
---|---|
date | Tue, 15 Apr 2008 14:39:49 -0600 |
parents | 73beed484455 |
children | ff92c77006c7 |
comparison
equal
deleted
inserted
replaced
4:73beed484455 | 5:810d58835f86 |
---|---|
1 | |
1 module backprop_test; | 2 module backprop_test; |
3 | |
2 | 4 |
3 import aid.nn.multilayer.backprop; | 5 import aid.nn.multilayer.backprop; |
4 import aid.nn.outputFunctions; | 6 import aid.nn.outputFunctions; |
7 import aid.misc; | |
5 import std.stdio; | 8 import std.stdio; |
6 import std.random; | 9 import std.random; |
10 import std.conv; | |
7 | 11 |
8 /+double[][] trainingInputs = [ | 12 double[][] trainingInputs, trainingOutputs; |
9 [0,0,0], | 13 uint numInputs; |
10 [0,0,1], | 14 uint[] outputsArray; |
11 [0,1,0], | |
12 [0,1,1], | |
13 [1,0,0], | |
14 [1,0,1], | |
15 [1,1,0], | |
16 [1,1,1]]; | |
17 | 15 |
18 double[][] trainingOutputs = [ | 16 void initTrainingExample(int example) { |
19 [0.1], | 17 if(example == 0) { |
20 [0.9], | 18 numInputs = 3; |
21 [0.9], | 19 outputsArray = [2,1]; |
22 [0.1], | 20 trainingInputs = [[0, 0, 0], |
23 [0.9], | 21 [0, 0, 1], |
24 [0.1], | 22 [0, 1, 0], |
25 [0.1], | 23 [0, 1, 1], |
26 [0.9]];+/ | 24 [1, 0, 0], |
25 [1, 0, 1], | |
26 [1, 1, 0], | |
27 [1, 1, 1]]; | |
28 | |
29 trainingOutputs = [[0.1], | |
30 [0.9], | |
31 [0.9], | |
32 [0.1], | |
33 [0.9], | |
34 [0.1], | |
35 [0.1], | |
36 [0.9]]; | |
37 } else if(example == 1) { | |
38 numInputs = 2; | |
39 outputsArray = [2,1]; | |
40 trainingInputs = [[0, 0], | |
41 [1, 0], | |
42 [0, 1], | |
43 [1, 1]]; | |
44 | |
45 trainingOutputs = [[0.9], | |
46 [0.1], | |
47 [0.1], | |
48 [0.9]]; | |
49 } else if(example == 2) { | |
50 numInputs = 8; | |
51 outputsArray = [3,8]; | |
52 trainingInputs = [ | |
53 [0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |
54 [0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |
55 [0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1], | |
56 [0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1], | |
57 [0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1], | |
58 [0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1], | |
59 [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1], | |
60 [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9]]; | |
61 | |
62 trainingOutputs = trainingInputs; | |
63 } | |
64 } | |
27 | 65 |
28 /+double[][] trainingInputs = [ | 66 void main(char[][] args) { |
29 [0,0], | 67 double learningRate = 0.2, momentum = 0.3, randomSize = 0.1, errorMin = 0.05; |
30 [1,0], | 68 int trainingExample = 0, maxIters = 10000; // 0 to 2 |
31 [0,1], | 69 bool quiet = false; // don't print output each time |
32 [1,1]]; | 70 int printEvery = 500; // output every ~ times |
33 | |
34 double[][] trainingOutputs = [ | |
35 [0.9], | |
36 [0.1], | |
37 [0.1], | |
38 [0.9]];+/ | |
39 | |
40 double[][] trainingInputs = [ | |
41 [0.9,0.1,0.1,0.1,0.1,0.1,0.1,0.1], | |
42 [0.1,0.9,0.1,0.1,0.1,0.1,0.1,0.1], | |
43 [0.1,0.1,0.9,0.1,0.1,0.1,0.1,0.1], | |
44 [0.1,0.1,0.1,0.9,0.1,0.1,0.1,0.1], | |
45 [0.1,0.1,0.1,0.1,0.9,0.1,0.1,0.1], | |
46 [0.1,0.1,0.1,0.1,0.1,0.9,0.1,0.1], | |
47 [0.1,0.1,0.1,0.1,0.1,0.1,0.9,0.1], | |
48 [0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.9]]; | |
49 | |
50 void main(){ | |
51 //rand_seed(0,0); | |
52 Backprop nn = new Backprop(8,[3,8],[&sigmoid,&sigmoid],.1); | |
53 | 71 |
54 double error = nn.calculateError(trainingInputs,trainingInputs); | 72 //try { |
73 for(int i = 1; i < args.length; i++) { | |
74 switch(args[i]) { | |
75 case "-s": | |
76 case "--seed": | |
77 rand_seed(123, 0); | |
78 break; | |
79 case "-l": | |
80 case "--learning-rate": | |
81 //if(args.length = i + 1) | |
82 //throw new Error("Wrong number of paramaters"); | |
83 learningRate = toDouble(args[++i]); | |
84 break; | |
85 case "-m": | |
86 case "--momentum": | |
87 momentum = toDouble(args[++i]); | |
88 break; | |
89 case "-r": | |
90 case "--random-size": | |
91 randomSize = toDouble(args[++i]); | |
92 break; | |
93 case "-e": | |
94 case "--error-min": | |
95 errorMin = toDouble(args[++i]); | |
96 break; | |
97 case "-n": | |
98 case "--example-number": | |
99 trainingExample = toInt(args[++i]); | |
100 if(trainingExample > 2 || trainingExample < 0) | |
101 throw new Error("example number must be between 0 and 2"); | |
102 case "-x": | |
103 case "--example": | |
104 switch(args[++i]) { | |
105 case "parity": | |
106 trainingExample = 0; | |
107 break; | |
108 case "xor": | |
109 trainingExample = 1; | |
110 break; | |
111 case "identity": | |
112 trainingExample = 2; | |
113 break; | |
114 default: | |
115 throw new Error("Wrong example name. Must be parity, xor or identity"); | |
116 } | |
117 break; | |
118 case "-q": | |
119 case "--quiet": | |
120 quiet = true; | |
121 break; | |
122 case "-p": | |
123 case "--print-every": | |
124 printEvery = toInt(args[++i]); | |
125 break; | |
126 case "-i": | |
127 case "--min-iters": | |
128 case "--min-iterations": | |
129 maxIters = toInt(args[++i]); | |
130 break; | |
131 default: | |
132 throw new Error("Unknown switch: " ~ args[i]); | |
133 } | |
134 } | |
135 //} catch(ArrayBoundsError) { | |
136 // throw new Error("Wrong number of paramaters"); | |
137 //} | |
138 | |
139 initTrainingExample(trainingExample); | |
140 | |
141 Backprop nn = new Backprop(numInputs, outputsArray, [&sigmoid, &sigmoid], learningRate, momentum, randomSize, true); | |
142 | |
143 double error = nn.calculateError(trainingInputs, trainingOutputs); | |
55 double[] output; | 144 double[] output; |
56 int iter = 0; | 145 int iter = 0; |
57 writef("weights="); printArray(nn.getWeights()); | 146 //writef("weights="); |
58 writef("outputs="); printArray(nn.evaluateFull(trainingInputs[$-1])); | 147 //printArray(nn.getWeights()); |
59 while(error >= 0.01 && iter < 50000){ | 148 //writef("outputs="); |
60 if(iter % 500 == 0){ | 149 //printArray(nn.evaluateFull(trainingInputs[$ - 1])); |
61 writefln("Iter: %d",iter); | 150 while (error >= errorMin && iter < maxIters) { |
62 for(int i=0; i<trainingInputs.length; i++){ | 151 if(iter % printEvery == 0) { |
63 output = nn.evaluate(trainingInputs[i]); | 152 writefln("Iter: %d", iter); |
64 writef(" %d:", i); printArray(output); | 153 if(!quiet) { |
154 for(int i = 0; i < trainingInputs.length; i++) { | |
155 output = nn.evaluate(trainingInputs[i]); | |
156 writef(" %d:", i); | |
157 printArray(output); | |
158 } | |
65 } | 159 } |
66 writefln(" Error: %f", error); | 160 writefln(" Error: %f", error); |
67 } | 161 } |
68 nn.train(trainingInputs,trainingInputs); | 162 nn.train(trainingInputs, trainingOutputs, true); |
69 error = nn.calculateError(trainingInputs,trainingInputs); | 163 error = nn.calculateError(trainingInputs, trainingOutputs); |
70 iter++; | 164 iter++; |
71 } | 165 } |
72 writefln("Total Iters: %d",iter); | 166 writefln("Total Iters: %d", iter); |
73 for(int i=0; i<trainingInputs.length; i++){ | 167 for(int i = 0; i < trainingInputs.length; i++) { |
74 writef(" %d:", i); printArray(nn.evaluateFull(trainingInputs[i])[0]); | 168 writef(" %d:", i); |
169 if(trainingExample == 2) | |
170 printArray(nn.evaluateFull(trainingInputs[i])[0]); | |
171 else | |
172 printArray(nn.evaluate(trainingInputs[i])); | |
75 } | 173 } |
76 writefln(" Error: %f", error); | 174 writefln(" Error: %f", error); |
77 writef("weights="); printArray(nn.getWeights()); | 175 writef("weights="); |
176 printArray(nn.getWeights()); | |
78 } | 177 } |
79 | |
80 void printArray(double[] array){ | |
81 writef("["); | |
82 for(int i=0; i<array.length-1; i++){ | |
83 writef("%f, ",array[i]); | |
84 } | |
85 writefln("%f]",array[$-1]); | |
86 } | |
87 | |
88 void printArray(double[][] array){ | |
89 writef("["); | |
90 for(int i=0; i<array.length; i++){ | |
91 printArray(array[i]); | |
92 } | |
93 writefln("]"); | |
94 } | |
95 | |
96 void printArray(double[][][] array){ | |
97 writef("["); | |
98 for(int i=0; i<array.length; i++){ | |
99 printArray(array[i]); | |
100 } | |
101 writefln("]"); | |
102 } |