Mercurial > projects > aid
comparison trunk/aid/nn/perceptron.d @ 2:9655c8362b25
Added the Perceptron class and the perceptron_test testing program.
author | revcompgeek |
---|---|
date | Sat, 05 Apr 2008 23:41:30 -0600 |
parents | |
children | 314d68bafeff |
comparison
equal
deleted
inserted
replaced
1:5dd9f598bcd8 | 2:9655c8362b25 |
---|---|
1 module aid.nn.perceptron; | |
2 | |
3 import std.random; | |
4 import std.math; | |
5 import std.string; | |
6 | |
7 double rnd(){ // The function that should be included in every math library! | |
8 return (cast(double)rand())/uint.max; | |
9 } | |
10 | |
11 class InputException : Exception { | |
12 this(char[] message){ | |
13 super(message); | |
14 } | |
15 } | |
16 | |
17 // The output functions | |
18 | |
19 alias double function(double) OutputFunction; | |
20 | |
21 double sign(double y){ | |
22 if(y>0) return 1; | |
23 return -1; | |
24 } | |
25 | |
26 double sigmoid(double x){ | |
27 return 1/(1+exp(-x)); | |
28 } | |
29 | |
30 double tanh(double x){ | |
31 return cast(double)tanh(cast(real)x); | |
32 } | |
33 | |
34 // End output functions | |
35 | |
36 | |
37 class perceptron { | |
38 private int numInputs; | |
39 private double[] weights; | |
40 private OutputFunction func; | |
41 public double learningRate; | |
42 | |
43 /** | |
44 * This is the constructor for loading the neural network from a string. | |
45 * | |
46 * Params: | |
47 * savedString = The string that was output from the save function. | |
48 * | |
49 * Throws: | |
50 * Throws an InputException when the string is in the wrong format. | |
51 */ | |
52 | |
53 public this(char[] savedString){ | |
54 //TODO: Impliment loading! | |
55 throw new Exception("Not implimented."); | |
56 } | |
57 | |
58 // This is private because one type of perceptron training requires the use of the sign function. | |
59 this(int numInputs, double learningRate=0.3, bool randomize=true,OutputFunction f=null){ | |
60 this.numInputs = numInputs + 1; | |
61 weights.length = numInputs + 1; | |
62 func = f; | |
63 this.learningRate = learningRate; | |
64 if(randomize){ | |
65 for(int i = 0; i < this.numInputs; i++){ | |
66 weights[i] = rnd() * 2 - 1; | |
67 } | |
68 } else { | |
69 for(int i = 0; i < this.numInputs; i++){ | |
70 weights[i] = 0; | |
71 } | |
72 } | |
73 } | |
74 | |
75 /** | |
76 * Evaluates the neural network. | |
77 * | |
78 * Params: | |
79 * inputs = The set of inputs to evaluate. | |
80 * | |
81 * Returns: 1 to indicate true, -1 for false | |
82 */ | |
83 | |
84 double evaluate(double[] inputs){ | |
85 if(inputs.length != numInputs-1) throw new InputException("Wrong input length. %d %d"); | |
86 double total = weights[0]; | |
87 for(int i = 1; i < numInputs; i++){ | |
88 total += inputs[i-1] * weights[i]; | |
89 } | |
90 if(func != null) return func(total); | |
91 return total; | |
92 } | |
93 | |
94 public double[] getWeights(){ | |
95 return weights.dup; | |
96 } | |
97 | |
98 /** | |
99 * Trains the neural network. This must be overloaded in a subclass. | |
100 * | |
101 * Params: | |
102 * inputs = The array of inputs to the nerual network. | |
103 * targetOutput = The output that the nerual network should give. | |
104 */ | |
105 /* Returns: True if it trained the network, false if not. | |
106 */ | |
107 | |
108 void train(double[] inputs,double targetOutput){ | |
109 if(inputs.length != numInputs-1) throw new InputException("Wrong input length."); | |
110 double output = evaluate(inputs); | |
111 double error = this.learningRate * (targetOutput - output); | |
112 weights[0] += error; | |
113 for(int i = 1; i < numInputs; i++){ | |
114 weights[i] += error * inputs[i-1]; | |
115 } | |
116 } | |
117 | |
118 /** | |
119 * Calculates the error based on the sum squared error function. | |
120 * | |
121 * Params: | |
122 * inputs = An array of arrays of all testing inputs. | |
123 * outputs = An array of all the outputs that the cooresponding inputs should have. | |
124 * | |
125 * Returns: | |
126 * The error value. | |
127 */ | |
128 | |
129 double getErrorValue(double[][] inputsArray, double[] outputsArray){ | |
130 double total = 0; | |
131 if(inputsArray.length != outputsArray.length) throw new InputException("inputsArray and outputsArray must be the same length"); | |
132 if(inputsArray.length < 1) throw new InputException("Must have at least 1 training example"); | |
133 double output,temp; | |
134 for(int i = 0; i < inputsArray.length; i++){ | |
135 output=evaluate(inputsArray[i]); | |
136 temp = outputsArray[i] - output; | |
137 total += temp*temp; | |
138 } | |
139 return total*0.5; | |
140 } | |
141 } | |
142 |