diff trunk/backprop_test.d @ 5:810d58835f86

Added momentum and stochastic training to backprop.
author revcompgeek
date Tue, 15 Apr 2008 14:39:49 -0600
parents 73beed484455
children ff92c77006c7
line wrap: on
line diff
--- a/trunk/backprop_test.d	Sat Apr 12 21:55:37 2008 -0600
+++ b/trunk/backprop_test.d	Tue Apr 15 14:39:49 2008 -0600
@@ -1,102 +1,177 @@
+
 module backprop_test;
 
+
 import aid.nn.multilayer.backprop;
 import aid.nn.outputFunctions;
+import aid.misc;
 import std.stdio;
 import std.random;
+import std.conv;
 
-/+double[][] trainingInputs = [
-	[0,0,0],
-	[0,0,1],
-	[0,1,0],
-	[0,1,1],
-	[1,0,0],
-	[1,0,1],
-	[1,1,0],
-	[1,1,1]];
+double[][] trainingInputs, trainingOutputs;
+uint numInputs;
+uint[] outputsArray;
 
-double[][] trainingOutputs = [
-	[0.1],
-	[0.9],
-	[0.9],
-	[0.1],
-	[0.9],
-	[0.1],
-	[0.1],
-	[0.9]];+/
+void initTrainingExample(int example) {
+	if(example == 0) {
+		numInputs = 3;
+		outputsArray = [2,1];
+		trainingInputs = [[0, 0, 0],
+		                  [0, 0, 1],
+		                  [0, 1, 0],
+		                  [0, 1, 1],
+		                  [1, 0, 0],
+		                  [1, 0, 1],
+		                  [1, 1, 0],
+		                  [1, 1, 1]];
+		
+		trainingOutputs = [[0.1],
+		                   [0.9],
+		                   [0.9],
+		                   [0.1],
+		                   [0.9],
+		                   [0.1],
+		                   [0.1],
+		                   [0.9]];
+	} else if(example == 1) {
+		numInputs = 2;
+		outputsArray = [2,1];
+		trainingInputs = [[0, 0],
+		                  [1, 0],
+		                  [0, 1],
+		                  [1, 1]];
+		
+		trainingOutputs = [[0.9],
+		                   [0.1],
+		                   [0.1],
+		                   [0.9]];
+	} else if(example == 2) {
+		numInputs = 8;
+		outputsArray = [3,8];
+		trainingInputs = [
+			[0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
+			[0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
+			[0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1],
+			[0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1],
+			[0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1],
+			[0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1],
+			[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1],
+			[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9]];
+		
+		trainingOutputs = trainingInputs;
+	}
+}
 
-/+double[][] trainingInputs = [
-	[0,0],
-	[1,0],
-	[0,1],
-	[1,1]];
-
-double[][] trainingOutputs = [
-	[0.9],
-	[0.1],
-	[0.1],
-	[0.9]];+/
-
-double[][] trainingInputs = [
-	[0.9,0.1,0.1,0.1,0.1,0.1,0.1,0.1],
-	[0.1,0.9,0.1,0.1,0.1,0.1,0.1,0.1],
-	[0.1,0.1,0.9,0.1,0.1,0.1,0.1,0.1],
-	[0.1,0.1,0.1,0.9,0.1,0.1,0.1,0.1],
-	[0.1,0.1,0.1,0.1,0.9,0.1,0.1,0.1],
-	[0.1,0.1,0.1,0.1,0.1,0.9,0.1,0.1],
-	[0.1,0.1,0.1,0.1,0.1,0.1,0.9,0.1],
-	[0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.9]];
-
-void main(){
-	//rand_seed(0,0);
-	Backprop nn = new Backprop(8,[3,8],[&sigmoid,&sigmoid],.1);
+void main(char[][] args) {
+	double learningRate = 0.2, momentum = 0.3, randomSize = 0.1, errorMin = 0.05;
+	int trainingExample = 0, maxIters = 10000; // 0 to 2
+	bool quiet = false; // don't print output each time
+	int printEvery = 500; // output every ~ times
 	
-	double error = nn.calculateError(trainingInputs,trainingInputs);
+	//try {
+	for(int i = 1; i < args.length; i++) {
+		switch(args[i]) {
+			case "-s":
+			case "--seed":
+				rand_seed(123, 0);
+				break;
+			case "-l":
+			case "--learning-rate":
+				//if(args.length = i + 1)
+					//throw new Error("Wrong number of paramaters");
+				learningRate = toDouble(args[++i]);
+				break;
+			case "-m":
+			case "--momentum":
+				momentum = toDouble(args[++i]);
+				break;
+			case "-r":
+			case "--random-size":
+				randomSize = toDouble(args[++i]);
+				break;
+			case "-e":
+			case "--error-min":
+				errorMin = toDouble(args[++i]);
+				break;
+			case "-n":
+			case "--example-number":
+				trainingExample = toInt(args[++i]);
+				if(trainingExample > 2 || trainingExample < 0)
+					throw new Error("example number must be between 0 and 2");
+			case "-x":
+			case "--example":
+				switch(args[++i]) {
+					case "parity":
+						trainingExample = 0;
+						break;
+					case "xor":
+						trainingExample = 1;
+						break;
+					case "identity":
+						trainingExample = 2;
+						break;
+					default:
+						throw new Error("Wrong example name. Must be parity, xor or identity");
+				}
+				break;
+			case "-q":
+			case "--quiet":
+				quiet = true;
+				break;
+			case "-p":
+			case "--print-every":
+				printEvery = toInt(args[++i]);
+				break;
+			case "-i":
+			case "--min-iters":
+			case "--min-iterations":
+				maxIters = toInt(args[++i]);
+				break;
+			default:
+				throw new Error("Unknown switch: " ~ args[i]);
+		}
+	}
+	//} catch(ArrayBoundsError) {
+	//	throw new Error("Wrong number of paramaters");
+	//}
+	
+	initTrainingExample(trainingExample);
+	
+	Backprop nn = new Backprop(numInputs, outputsArray, [&sigmoid, &sigmoid], learningRate, momentum, randomSize, true);
+	
+	double error = nn.calculateError(trainingInputs, trainingOutputs);
 	double[] output;
 	int iter = 0;
-	writef("weights="); printArray(nn.getWeights());
-	writef("outputs="); printArray(nn.evaluateFull(trainingInputs[$-1]));
-	while(error >= 0.01 && iter < 50000){
-		if(iter % 500 == 0){
-			writefln("Iter: %d",iter);
-			for(int i=0; i<trainingInputs.length; i++){
-				output = nn.evaluate(trainingInputs[i]);
-				writef("  %d:", i); printArray(output);
+	//writef("weights=");
+	//printArray(nn.getWeights());
+	//writef("outputs=");
+	//printArray(nn.evaluateFull(trainingInputs[$ - 1]));
+	while (error >= errorMin && iter < maxIters) {
+		if(iter % printEvery == 0) {
+			writefln("Iter: %d", iter);
+			if(!quiet) {
+				for(int i = 0; i < trainingInputs.length; i++) {
+					output = nn.evaluate(trainingInputs[i]);
+					writef("  %d:", i);
+					printArray(output);
+				}
 			}
 			writefln("  Error: %f", error);
 		}
-		nn.train(trainingInputs,trainingInputs);
-		error = nn.calculateError(trainingInputs,trainingInputs);
+		nn.train(trainingInputs, trainingOutputs, true);
+		error = nn.calculateError(trainingInputs, trainingOutputs);
 		iter++;
 	}
-	writefln("Total Iters: %d",iter);
-	for(int i=0; i<trainingInputs.length; i++){
-		writef("  %d:", i); printArray(nn.evaluateFull(trainingInputs[i])[0]);
+	writefln("Total Iters: %d", iter);
+	for(int i = 0; i < trainingInputs.length; i++) {
+		writef("  %d:", i);
+		if(trainingExample == 2)
+			printArray(nn.evaluateFull(trainingInputs[i])[0]);
+		else
+			printArray(nn.evaluate(trainingInputs[i]));
 	}
 	writefln("  Error: %f", error);
-	writef("weights="); printArray(nn.getWeights());
-}
-
-void printArray(double[] array){
-	writef("[");
-	for(int i=0; i<array.length-1; i++){
-		writef("%f, ",array[i]);
-	}
-	writefln("%f]",array[$-1]);
+	writef("weights=");
+	printArray(nn.getWeights());
 }
-
-void printArray(double[][] array){
-	writef("[");
-	for(int i=0; i<array.length; i++){
-		printArray(array[i]);
-	}
-	writefln("]");
-}
-
-void printArray(double[][][] array){
-	writef("[");
-	for(int i=0; i<array.length; i++){
-		printArray(array[i]);
-	}
-	writefln("]");
-}
\ No newline at end of file