view trunk/backprop_test.d @ 6:ff92c77006c7

Added support for reading training examples from files.
author revcompgeek
date Tue, 06 May 2008 21:43:55 -0600
parents 810d58835f86
children b9fe92a2d8ad
line wrap: on
line source

module backprop_test;

import aid.nn.multilayer.backprop;
import aid.nn.outputFunctions;
import aid.misc;
import std.stdio;
import std.random;
import std.conv;
import std.string;
import std.stream;

//double[][] trainingInputs, trainingOutputs;
//uint numInputs;
uint[] outputsArray;

struct TrainingExample {
    char[] name;
    uint numInputs,numOutputs;
    uint[] numHidden;
    double[][] inputs,outputs;
}

/// Reads the training example from a file
TrainingExample readTrainingExample(InputStream f){
    TrainingExample e;
    char[][] temp,trnInputs,trnOutputs;
    char[] line;
    bool gotInputs,gotOutputs,gotHidden;
    
    int i;
    
    line = f.readLine();
    temp = split(line);
    if(tolower(temp[0]) != "name:") throw new Error("Expecting \"Name:\" on first line.");
    e.name = join(temp[1..$]," ");
    
    bool stop = false;
    while(!stop){
        if(f.eof) throw new Error("Unexpected end of file.");
        line = strip(tolower(f.readLine()));
        temp = split(line);
        if(temp.length == 0) continue;
        switch(temp[0]){
        case "inputs:":
            if(gotInputs) throw new Error("Only one inputs line allowed.");
            e.numInputs = toInt(temp[1]);
            if(e.numInputs < 1) throw new Error("Inputs cannot be less than 1.");
            gotInputs = true;
            break;
        case "outputs:":
            if(gotOutputs) throw new Error("Only one outputs line allowed.");
            e.numOutputs = toInt(temp[1]);
            if(e.numOutputs < 1) throw new Error("Outputs cannot be less than 1.");
            gotOutputs = true;
            break;
        case "hidden:":
            if(gotHidden) throw new Error("Only one hidden line allowed.");
            if(temp.length == 1) break;
            temp = split(join(temp[1..$],""),",");
            e.numHidden.length = temp.length;
            for(i = 0; i<temp.length; i++){
                e.numHidden[i] = toInt(temp[i]);
                if(e.numHidden[i] < 1) throw new Error("Hidden layers cannot be less than 1.");
            }
            gotHidden = true;
            break;
        case "data:":
            stop = true;
            break;
        default:
            throw new Error(temp[0] ~ " not recognized.");
        }
    }
    
    if(!gotInputs) throw new Error("Missing inputs line.");
    if(!gotOutputs) throw new Error("Missing outputs line.");
    //if(!gotHidden) throw new Error("Missing hidden line."); // Hidden line not required
    
    //Data reading
    e.inputs.length = 10; // This isn't as redundant as it looks.
    e.inputs.length = 0;
    e.outputs.length = 10;
    e.outputs.length = 0;
    while(!f.eof){
        line = join(split(strip(f.readLine())),""); // remove whitespace
        temp = split(line,";");
        if(line.length == 0) continue; // blank line
        if(temp.length != 2) throw new Error("Wrong number of semicolons in data.");
        
        trnInputs = split(temp[0],",");
        trnOutputs = split(temp[1],",");
        if(trnInputs.length != e.numInputs) throw new Error("Expecting " ~ toString(e.numInputs) ~ " inputs, not " ~ toString(trnInputs.length));
        if(trnOutputs.length != e.numOutputs) throw new Error("Expecting " ~ toString(e.numOutputs) ~ " outputs, not " ~ toString(trnOutputs.length));
        e.inputs.length = e.inputs.length + 1;
        e.inputs[$-1].length = e.numInputs;
        e.outputs.length = e.outputs.length + 1;
        e.outputs[$-1].length = e.numOutputs;
        for(i = 0; i<e.numInputs; i++)
            e.inputs[$-1][i] = toDouble(trnInputs[i]);
        for(i = 0; i<e.numOutputs; i++)
            e.outputs[$-1][i] = toDouble(trnOutputs[i]);
    }
    
    return e;
}

