Mercurial > projects > aid
diff trunk/backprop_test.d @ 5:810d58835f86
Added momentum and stochastic training to backprop.
author | revcompgeek |
---|---|
date | Tue, 15 Apr 2008 14:39:49 -0600 |
parents | 73beed484455 |
children | ff92c77006c7 |
line wrap: on
line diff
--- a/trunk/backprop_test.d Sat Apr 12 21:55:37 2008 -0600 +++ b/trunk/backprop_test.d Tue Apr 15 14:39:49 2008 -0600 @@ -1,102 +1,177 @@ + module backprop_test; + import aid.nn.multilayer.backprop; import aid.nn.outputFunctions; +import aid.misc; import std.stdio; import std.random; +import std.conv; -/+double[][] trainingInputs = [ - [0,0,0], - [0,0,1], - [0,1,0], - [0,1,1], - [1,0,0], - [1,0,1], - [1,1,0], - [1,1,1]]; +double[][] trainingInputs, trainingOutputs; +uint numInputs; +uint[] outputsArray; -double[][] trainingOutputs = [ - [0.1], - [0.9], - [0.9], - [0.1], - [0.9], - [0.1], - [0.1], - [0.9]];+/ +void initTrainingExample(int example) { + if(example == 0) { + numInputs = 3; + outputsArray = [2,1]; + trainingInputs = [[0, 0, 0], + [0, 0, 1], + [0, 1, 0], + [0, 1, 1], + [1, 0, 0], + [1, 0, 1], + [1, 1, 0], + [1, 1, 1]]; + + trainingOutputs = [[0.1], + [0.9], + [0.9], + [0.1], + [0.9], + [0.1], + [0.1], + [0.9]]; + } else if(example == 1) { + numInputs = 2; + outputsArray = [2,1]; + trainingInputs = [[0, 0], + [1, 0], + [0, 1], + [1, 1]]; + + trainingOutputs = [[0.9], + [0.1], + [0.1], + [0.9]]; + } else if(example == 2) { + numInputs = 8; + outputsArray = [3,8]; + trainingInputs = [ + [0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + [0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9]]; + + trainingOutputs = trainingInputs; + } +} -/+double[][] trainingInputs = [ - [0,0], - [1,0], - [0,1], - [1,1]]; - -double[][] trainingOutputs = [ - [0.9], - [0.1], - [0.1], - [0.9]];+/ - -double[][] trainingInputs = [ - [0.9,0.1,0.1,0.1,0.1,0.1,0.1,0.1], - [0.1,0.9,0.1,0.1,0.1,0.1,0.1,0.1], - [0.1,0.1,0.9,0.1,0.1,0.1,0.1,0.1], - [0.1,0.1,0.1,0.9,0.1,0.1,0.1,0.1], - [0.1,0.1,0.1,0.1,0.9,0.1,0.1,0.1], - [0.1,0.1,0.1,0.1,0.1,0.9,0.1,0.1], - [0.1,0.1,0.1,0.1,0.1,0.1,0.9,0.1], - [0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.9]]; - -void main(){ - //rand_seed(0,0); - Backprop nn = new Backprop(8,[3,8],[&sigmoid,&sigmoid],.1); +void main(char[][] args) { + double learningRate = 0.2, momentum = 0.3, randomSize = 0.1, errorMin = 0.05; + int trainingExample = 0, maxIters = 10000; // 0 to 2 + bool quiet = false; // don't print output each time + int printEvery = 500; // output every ~ times - double error = nn.calculateError(trainingInputs,trainingInputs); + //try { + for(int i = 1; i < args.length; i++) { + switch(args[i]) { + case "-s": + case "--seed": + rand_seed(123, 0); + break; + case "-l": + case "--learning-rate": + //if(args.length = i + 1) + //throw new Error("Wrong number of paramaters"); + learningRate = toDouble(args[++i]); + break; + case "-m": + case "--momentum": + momentum = toDouble(args[++i]); + break; + case "-r": + case "--random-size": + randomSize = toDouble(args[++i]); + break; + case "-e": + case "--error-min": + errorMin = toDouble(args[++i]); + break; + case "-n": + case "--example-number": + trainingExample = toInt(args[++i]); + if(trainingExample > 2 || trainingExample < 0) + throw new Error("example number must be between 0 and 2"); + case "-x": + case "--example": + switch(args[++i]) { + case "parity": + trainingExample = 0; + break; + case "xor": + trainingExample = 1; + break; + case "identity": + trainingExample = 2; + break; + default: + throw new Error("Wrong example name. Must be parity, xor or identity"); + } + break; + case "-q": + case "--quiet": + quiet = true; + break; + case "-p": + case "--print-every": + printEvery = toInt(args[++i]); + break; + case "-i": + case "--min-iters": + case "--min-iterations": + maxIters = toInt(args[++i]); + break; + default: + throw new Error("Unknown switch: " ~ args[i]); + } + } + //} catch(ArrayBoundsError) { + // throw new Error("Wrong number of paramaters"); + //} + + initTrainingExample(trainingExample); + + Backprop nn = new Backprop(numInputs, outputsArray, [&sigmoid, &sigmoid], learningRate, momentum, randomSize, true); + + double error = nn.calculateError(trainingInputs, trainingOutputs); double[] output; int iter = 0; - writef("weights="); printArray(nn.getWeights()); - writef("outputs="); printArray(nn.evaluateFull(trainingInputs[$-1])); - while(error >= 0.01 && iter < 50000){ - if(iter % 500 == 0){ - writefln("Iter: %d",iter); - for(int i=0; i<trainingInputs.length; i++){ - output = nn.evaluate(trainingInputs[i]); - writef(" %d:", i); printArray(output); + //writef("weights="); + //printArray(nn.getWeights()); + //writef("outputs="); + //printArray(nn.evaluateFull(trainingInputs[$ - 1])); + while (error >= errorMin && iter < maxIters) { + if(iter % printEvery == 0) { + writefln("Iter: %d", iter); + if(!quiet) { + for(int i = 0; i < trainingInputs.length; i++) { + output = nn.evaluate(trainingInputs[i]); + writef(" %d:", i); + printArray(output); + } } writefln(" Error: %f", error); } - nn.train(trainingInputs,trainingInputs); - error = nn.calculateError(trainingInputs,trainingInputs); + nn.train(trainingInputs, trainingOutputs, true); + error = nn.calculateError(trainingInputs, trainingOutputs); iter++; } - writefln("Total Iters: %d",iter); - for(int i=0; i<trainingInputs.length; i++){ - writef(" %d:", i); printArray(nn.evaluateFull(trainingInputs[i])[0]); + writefln("Total Iters: %d", iter); + for(int i = 0; i < trainingInputs.length; i++) { + writef(" %d:", i); + if(trainingExample == 2) + printArray(nn.evaluateFull(trainingInputs[i])[0]); + else + printArray(nn.evaluate(trainingInputs[i])); } writefln(" Error: %f", error); - writef("weights="); printArray(nn.getWeights()); -} - -void printArray(double[] array){ - writef("["); - for(int i=0; i<array.length-1; i++){ - writef("%f, ",array[i]); - } - writefln("%f]",array[$-1]); + writef("weights="); + printArray(nn.getWeights()); } - -void printArray(double[][] array){ - writef("["); - for(int i=0; i<array.length; i++){ - printArray(array[i]); - } - writefln("]"); -} - -void printArray(double[][][] array){ - writef("["); - for(int i=0; i<array.length; i++){ - printArray(array[i]); - } - writefln("]"); -} \ No newline at end of file