Mercurial > projects > aid
annotate trunk/backprop_test.d @ 5:810d58835f86
Added momentum and stochastic training to backprop.
author | revcompgeek |
---|---|
date | Tue, 15 Apr 2008 14:39:49 -0600 |
parents | 73beed484455 |
children | ff92c77006c7 |
rev | line source |
---|---|
5
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
1 |
3 | 2 module backprop_test; |
3 | |
5
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
4 |
3 | 5 import aid.nn.multilayer.backprop; |
6 import aid.nn.outputFunctions; | |
5
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
7 import aid.misc; |
3 | 8 import std.stdio; |
4 | 9 import std.random; |
5
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
10 import std.conv; |
3 | 11 |
5
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
12 double[][] trainingInputs, trainingOutputs; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
13 uint numInputs; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
14 uint[] outputsArray; |
3 | 15 |
5
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
16 void initTrainingExample(int example) { |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
17 if(example == 0) { |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
18 numInputs = 3; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
19 outputsArray = [2,1]; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
20 trainingInputs = [[0, 0, 0], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
21 [0, 0, 1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
22 [0, 1, 0], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
23 [0, 1, 1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
24 [1, 0, 0], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
25 [1, 0, 1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
26 [1, 1, 0], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
27 [1, 1, 1]]; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
28 |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
29 trainingOutputs = [[0.1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
30 [0.9], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
31 [0.9], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
32 [0.1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
33 [0.9], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
34 [0.1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
35 [0.1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
36 [0.9]]; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
37 } else if(example == 1) { |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
38 numInputs = 2; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
39 outputsArray = [2,1]; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
40 trainingInputs = [[0, 0], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
41 [1, 0], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
42 [0, 1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
43 [1, 1]]; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
44 |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
45 trainingOutputs = [[0.9], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
46 [0.1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
47 [0.1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
48 [0.9]]; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
49 } else if(example == 2) { |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
50 numInputs = 8; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
51 outputsArray = [3,8]; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
52 trainingInputs = [ |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
53 [0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
54 [0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
55 [0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
56 [0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
57 [0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
58 [0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
59 [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1], |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
60 [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9]]; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
61 |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
62 trainingOutputs = trainingInputs; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
63 } |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
64 } |
3 | 65 |
5
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
66 void main(char[][] args) { |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
67 double learningRate = 0.2, momentum = 0.3, randomSize = 0.1, errorMin = 0.05; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
68 int trainingExample = 0, maxIters = 10000; // 0 to 2 |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
69 bool quiet = false; // don't print output each time |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
70 int printEvery = 500; // output every ~ times |
3 | 71 |
5
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
72 //try { |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
73 for(int i = 1; i < args.length; i++) { |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
74 switch(args[i]) { |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
75 case "-s": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
76 case "--seed": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
77 rand_seed(123, 0); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
78 break; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
79 case "-l": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
80 case "--learning-rate": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
81 //if(args.length = i + 1) |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
82 //throw new Error("Wrong number of paramaters"); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
83 learningRate = toDouble(args[++i]); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
84 break; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
85 case "-m": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
86 case "--momentum": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
87 momentum = toDouble(args[++i]); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
88 break; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
89 case "-r": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
90 case "--random-size": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
91 randomSize = toDouble(args[++i]); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
92 break; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
93 case "-e": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
94 case "--error-min": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
95 errorMin = toDouble(args[++i]); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
96 break; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
97 case "-n": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
98 case "--example-number": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
99 trainingExample = toInt(args[++i]); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
100 if(trainingExample > 2 || trainingExample < 0) |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
101 throw new Error("example number must be between 0 and 2"); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
102 case "-x": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
103 case "--example": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
104 switch(args[++i]) { |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
105 case "parity": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
106 trainingExample = 0; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
107 break; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
108 case "xor": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
109 trainingExample = 1; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
110 break; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
111 case "identity": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
112 trainingExample = 2; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
113 break; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
114 default: |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
115 throw new Error("Wrong example name. Must be parity, xor or identity"); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
116 } |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
117 break; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
118 case "-q": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
119 case "--quiet": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
120 quiet = true; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
121 break; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
122 case "-p": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
123 case "--print-every": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
124 printEvery = toInt(args[++i]); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
125 break; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
126 case "-i": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
127 case "--min-iters": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
128 case "--min-iterations": |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
129 maxIters = toInt(args[++i]); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
130 break; |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
131 default: |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
132 throw new Error("Unknown switch: " ~ args[i]); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
133 } |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
134 } |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
135 //} catch(ArrayBoundsError) { |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
136 // throw new Error("Wrong number of paramaters"); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
137 //} |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
138 |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
139 initTrainingExample(trainingExample); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
140 |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
141 Backprop nn = new Backprop(numInputs, outputsArray, [&sigmoid, &sigmoid], learningRate, momentum, randomSize, true); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
142 |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
143 double error = nn.calculateError(trainingInputs, trainingOutputs); |
4 | 144 double[] output; |
3 | 145 int iter = 0; |
5
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
146 //writef("weights="); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
147 //printArray(nn.getWeights()); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
148 //writef("outputs="); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
149 //printArray(nn.evaluateFull(trainingInputs[$ - 1])); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
150 while (error >= errorMin && iter < maxIters) { |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
151 if(iter % printEvery == 0) { |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
152 writefln("Iter: %d", iter); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
153 if(!quiet) { |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
154 for(int i = 0; i < trainingInputs.length; i++) { |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
155 output = nn.evaluate(trainingInputs[i]); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
156 writef(" %d:", i); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
157 printArray(output); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
158 } |
3 | 159 } |
160 writefln(" Error: %f", error); | |
161 } | |
5
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
162 nn.train(trainingInputs, trainingOutputs, true); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
163 error = nn.calculateError(trainingInputs, trainingOutputs); |
4 | 164 iter++; |
3 | 165 } |
5
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
166 writefln("Total Iters: %d", iter); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
167 for(int i = 0; i < trainingInputs.length; i++) { |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
168 writef(" %d:", i); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
169 if(trainingExample == 2) |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
170 printArray(nn.evaluateFull(trainingInputs[i])[0]); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
171 else |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
172 printArray(nn.evaluate(trainingInputs[i])); |
3 | 173 } |
174 writefln(" Error: %f", error); | |
5
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
175 writef("weights="); |
810d58835f86
Added momentum and stochastic training to backprop.
revcompgeek
parents:
4
diff
changeset
|
176 printArray(nn.getWeights()); |
4 | 177 } |