/*void initTrainingExample(int example) {
	if(example == 0) {
		numInputs = 3;
		outputsArray = [2,1];
		trainingInputs = [[0, 0, 0],
		                  [0, 0, 1],
		                  [0, 1, 0],
		                  [0, 1, 1],
		                  [1, 0, 0],
		                  [1, 0, 1],
		                  [1, 1, 0],
		                  [1, 1, 1]];
		
		trainingOutputs = [[0.1],
		                   [0.9],
		                   [0.9],
		                   [0.1],
		                   [0.9],
		                   [0.1],
		                   [0.1],
		                   [0.9]];
	} else if(example == 1) {
		numInputs = 2;
		outputsArray = [2,1];
		trainingInputs = [[0, 0],
		                  [1, 0],
		                  [0, 1],
		                  [1, 1]];
		
		trainingOutputs = [[0.9],
		                   [0.1],
		                   [0.1],
		                   [0.9]];
	} else if(example == 2) {
		numInputs = 8;
		outputsArray = [3,8];
		trainingInputs = [
			[0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
			[0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
			[0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1],
			[0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1],
			[0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1],
			[0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1],
			[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1],
			[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9]];
		
		trainingOutputs = trainingInputs;
	}
}*/

void main(char[][] args) {
	double learningRate = 0.2, momentum = 0.3, randomSize = 0.1, errorMin = 0.05;
	int /*trainingExample = 0,*/ maxIters = 10000; // 0 to 2
	bool quiet,printFull,printWeights;
	int printEvery = 500; // output every ~ times
    TrainingExample example;
	
	char[] filename = args[1];
    example = readTrainingExample(new File(filename));
	for(int i = 2; i < args.length; i++) {
		switch(args[i]) {
			case "-s":
			case "--seed":
				rand_seed(123, 0);
				break;
			case "-l":
			case "--learning-rate":
				learningRate = toDouble(args[++i]);
				break;
			case "-m":
			case "--momentum":
				momentum = toDouble(args[++i]);
				break;
			case "-r":
			case "--random-size":
				randomSize = toDouble(args[++i]);
				break;
			case "-e":
			case "--error-min":
				errorMin = toDouble(args[++i]);
				break;
			/*case "-n":
			case "--example-number":
				trainingExample = toInt(args[++i]);
				if(trainingExample > 2 || trainingExample < 0)
					throw new Error("example number must be between 0 and 2");
			case "-x":
			case "--example":
				switch(args[++i]) {
					case "parity":
						trainingExample = 0;
						break;
					case "xor":
						trainingExample = 1;
						break;
					case "identity":
						trainingExample = 2;
						break;
					default:
						throw new Error("Wrong example name. Must be parity, xor or identity");
				}
				break;*/
			case "-q":
			case "--quiet":
				quiet = true;
				break;
			case "-p":
			case "--print-every":
				printEvery = toInt(args[++i]);
				break;
			case "-i":
			case "--min-iters":
			case "--min-iterations":
				maxIters = toInt(args[++i]);
				break;
            case "-f":
            case "--full":
            case "--print full":
                printFull = true;
                break;
            case "-w":
            case "--print-weights":
                printWeights = true;
                break;
			default:
				throw new Error("Unknown switch: " ~ args[i]);
		}
	}
	//} catch(ArrayBoundsError) {
	//	throw new Error("Wrong number of paramaters");
	//}
	
	//initTrainingExample(trainingExample);
	
	writefln("Starting training with: " ~ example.name);
	
	OutputFunctionPtr[] functions;
    functions.length = example.numHidden.length + 1;
    for(int i = 0; i<functions.length; i++){
        functions[i] = &sigmoid;
    }
	
	Backprop nn = new Backprop(example.numInputs, example.numHidden ~ example.numOutputs, functions, learningRate, momentum, randomSize, true);
	
	double error = nn.calculateError(example.inputs, example.outputs);
	double[] output;
	int iter = 0;
	//writef("weights=");
	//printArray(nn.getWeights());
	//writef("outputs=");
	//printArray(nn.evaluateFull(trainingInputs[$ - 1]));
	while (error >= errorMin && iter < maxIters) {
		if(iter % printEvery == 0) {
			writefln("Iter: %d", iter);
			if(!quiet) {
				for(int i = 0; i < example.inputs.length; i++) {
					output = nn.evaluate(example.inputs[i]);
					writef("  %d:", i);
					printArray(output);
				}
			}
			writefln("  Error: %f", error);
		}
		nn.train(example.inputs, example.outputs, true);
		error = nn.calculateError(example.inputs, example.outputs);
		iter++;
	}
	writefln("Total Iters: %d", iter);
	for(int i = 0; i < example.inputs.length; i++) {
		writef("  %d:", i);
		if(printFull)
			printArray(nn.evaluateFull(example.inputs[i])[0]);
		else
			printArray(nn.evaluate(example.inputs[i]));
	}
	writefln("  Error: %f", error);
	if(printWeights){
        writef("weights=");
        printArray(nn.getWeights());
    }
}