00001 #ifndef HOPFIELDNETWORK_H
00002 #define HOPFIELDNETWORK_H
00003
00010 #include "Network.h"
00011 #include "Matrix.h"
00012 #include <vector>
00013
00014 namespace annie {
00015
00020 real isPositive(real x);
00021
00041 class HopfieldBase : public Network {
00042 protected:
00044 bool _bipolar;
00045
00047 int _nPatterns;
00048 int _time;
00049
00050 HopfieldBase() : Network(0,0), _neurons(0) {}
00051
00052 struct NData {
00053 Vector _biases;
00054
00055 Vector _outputs;
00056 ActivationFunction _function;
00057 void setSize(uint size) {
00058 _biases.resize(size); _outputs.resize(size);
00059 }
00060 NData(uint size, ActivationFunction f=signum) : _biases(size), _outputs(size), _function(f) {}
00061 } _neurons;
00062
00064 int real2int(real r) const;
00066 typedef uint Nid;
00067 real _getBias(Nid i);
00068 void _setBias(Nid i, real bias);
00069 void _setNeuronOutput(Nid n, real a) { _neurons._outputs[n] = a; }
00070 real _getNeuronOutput(Nid n) const { return _neurons._outputs[n]; }
00071 public:
00072
00073
00074
00075 virtual real getWeight(uint from, uint to) const = 0;
00076 virtual void setWeight(uint from, uint to, real weight) = 0;
00077
00085 HopfieldBase(uint size, bool bias=false, bool bipolar=true);
00086
00087 virtual ~HopfieldBase() {}
00088
00090 virtual void save(const std::string &filename) { throw Exception("Not implemented"); }
00091
00093 virtual real getEnergy();
00094
00098 virtual real getEnergy(int pattern[]);
00099
00101 virtual uint getSize() const;
00102
00104 virtual void step();
00105
00110 virtual int getTime();
00111
00113 void randomize();
00114
00116 virtual const char* getClassName() const;
00117
00118
00119
00120
00121
00122
00123
00124
00125
00126
00127
00128
00129
00134 virtual real getBias(uint i);
00135
00140 virtual void setBias(uint i, real bias);
00141
00147 virtual void setInput(int pattern[]);
00148
00155 virtual void setInput(const std::vector<int> &pattern);
00156
00160 virtual void setInput(const Vector &pattern);
00161
00165 virtual Vector getOutput() const;
00166
00168 std::vector<int> getOutputInt() const;
00169
00174 virtual Vector getNextOutput();
00175
00183 virtual Vector getOutput(const Vector &input);
00184
00192 virtual bool propagate(int pattern[], uint timeout);
00193
00201 virtual bool propagate(const std::vector<int> &pattern, uint timeout);
00202
00204 virtual bool HopfieldBase::propagate(const Vector &pattern, uint timeout);
00205 };
00206
00211 class HopfieldNetwork : public HopfieldBase {
00212 protected:
00214 Matrix *_weightMatrix;
00215
00216 virtual real getWeight(uint from, uint to) const;
00217 public:
00218 HopfieldNetwork(uint size, bool bias=false, bool bipolar=true);
00219
00225 HopfieldNetwork(const char *filename);
00226
00228 ~HopfieldNetwork();
00229
00235 virtual void save(const std::string &filename);
00236
00246 virtual void setWeight(uint i, uint j, real weight);
00247
00249 virtual Matrix getWeightMatrix();
00250
00257 virtual void addPattern(int pattern[]);
00258
00263 virtual uint getPatternCount();
00264 };
00265
00266
00267 }
00268 ;
00269 #endif // define _HOPFIELDNETWORK_H
00270