view trunk/aid/ga.d @ 0:4b2e8e8a633e

Repository setup.
author revcompgeek
date Mon, 03 Mar 2008 19:28:10 -0700
parents
children
line wrap: on
line source

/***********************************
 *    Provides a set of classes    *
 *  for using Genetic Algorithms.  *
 *                                 *
 * Author: Matt P.                 *
 * Version: 0.1                    *
 ***********************************/

module aid.ga;

import tango.math.Random;
import tango.io.Stdout;

void removeIndex(Gene[] array, uint index){
	//scope(failure) Stdout.format("E: removeIndex {}", index).newline;
	if (index == 0)
		array = array[1..$];
	else if (index == array.length)
		array = array[0..$-1];
	else
		array = array[0..index-1] ~ array[index+1..$];
}

class Gene {
	char[] geneSequence;
	private double fit = -1; // Save the fitness value so that it doesn't have to be calculated more than once.
	private GeneticAlgorithm* ga;
	
	this(GeneticAlgorithm* a){
		ga = a;
	}
	
	this(Gene g, GeneticAlgorithm* a){
		scope(failure) Stdout("E: Gene clone").newline;
		geneSequence = g.geneSequence.dup;
		this(a);
	}
	
	double fitness(){ // Get property
		if (fit == -1){
			fit = ga.calculateFitness(geneSequence);
		}
		return fit;
	}
	
	Gene[] crossover(Gene other){
		Gene gene1 = new Gene(this,ga),gene2 = new Gene(other,ga);
		auto r = Random.shared;
		auto len = geneSequence.length;
		char[] temp;
		
		//scope(failure) Stdout.format("E: crossover {}").newline;
		
		if (ga.crossoverType == 2){
			//Stdout("1").flush;
			uint point1 = r.next(len-2)+1;
			uint point2 = r.next(len-2)+1;
			while (point1 == point2)
				point2 = r.next(len);
			if (point2 < point1){
				scope uint t = point1;
				point1 = point2;
				point2 = t;
			}
			//Stdout("2").newline;
			scope(failure) Stdout.format("Ed: {},{}",point1,point2).newline;
			//Stdout(point1)(", ")(point2).newline;
			temp = gene1.geneSequence[point1..point2].dup;
			//Stdout("4").flush;
			gene1.geneSequence[point1..point2] = gene2.geneSequence[point1..point2].dup;
			//Stdout("5").flush;
			gene2.geneSequence[point1..point2] = temp;
			//Stdout("6").flush;
		} else if (ga.crossoverType == 1) {
			uint point = r.next(len-2)+1;
			temp = gene1.geneSequence[point..$].dup;
			gene1.geneSequence[point..$] = gene2.geneSequence[point..$];
			gene2.geneSequence[point..$] = temp;
		} else { // Uniform crossover
			for (int i = 0; i < len / 2; i++){
				uint point = r.next(len);
			
				auto t = gene1.geneSequence[point];
				gene1.geneSequence[point] = gene2.geneSequence[point];
				gene2.geneSequence[point] = t;
			}
		}
		return [gene1,gene2];
	}
	
	Gene mutate(){
		Gene g = new Gene(this,ga);
		uint i = Random.shared.next(g.geneSequence.length);
		g.geneSequence[i] = ga.getRandomChar(i);
		return g;
	}
}

class Generation {
	private Gene[] genes;
	private uint population;
	private GeneticAlgorithm* ga;
	
	public this(GeneticAlgorithm* g, bool generate = false){
		ga = g;
		population = ga.startPopulation;
		if (generate) generateGenes();
	}
	
	public this(Generation gen){
		genes = gen.genes;
		ga = gen.ga;
	}
	
	public Generation evolve(){
		Generation newGen = new Generation(ga);
		
		// Use tournament selection to create a new Generation
		uint livingPop = cast(uint)(population * ga.survivalRate);
		newGen.population = population;
		newGen.genes.length = population;
		for (auto i = 0; i < livingPop; i++){
			uint t = tournamentSelect();
			newGen.genes[i] = genes[t];
			//genes.removeIndex(t);
		}
		
		// Cross over the left over genes
		for (auto i = livingPop; i < population; i+=2){
			uint i1 = tournamentSelect(), i2 = tournamentSelect();
			while (i1 == i2)
				i2 = tournamentSelect();
			Gene g1 = genes[i1], g2 = genes[i2];
			
			Gene[] gs = g1.crossover(g2);

			newGen.genes[i..i+2] = gs[];
			if (ga.deleteOnCrossover){
				genes.removeIndex(i1);
				genes.removeIndex(i2);
			}
		}
		
		// Mutate a small amount of genes
		for (auto i = 0; i < population * ga.mutationRate; i++){
			uint index = Random.shared.next(newGen.genes.length);
			newGen.genes[index] = newGen.genes[index].mutate();
		}
		
		return newGen;
	}
	
	private void generateGenes(){
		scope(failure) Stdout("E: generateGenes").newline;
		genes.length = population;
		for (uint i = 0; i < population; i++){
			Gene t = new Gene(ga);
			t.geneSequence = ga.getRandomGenotype();
			genes[i] = t;
		}
	}
		
	private uint tournamentSelect(){
		Gene[4] gens;
		uint[4] indices;
		scope(failure) Stdout("E: tournamentSelect").newline;
		for (uint i = 0; i < 4; i++){
			indices[i] = Random.shared.next(genes.length);
			gens[i] = genes[indices[i]];
		}
		double max = gens[0].fitness;
		uint index = indices[0];
		foreach (i,g; gens){
			if (g.fitness > max){
				max = g.fitness;
				index = indices[i];
			}
		}
		return index;
	}
}

class GeneticAlgorithm {
	public double delegate(char[]) calculateFitness; // Set these before you do anything with the library
	public char[] delegate() getRandomGenotype;
	public char delegate(uint index) getRandomChar;
	public double survivalRate = 0.5;
	public double mutationRate = 0.05;
	public bool deleteOnCrossover = true;
	public int crossoverType = 1; // 1 point crossover
	public uint startPopulation = 1000;
	public bool verbose = false;
	public uint bailout = 1000;
	
	public double fitnessThreshold;
	
	public uint run(){
		//Stdout("t").flush;
		Generation currGen = new Generation(&this,true);
		//Stdout("test").flush;
		bool stop = false;
		uint generation = 1;
		double average;
		double max;
		while (true){
			if (verbose) max = 0;
			if (verbose) average = 0;
			foreach (Gene g; currGen.genes){
				if (verbose) average += g.fitness;
				
				if (verbose) if (g.fitness > max) max = g.fitness;
				if (g.fitness >= fitnessThreshold){
					stop = true;
				}
			}
			if (verbose) average /= currGen.genes.length;
			if (verbose) Stdout.formatln("Gen: {} Avg: {:.4}  Max: {:.4}",generation, average, max);
			if (stop || generation >= bailout) return generation;
			currGen = currGen.evolve();
			generation++;
		}
	}
}