Mercurial > projects > aid
view trunk/backprop_test.d @ 6:ff92c77006c7
Added support for reading training examples from files.
author | revcompgeek |
---|---|
date | Tue, 06 May 2008 21:43:55 -0600 |
parents | 810d58835f86 |
children | b9fe92a2d8ad |
line wrap: on
line source
module backprop_test; import aid.nn.multilayer.backprop; import aid.nn.outputFunctions; import aid.misc; import std.stdio; import std.random; import std.conv; import std.string; import std.stream; //double[][] trainingInputs, trainingOutputs; //uint numInputs; uint[] outputsArray; struct TrainingExample { char[] name; uint numInputs,numOutputs; uint[] numHidden; double[][] inputs,outputs; } /// Reads the training example from a file TrainingExample readTrainingExample(InputStream f){ TrainingExample e; char[][] temp,trnInputs,trnOutputs; char[] line; bool gotInputs,gotOutputs,gotHidden; int i; line = f.readLine(); temp = split(line); if(tolower(temp[0]) != "name:") throw new Error("Expecting \"Name:\" on first line."); e.name = join(temp[1..$]," "); bool stop = false; while(!stop){ if(f.eof) throw new Error("Unexpected end of file."); line = strip(tolower(f.readLine())); temp = split(line); if(temp.length == 0) continue; switch(temp[0]){ case "inputs:": if(gotInputs) throw new Error("Only one inputs line allowed."); e.numInputs = toInt(temp[1]); if(e.numInputs < 1) throw new Error("Inputs cannot be less than 1."); gotInputs = true; break; case "outputs:": if(gotOutputs) throw new Error("Only one outputs line allowed."); e.numOutputs = toInt(temp[1]); if(e.numOutputs < 1) throw new Error("Outputs cannot be less than 1."); gotOutputs = true; break; case "hidden:": if(gotHidden) throw new Error("Only one hidden line allowed."); if(temp.length == 1) break; temp = split(join(temp[1..$],""),","); e.numHidden.length = temp.length; for(i = 0; i<temp.length; i++){ e.numHidden[i] = toInt(temp[i]); if(e.numHidden[i] < 1) throw new Error("Hidden layers cannot be less than 1."); } gotHidden = true; break; case "data:": stop = true; break; default: throw new Error(temp[0] ~ " not recognized."); } } if(!gotInputs) throw new Error("Missing inputs line."); if(!gotOutputs) throw new Error("Missing outputs line."); //if(!gotHidden) throw new Error("Missing hidden line."); // Hidden line not required //Data reading e.inputs.length = 10; // This isn't as redundant as it looks. e.inputs.length = 0; e.outputs.length = 10; e.outputs.length = 0; while(!f.eof){ line = join(split(strip(f.readLine())),""); // remove whitespace temp = split(line,";"); if(line.length == 0) continue; // blank line if(temp.length != 2) throw new Error("Wrong number of semicolons in data."); trnInputs = split(temp[0],","); trnOutputs = split(temp[1],","); if(trnInputs.length != e.numInputs) throw new Error("Expecting " ~ toString(e.numInputs) ~ " inputs, not " ~ toString(trnInputs.length)); if(trnOutputs.length != e.numOutputs) throw new Error("Expecting " ~ toString(e.numOutputs) ~ " outputs, not " ~ toString(trnOutputs.length)); e.inputs.length = e.inputs.length + 1; e.inputs[$-1].length = e.numInputs; e.outputs.length = e.outputs.length + 1; e.outputs[$-1].length = e.numOutputs; for(i = 0; i<e.numInputs; i++) e.inputs[$-1][i] = toDouble(trnInputs[i]); for(i = 0; i<e.numOutputs; i++) e.outputs[$-1][i] = toDouble(trnOutputs[i]); } return e; } /*void initTrainingExample(int example) { if(example == 0) { numInputs = 3; outputsArray = [2,1]; trainingInputs = [[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]]; trainingOutputs = [[0.1], [0.9], [0.9], [0.1], [0.9], [0.1], [0.1], [0.9]]; } else if(example == 1) { numInputs = 2; outputsArray = [2,1]; trainingInputs = [[0, 0], [1, 0], [0, 1], [1, 1]]; trainingOutputs = [[0.9], [0.1], [0.1], [0.9]]; } else if(example == 2) { numInputs = 8; outputsArray = [3,8]; trainingInputs = [ [0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9]]; trainingOutputs = trainingInputs; } }*/ void main(char[][] args) { double learningRate = 0.2, momentum = 0.3, randomSize = 0.1, errorMin = 0.05; int /*trainingExample = 0,*/ maxIters = 10000; // 0 to 2 bool quiet,printFull,printWeights; int printEvery = 500; // output every ~ times TrainingExample example; char[] filename = args[1]; example = readTrainingExample(new File(filename)); for(int i = 2; i < args.length; i++) { switch(args[i]) { case "-s": case "--seed": rand_seed(123, 0); break; case "-l": case "--learning-rate": learningRate = toDouble(args[++i]); break; case "-m": case "--momentum": momentum = toDouble(args[++i]); break; case "-r": case "--random-size": randomSize = toDouble(args[++i]); break; case "-e": case "--error-min": errorMin = toDouble(args[++i]); break; /*case "-n": case "--example-number": trainingExample = toInt(args[++i]); if(trainingExample > 2 || trainingExample < 0) throw new Error("example number must be between 0 and 2"); case "-x": case "--example": switch(args[++i]) { case "parity": trainingExample = 0; break; case "xor": trainingExample = 1; break; case "identity": trainingExample = 2; break; default: throw new Error("Wrong example name. Must be parity, xor or identity"); } break;*/ case "-q": case "--quiet": quiet = true; break; case "-p": case "--print-every": printEvery = toInt(args[++i]); break; case "-i": case "--min-iters": case "--min-iterations": maxIters = toInt(args[++i]); break; case "-f": case "--full": case "--print full": printFull = true; break; case "-w": case "--print-weights": printWeights = true; break; default: throw new Error("Unknown switch: " ~ args[i]); } } //} catch(ArrayBoundsError) { // throw new Error("Wrong number of paramaters"); //} //initTrainingExample(trainingExample); writefln("Starting training with: " ~ example.name); OutputFunctionPtr[] functions; functions.length = example.numHidden.length + 1; for(int i = 0; i<functions.length; i++){ functions[i] = &sigmoid; } Backprop nn = new Backprop(example.numInputs, example.numHidden ~ example.numOutputs, functions, learningRate, momentum, randomSize, true); double error = nn.calculateError(example.inputs, example.outputs); double[] output; int iter = 0; //writef("weights="); //printArray(nn.getWeights()); //writef("outputs="); //printArray(nn.evaluateFull(trainingInputs[$ - 1])); while (error >= errorMin && iter < maxIters) { if(iter % printEvery == 0) { writefln("Iter: %d", iter); if(!quiet) { for(int i = 0; i < example.inputs.length; i++) { output = nn.evaluate(example.inputs[i]); writef(" %d:", i); printArray(output); } } writefln(" Error: %f", error); } nn.train(example.inputs, example.outputs, true); error = nn.calculateError(example.inputs, example.outputs); iter++; } writefln("Total Iters: %d", iter); for(int i = 0; i < example.inputs.length; i++) { writef(" %d:", i); if(printFull) printArray(nn.evaluateFull(example.inputs[i])[0]); else printArray(nn.evaluate(example.inputs[i])); } writefln(" Error: %f", error); if(printWeights){ writef("weights="); printArray(nn.getWeights()); } }