Mercurial > projects > aid
view trunk/backprop_test.d @ 7:b9fe92a2d8ad default tip
Removed old code.
author | revcompgeek |
---|---|
date | Tue, 06 May 2008 22:20:26 -0600 |
parents | ff92c77006c7 |
children |
line wrap: on
line source
module backprop_test; import aid.nn.multilayer.backprop; import aid.nn.outputFunctions; import aid.misc; import std.stdio; import std.random; import std.conv; import std.string; import std.stream; uint[] outputsArray; struct TrainingExample { char[] name; uint numInputs,numOutputs; uint[] numHidden; double[][] inputs,outputs; } /// Reads the training example from a file TrainingExample readTrainingExample(InputStream f){ TrainingExample e; char[][] temp,trnInputs,trnOutputs; char[] line; bool gotInputs,gotOutputs,gotHidden; int i; line = f.readLine(); temp = split(line); if(tolower(temp[0]) != "name:") throw new Error("Expecting \"Name:\" on first line."); e.name = join(temp[1..$]," "); bool stop = false; while(!stop){ if(f.eof) throw new Error("Unexpected end of file."); line = strip(tolower(f.readLine())); temp = split(line); if(temp.length == 0) continue; switch(temp[0]){ case "inputs:": if(gotInputs) throw new Error("Only one inputs line allowed."); e.numInputs = toInt(temp[1]); if(e.numInputs < 1) throw new Error("Inputs cannot be less than 1."); gotInputs = true; break; case "outputs:": if(gotOutputs) throw new Error("Only one outputs line allowed."); e.numOutputs = toInt(temp[1]); if(e.numOutputs < 1) throw new Error("Outputs cannot be less than 1."); gotOutputs = true; break; case "hidden:": if(gotHidden) throw new Error("Only one hidden line allowed."); if(temp.length == 1) break; temp = split(join(temp[1..$],""),","); e.numHidden.length = temp.length; for(i = 0; i<temp.length; i++){ e.numHidden[i] = toInt(temp[i]); if(e.numHidden[i] < 1) throw new Error("Hidden layers cannot be less than 1."); } gotHidden = true; break; case "data:": stop = true; break; default: throw new Error(temp[0] ~ " not recognized."); } } if(!gotInputs) throw new Error("Missing inputs line."); if(!gotOutputs) throw new Error("Missing outputs line."); //if(!gotHidden) throw new Error("Missing hidden line."); // Hidden line not required //Data reading e.inputs.length = 10; // This isn't as redundant as it looks. e.inputs.length = 0; e.outputs.length = 10; e.outputs.length = 0; while(!f.eof){ line = join(split(strip(f.readLine())),""); // remove whitespace temp = split(line,";"); if(line.length == 0) continue; // blank line if(temp.length != 2) throw new Error("Wrong number of semicolons in data."); trnInputs = split(temp[0],","); trnOutputs = split(temp[1],","); if(trnInputs.length != e.numInputs) throw new Error("Expecting " ~ toString(e.numInputs) ~ " inputs, not " ~ toString(trnInputs.length)); if(trnOutputs.length != e.numOutputs) throw new Error("Expecting " ~ toString(e.numOutputs) ~ " outputs, not " ~ toString(trnOutputs.length)); e.inputs.length = e.inputs.length + 1; e.inputs[$-1].length = e.numInputs; e.outputs.length = e.outputs.length + 1; e.outputs[$-1].length = e.numOutputs; for(i = 0; i<e.numInputs; i++) e.inputs[$-1][i] = toDouble(trnInputs[i]); for(i = 0; i<e.numOutputs; i++) e.outputs[$-1][i] = toDouble(trnOutputs[i]); } return e; } void main(char[][] args) { double learningRate = 0.2, momentum = 0.3, randomSize = 0.1, errorMin = 0.05; int /*trainingExample = 0,*/ maxIters = 10000; // 0 to 2 bool quiet,printFull,printWeights; int printEvery = 500; // output every ~ times TrainingExample example; char[] filename = args[1]; example = readTrainingExample(new File(filename)); for(int i = 2; i < args.length; i++) { switch(args[i]) { case "-s": case "--seed": rand_seed(123, 0); break; case "-l": case "--learning-rate": learningRate = toDouble(args[++i]); break; case "-m": case "--momentum": momentum = toDouble(args[++i]); break; case "-r": case "--random-size": randomSize = toDouble(args[++i]); break; case "-e": case "--error-min": errorMin = toDouble(args[++i]); break; case "-q": case "--quiet": quiet = true; break; case "-p": case "--print-every": printEvery = toInt(args[++i]); break; case "-i": case "--min-iters": case "--min-iterations": maxIters = toInt(args[++i]); break; case "-f": case "--full": case "--print full": printFull = true; break; case "-w": case "--print-weights": printWeights = true; break; default: throw new Error("Unknown switch: " ~ args[i]); } } writefln("Starting training with: " ~ example.name); OutputFunctionPtr[] functions; functions.length = example.numHidden.length + 1; for(int i = 0; i<functions.length; i++){ functions[i] = &sigmoid; } Backprop nn = new Backprop(example.numInputs, example.numHidden ~ example.numOutputs, functions, learningRate, momentum, randomSize, true); double error = nn.calculateError(example.inputs, example.outputs); double[] output; int iter = 0; //writef("weights="); //printArray(nn.getWeights()); //writef("outputs="); //printArray(nn.evaluateFull(trainingInputs[$ - 1])); while (error >= errorMin && iter < maxIters) { if(iter % printEvery == 0) { writefln("Iter: %d", iter); if(!quiet) { for(int i = 0; i < example.inputs.length; i++) { output = nn.evaluate(example.inputs[i]); writef(" %d:", i); printArray(output); } } writefln(" Error: %f", error); } nn.train(example.inputs, example.outputs, true); error = nn.calculateError(example.inputs, example.outputs); iter++; } writefln("Total Iters: %d", iter); for(int i = 0; i < example.inputs.length; i++) { writef(" %d:", i); if(printFull) printArray(nn.evaluateFull(example.inputs[i])[0]); else printArray(nn.evaluate(example.inputs[i])); } writefln(" Error: %f", error); if(printWeights){ writef("weights="); printArray(nn.getWeights()); } }