view trunk/backprop_test.d @ 4:73beed484455

Backprop working correctly.
author revcompgeek
date Sat, 12 Apr 2008 21:55:37 -0600
parents 314d68bafeff
children 810d58835f86
line wrap: on
line source

module backprop_test;

import aid.nn.multilayer.backprop;
import aid.nn.outputFunctions;
import std.stdio;
import std.random;

/+double[][] trainingInputs = [
	[0,0,0],
	[0,0,1],
	[0,1,0],
	[0,1,1],
	[1,0,0],
	[1,0,1],
	[1,1,0],
	[1,1,1]];

double[][] trainingOutputs = [
	[0.1],
	[0.9],
	[0.9],
	[0.1],
	[0.9],
	[0.1],
	[0.1],
	[0.9]];+/

/+double[][] trainingInputs = [
	[0,0],
	[1,0],
	[0,1],
	[1,1]];

double[][] trainingOutputs = [
	[0.9],
	[0.1],
	[0.1],
	[0.9]];+/

double[][] trainingInputs = [
	[0.9,0.1,0.1,0.1,0.1,0.1,0.1,0.1],
	[0.1,0.9,0.1,0.1,0.1,0.1,0.1,0.1],
	[0.1,0.1,0.9,0.1,0.1,0.1,0.1,0.1],
	[0.1,0.1,0.1,0.9,0.1,0.1,0.1,0.1],
	[0.1,0.1,0.1,0.1,0.9,0.1,0.1,0.1],
	[0.1,0.1,0.1,0.1,0.1,0.9,0.1,0.1],
	[0.1,0.1,0.1,0.1,0.1,0.1,0.9,0.1],
	[0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.9]];

void main(){
	//rand_seed(0,0);
	Backprop nn = new Backprop(8,[3,8],[&sigmoid,&sigmoid],.1);
	
	double error = nn.calculateError(trainingInputs,trainingInputs);
	double[] output;
	int iter = 0;
	writef("weights="); printArray(nn.getWeights());
	writef("outputs="); printArray(nn.evaluateFull(trainingInputs[$-1]));
	while(error >= 0.01 && iter < 50000){
		if(iter % 500 == 0){
			writefln("Iter: %d",iter);
			for(int i=0; i<trainingInputs.length; i++){
				output = nn.evaluate(trainingInputs[i]);
				writef("  %d:", i); printArray(output);
			}
			writefln("  Error: %f", error);
		}
		nn.train(trainingInputs,trainingInputs);
		error = nn.calculateError(trainingInputs,trainingInputs);
		iter++;
	}
	writefln("Total Iters: %d",iter);
	for(int i=0; i<trainingInputs.length; i++){
		writef("  %d:", i); printArray(nn.evaluateFull(trainingInputs[i])[0]);
	}
	writefln("  Error: %f", error);
	writef("weights="); printArray(nn.getWeights());
}

void printArray(double[] array){
	writef("[");
	for(int i=0; i<array.length-1; i++){
		writef("%f, ",array[i]);
	}
	writefln("%f]",array[$-1]);
}

void printArray(double[][] array){
	writef("[");
	for(int i=0; i<array.length; i++){
		printArray(array[i]);
	}
	writefln("]");
}

void printArray(double[][][] array){
	writef("[");
	for(int i=0; i<array.length; i++){
		printArray(array[i]);
	}
	writefln("]");
}