view trunk/backprop_test.d @ 5:810d58835f86

Added momentum and stochastic training to backprop.
author revcompgeek
date Tue, 15 Apr 2008 14:39:49 -0600
parents 73beed484455
children ff92c77006c7
line wrap: on
line source


module backprop_test;


import aid.nn.multilayer.backprop;
import aid.nn.outputFunctions;
import aid.misc;
import std.stdio;
import std.random;
import std.conv;

double[][] trainingInputs, trainingOutputs;
uint numInputs;
uint[] outputsArray;

void initTrainingExample(int example) {
	if(example == 0) {
		numInputs = 3;
		outputsArray = [2,1];
		trainingInputs = [[0, 0, 0],
		                  [0, 0, 1],
		                  [0, 1, 0],
		                  [0, 1, 1],
		                  [1, 0, 0],
		                  [1, 0, 1],
		                  [1, 1, 0],
		                  [1, 1, 1]];
		
		trainingOutputs = [[0.1],
		                   [0.9],
		                   [0.9],
		                   [0.1],
		                   [0.9],
		                   [0.1],
		                   [0.1],
		                   [0.9]];
	} else if(example == 1) {
		numInputs = 2;
		outputsArray = [2,1];
		trainingInputs = [[0, 0],
		                  [1, 0],
		                  [0, 1],
		                  [1, 1]];
		
		trainingOutputs = [[0.9],
		                   [0.1],
		                   [0.1],
		                   [0.9]];
	} else if(example == 2) {
		numInputs = 8;
		outputsArray = [3,8];
		trainingInputs = [
			[0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
			[0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
			[0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1],
			[0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1],
			[0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1],
			[0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1],
			[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1],
			[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9]];
		
		trainingOutputs = trainingInputs;
	}
}

void main(char[][] args) {
	double learningRate = 0.2, momentum = 0.3, randomSize = 0.1, errorMin = 0.05;
	int trainingExample = 0, maxIters = 10000; // 0 to 2
	bool quiet = false; // don't print output each time
	int printEvery = 500; // output every ~ times
	
	//try {
	for(int i = 1; i < args.length; i++) {
		switch(args[i]) {
			case "-s":
			case "--seed":
				rand_seed(123, 0);
				break;
			case "-l":
			case "--learning-rate":
				//if(args.length = i + 1)
					//throw new Error("Wrong number of paramaters");
				learningRate = toDouble(args[++i]);
				break;
			case "-m":
			case "--momentum":
				momentum = toDouble(args[++i]);
				break;
			case "-r":
			case "--random-size":
				randomSize = toDouble(args[++i]);
				break;
			case "-e":
			case "--error-min":
				errorMin = toDouble(args[++i]);
				break;
			case "-n":
			case "--example-number":
				trainingExample = toInt(args[++i]);
				if(trainingExample > 2 || trainingExample < 0)
					throw new Error("example number must be between 0 and 2");
			case "-x":
			case "--example":
				switch(args[++i]) {
					case "parity":
						trainingExample = 0;
						break;
					case "xor":
						trainingExample = 1;
						break;
					case "identity":
						trainingExample = 2;
						break;
					default:
						throw new Error("Wrong example name. Must be parity, xor or identity");
				}
				break;
			case "-q":
			case "--quiet":
				quiet = true;
				break;
			case "-p":
			case "--print-every":
				printEvery = toInt(args[++i]);
				break;
			case "-i":
			case "--min-iters":
			case "--min-iterations":
				maxIters = toInt(args[++i]);
				break;
			default:
				throw new Error("Unknown switch: " ~ args[i]);
		}
	}
	//} catch(ArrayBoundsError) {
	//	throw new Error("Wrong number of paramaters");
	//}
	
	initTrainingExample(trainingExample);
	
	Backprop nn = new Backprop(numInputs, outputsArray, [&sigmoid, &sigmoid], learningRate, momentum, randomSize, true);
	
	double error = nn.calculateError(trainingInputs, trainingOutputs);
	double[] output;
	int iter = 0;
	//writef("weights=");
	//printArray(nn.getWeights());
	//writef("outputs=");
	//printArray(nn.evaluateFull(trainingInputs[$ - 1]));
	while (error >= errorMin && iter < maxIters) {
		if(iter % printEvery == 0) {
			writefln("Iter: %d", iter);
			if(!quiet) {
				for(int i = 0; i < trainingInputs.length; i++) {
					output = nn.evaluate(trainingInputs[i]);
					writef("  %d:", i);
					printArray(output);
				}
			}
			writefln("  Error: %f", error);
		}
		nn.train(trainingInputs, trainingOutputs, true);
		error = nn.calculateError(trainingInputs, trainingOutputs);
		iter++;
	}
	writefln("Total Iters: %d", iter);
	for(int i = 0; i < trainingInputs.length; i++) {
		writef("  %d:", i);
		if(trainingExample == 2)
			printArray(nn.evaluateFull(trainingInputs[i])[0]);
		else
			printArray(nn.evaluate(trainingInputs[i]));
	}
	writefln("  Error: %f", error);
	writef("weights=");
	printArray(nn.getWeights());
}