Mercurial > projects > aid
changeset 6:ff92c77006c7
Added support for reading training examples from files.
author | revcompgeek |
---|---|
date | Tue, 06 May 2008 21:43:55 -0600 |
parents | 810d58835f86 |
children | b9fe92a2d8ad |
files | trunk/aid/astar.d trunk/aid/nn/multilayer/backprop.d trunk/aid/nn/outputFunctions.d trunk/aid/nn/perceptron.d trunk/backprop_test.d trunk/dsss.conf trunk/trainingexamples/binaryadd4.trnex trunk/trainingexamples/identity.trnex trunk/trainingexamples/parody2.trnex trunk/trainingexamples/parody3.trnex |
diffstat | 10 files changed, 216 insertions(+), 27 deletions(-) [+] |
line wrap: on
line diff
--- a/trunk/aid/astar.d Tue Apr 15 14:39:49 2008 -0600 +++ b/trunk/aid/astar.d Tue May 06 21:43:55 2008 -0600 @@ -7,6 +7,7 @@ import mintl.arrayheap; import mintl.arraylist; +import xenon.font; class Node(DATA) { int xloc;
--- a/trunk/aid/nn/multilayer/backprop.d Tue Apr 15 14:39:49 2008 -0600 +++ b/trunk/aid/nn/multilayer/backprop.d Tue May 06 21:43:55 2008 -0600 @@ -1,3 +1,7 @@ +/** + * backprop.d + * Holds the backpropagation neural network. + */ module aid.nn.multilevel.backprop;
--- a/trunk/aid/nn/outputFunctions.d Tue Apr 15 14:39:49 2008 -0600 +++ b/trunk/aid/nn/outputFunctions.d Tue May 06 21:43:55 2008 -0600 @@ -1,3 +1,8 @@ +/** + * outputFunctions.d + * Holds all of the output functions used by the neural networks. + */ + module aid.nn.outputFunctions; import aid.nn.outputFunctions;
--- a/trunk/aid/nn/perceptron.d Tue Apr 15 14:39:49 2008 -0600 +++ b/trunk/aid/nn/perceptron.d Tue May 06 21:43:55 2008 -0600 @@ -1,3 +1,8 @@ +/** + * perceptron.d + * Holds the simple perceptron class. + */ + module aid.nn.perceptron; import aid.nn.outputFunctions;
--- a/trunk/backprop_test.d Tue Apr 15 14:39:49 2008 -0600 +++ b/trunk/backprop_test.d Tue May 06 21:43:55 2008 -0600 @@ -1,19 +1,110 @@ - module backprop_test; - import aid.nn.multilayer.backprop; import aid.nn.outputFunctions; import aid.misc; import std.stdio; import std.random; import std.conv; +import std.string; +import std.stream; -double[][] trainingInputs, trainingOutputs; -uint numInputs; +//double[][] trainingInputs, trainingOutputs; +//uint numInputs; uint[] outputsArray; -void initTrainingExample(int example) { +struct TrainingExample { + char[] name; + uint numInputs,numOutputs; + uint[] numHidden; + double[][] inputs,outputs; +} + +/// Reads the training example from a file +TrainingExample readTrainingExample(InputStream f){ + TrainingExample e; + char[][] temp,trnInputs,trnOutputs; + char[] line; + bool gotInputs,gotOutputs,gotHidden; + + int i; + + line = f.readLine(); + temp = split(line); + if(tolower(temp[0]) != "name:") throw new Error("Expecting \"Name:\" on first line."); + e.name = join(temp[1..$]," "); + + bool stop = false; + while(!stop){ + if(f.eof) throw new Error("Unexpected end of file."); + line = strip(tolower(f.readLine())); + temp = split(line); + if(temp.length == 0) continue; + switch(temp[0]){ + case "inputs:": + if(gotInputs) throw new Error("Only one inputs line allowed."); + e.numInputs = toInt(temp[1]); + if(e.numInputs < 1) throw new Error("Inputs cannot be less than 1."); + gotInputs = true; + break; + case "outputs:": + if(gotOutputs) throw new Error("Only one outputs line allowed."); + e.numOutputs = toInt(temp[1]); + if(e.numOutputs < 1) throw new Error("Outputs cannot be less than 1."); + gotOutputs = true; + break; + case "hidden:": + if(gotHidden) throw new Error("Only one hidden line allowed."); + if(temp.length == 1) break; + temp = split(join(temp[1..$],""),","); + e.numHidden.length = temp.length; + for(i = 0; i<temp.length; i++){ + e.numHidden[i] = toInt(temp[i]); + if(e.numHidden[i] < 1) throw new Error("Hidden layers cannot be less than 1."); + } + gotHidden = true; + break; + case "data:": + stop = true; + break; + default: + throw new Error(temp[0] ~ " not recognized."); + } + } + + if(!gotInputs) throw new Error("Missing inputs line."); + if(!gotOutputs) throw new Error("Missing outputs line."); + //if(!gotHidden) throw new Error("Missing hidden line."); // Hidden line not required + + //Data reading + e.inputs.length = 10; // This isn't as redundant as it looks. + e.inputs.length = 0; + e.outputs.length = 10; + e.outputs.length = 0; + while(!f.eof){ + line = join(split(strip(f.readLine())),""); // remove whitespace + temp = split(line,";"); + if(line.length == 0) continue; // blank line + if(temp.length != 2) throw new Error("Wrong number of semicolons in data."); + + trnInputs = split(temp[0],","); + trnOutputs = split(temp[1],","); + if(trnInputs.length != e.numInputs) throw new Error("Expecting " ~ toString(e.numInputs) ~ " inputs, not " ~ toString(trnInputs.length)); + if(trnOutputs.length != e.numOutputs) throw new Error("Expecting " ~ toString(e.numOutputs) ~ " outputs, not " ~ toString(trnOutputs.length)); + e.inputs.length = e.inputs.length + 1; + e.inputs[$-1].length = e.numInputs; + e.outputs.length = e.outputs.length + 1; + e.outputs[$-1].length = e.numOutputs; + for(i = 0; i<e.numInputs; i++) + e.inputs[$-1][i] = toDouble(trnInputs[i]); + for(i = 0; i<e.numOutputs; i++) + e.outputs[$-1][i] = toDouble(trnOutputs[i]); + } + + return e; +} + +/*void initTrainingExample(int example) { if(example == 0) { numInputs = 3; outputsArray = [2,1]; @@ -61,16 +152,18 @@ trainingOutputs = trainingInputs; } -} +}*/ void main(char[][] args) { double learningRate = 0.2, momentum = 0.3, randomSize = 0.1, errorMin = 0.05; - int trainingExample = 0, maxIters = 10000; // 0 to 2 - bool quiet = false; // don't print output each time + int /*trainingExample = 0,*/ maxIters = 10000; // 0 to 2 + bool quiet,printFull,printWeights; int printEvery = 500; // output every ~ times + TrainingExample example; - //try { - for(int i = 1; i < args.length; i++) { + char[] filename = args[1]; + example = readTrainingExample(new File(filename)); + for(int i = 2; i < args.length; i++) { switch(args[i]) { case "-s": case "--seed": @@ -78,8 +171,6 @@ break; case "-l": case "--learning-rate": - //if(args.length = i + 1) - //throw new Error("Wrong number of paramaters"); learningRate = toDouble(args[++i]); break; case "-m": @@ -94,7 +185,7 @@ case "--error-min": errorMin = toDouble(args[++i]); break; - case "-n": + /*case "-n": case "--example-number": trainingExample = toInt(args[++i]); if(trainingExample > 2 || trainingExample < 0) @@ -114,7 +205,7 @@ default: throw new Error("Wrong example name. Must be parity, xor or identity"); } - break; + break;*/ case "-q": case "--quiet": quiet = true; @@ -128,6 +219,15 @@ case "--min-iterations": maxIters = toInt(args[++i]); break; + case "-f": + case "--full": + case "--print full": + printFull = true; + break; + case "-w": + case "--print-weights": + printWeights = true; + break; default: throw new Error("Unknown switch: " ~ args[i]); } @@ -136,11 +236,19 @@ // throw new Error("Wrong number of paramaters"); //} - initTrainingExample(trainingExample); + //initTrainingExample(trainingExample); + + writefln("Starting training with: " ~ example.name); - Backprop nn = new Backprop(numInputs, outputsArray, [&sigmoid, &sigmoid], learningRate, momentum, randomSize, true); + OutputFunctionPtr[] functions; + functions.length = example.numHidden.length + 1; + for(int i = 0; i<functions.length; i++){ + functions[i] = &sigmoid; + } - double error = nn.calculateError(trainingInputs, trainingOutputs); + Backprop nn = new Backprop(example.numInputs, example.numHidden ~ example.numOutputs, functions, learningRate, momentum, randomSize, true); + + double error = nn.calculateError(example.inputs, example.outputs); double[] output; int iter = 0; //writef("weights="); @@ -151,27 +259,29 @@ if(iter % printEvery == 0) { writefln("Iter: %d", iter); if(!quiet) { - for(int i = 0; i < trainingInputs.length; i++) { - output = nn.evaluate(trainingInputs[i]); + for(int i = 0; i < example.inputs.length; i++) { + output = nn.evaluate(example.inputs[i]); writef(" %d:", i); printArray(output); } } writefln(" Error: %f", error); } - nn.train(trainingInputs, trainingOutputs, true); - error = nn.calculateError(trainingInputs, trainingOutputs); + nn.train(example.inputs, example.outputs, true); + error = nn.calculateError(example.inputs, example.outputs); iter++; } writefln("Total Iters: %d", iter); - for(int i = 0; i < trainingInputs.length; i++) { + for(int i = 0; i < example.inputs.length; i++) { writef(" %d:", i); - if(trainingExample == 2) - printArray(nn.evaluateFull(trainingInputs[i])[0]); + if(printFull) + printArray(nn.evaluateFull(example.inputs[i])[0]); else - printArray(nn.evaluate(trainingInputs[i])); + printArray(nn.evaluate(example.inputs[i])); } writefln(" Error: %f", error); - writef("weights="); - printArray(nn.getWeights()); + if(printWeights){ + writef("weights="); + printArray(nn.getWeights()); + } }
--- a/trunk/dsss.conf Tue Apr 15 14:39:49 2008 -0600 +++ b/trunk/dsss.conf Tue May 06 21:43:55 2008 -0600 @@ -3,7 +3,12 @@ exclude=aid/maze/ exclude+=aid/containers/ [ga_code.d] +noinstall [ga_maze.d] +noinstall [mazegen.d] +noinstall [perceptron_test.d] +noinstall [backprop_test.d] +noinstall
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/trunk/trainingexamples/binaryadd4.trnex Tue May 06 21:43:55 2008 -0600 @@ -0,0 +1,24 @@ +Name: Binary Addition 2,2 +Inputs: 4 +Outputs: 3 +Hidden: 8 +Data: +0,0, 0,0; 0.1,0.1,0.1 +0,1, 0,0; 0.1,0.1,0.9 +1,0, 0,0; 0.1,0.9,0.1 +1,1, 0,0; 0.1,0.9,0.9 + +0,0, 0,1; 0.1,0.1,0.9 +0,1, 0,1; 0.1,0.9,0.1 +1,0, 0,1; 0.1,0.9,0.9 +1,1, 0,1; 0.9,0.1,0.1 + +0,0, 1,0; 0.1,0.9,0.1 +0,1, 1,0; 0.1,0.9,0.9 +1,0, 1,0; 0.9,0.1,0.9 +1,1, 1,0; 0.9,0.9,0.1 + +0,0, 1,1; 0.1,0.9,0.9 +0,1, 1,1; 0.9,0.1,0.1 +1,0, 1,1; 0.9,0.9,0.1 +1,1, 1,1; 0.9,0.9,0.9 \ No newline at end of file
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/trunk/trainingexamples/identity.trnex Tue May 06 21:43:55 2008 -0600 @@ -0,0 +1,13 @@ +Name: Identity-Hidden representation +Inputs: 8 +Outputs: 8 +Hidden: 3 +Data: +0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1; 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1 +0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1; 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1 +0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1; 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1 +0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1; 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1 +0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1; 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1 +0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1; 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1 +0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1; 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1 +0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9; 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9 \ No newline at end of file
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/trunk/trainingexamples/parody2.trnex Tue May 06 21:43:55 2008 -0600 @@ -0,0 +1,9 @@ +Name: Parody 2 +Inputs: 2 +Outputs: 1 +Hidden: 2 +Data: +0, 0; 0.1 +0, 1; 0.9 +1, 0; 0.9 +1, 1; 0.1 \ No newline at end of file
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/trunk/trainingexamples/parody3.trnex Tue May 06 21:43:55 2008 -0600 @@ -0,0 +1,13 @@ +Name: Parody 3 +Inputs: 3 +Outputs: 1 +Hidden: 2 +Data: +0, 0, 0; 0.1 +0, 0, 1; 0.9 +0, 1, 0; 0.9 +0, 1, 1; 0.1 +1, 0, 0; 0.9 +1, 0, 1; 0.1 +1, 1, 0; 0.1 +1, 1, 1; 0.9 \ No newline at end of file