Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members | Related Pages

HopfieldNetwork.h

Go to the documentation of this file.
00001 #ifndef HOPFIELDNETWORK_H
00002 #define HOPFIELDNETWORK_H
00003 
00010 #include "Network.h"
00011 #include "Matrix.h"
00012 #include <vector>
00013 
00014 namespace annie {
00015 
00020 real isPositive(real x);
00021 
00041 /*abstract*/ class HopfieldBase : public Network    {
00042 protected:
00044     bool _bipolar;
00045 
00047     int _nPatterns;
00048     int _time;
00049 
00050     HopfieldBase() : Network(0,0), _neurons(0) {}
00051 
00052     struct  NData {
00053         Vector _biases;
00054         //      Vector _activations;
00055         Vector _outputs;
00056         ActivationFunction _function;   //one for all
00057         void setSize(uint size) {
00058             _biases.resize(size); _outputs.resize(size);
00059         }
00060         NData(uint size, ActivationFunction f=signum) : _biases(size), _outputs(size), _function(f) {}
00061     } _neurons;
00062     
00064     int real2int(real r) const;
00066     typedef uint Nid;
00067     real _getBias(Nid i);
00068     void _setBias(Nid i, real bias);
00069     void _setNeuronOutput(Nid n, real a) { _neurons._outputs[n] = a; }
00070     real _getNeuronOutput(Nid n) const  { return _neurons._outputs[n]; }
00071 public:
00072     /*virtual void setActivation(Nid n, real a) = 0;
00073     virtual real getOutput(Nid n, real a) = 0;*/
00074 
00075     virtual real getWeight(uint from, uint to) const = 0;
00076     virtual void setWeight(uint from, uint to, real weight) = 0;    
00077 
00085     HopfieldBase(uint size, bool bias=false, bool bipolar=true);
00086 
00087     virtual ~HopfieldBase() {}
00088 
00090     virtual void save(const std::string &filename) { throw Exception("Not implemented"); }
00091 
00093     virtual real getEnergy();
00094 
00098     virtual real getEnergy(int pattern[]);
00099 
00101     virtual uint getSize() const;
00102 
00104     virtual void step();
00105     
00110     virtual int getTime();
00111 
00113     void randomize();
00114 
00116     virtual const char* getClassName() const;
00117 
00118 
00119     /* Given an input pattern, keeps iterating through time till the network
00120       * output converges. Ofcourse, it is possible that this never happens
00121       * and hence a timeout has to be specified.
00122       * \todo Implement this!
00123       * @param pattern The initial input pattern given to the network
00124       * @param updateAll Determines type of updating (synchronous, asynchronous)
00125       * @param timeout The maximum number of iteration to try convergence for
00126          * @return false if the network output didn't converge till the timeout, true otherwise
00127       */
00128     //virtual bool converge(int pattern[], bool updateAll, int timeout);
00129 
00134     virtual real getBias(uint i);
00135 
00140     virtual void setBias(uint i, real bias);
00141 
00147     virtual void setInput(int pattern[]);
00148 
00155     virtual void setInput(const std::vector<int> &pattern);
00156 
00160     virtual void setInput(const Vector &pattern);
00161 
00165     virtual Vector getOutput() const;
00166 
00168     std::vector<int> getOutputInt() const;
00169 
00174     virtual Vector getNextOutput();
00175 
00183     virtual Vector getOutput(const Vector &input);
00184 
00192     virtual bool propagate(int pattern[], uint timeout);
00193 
00201     virtual bool propagate(const std::vector<int> &pattern, uint timeout);
00202 
00204     virtual bool HopfieldBase::propagate(const Vector &pattern, uint timeout);
00205 };
00206 
00211 class HopfieldNetwork : public HopfieldBase {
00212 protected:
00214     Matrix *_weightMatrix;
00215 
00216     virtual real getWeight(uint from, uint to) const;
00217 public:
00218     HopfieldNetwork(uint size, bool bias=false, bool bipolar=true);
00219 
00225     HopfieldNetwork(const char *filename);
00226     
00228     ~HopfieldNetwork();
00229 
00235     virtual void save(const std::string &filename);
00236 
00246     virtual void setWeight(uint i, uint j, real weight);
00247 
00249     virtual Matrix getWeightMatrix();
00250 
00257     virtual void addPattern(int pattern[]);
00258 
00263     virtual uint getPatternCount();
00264 };
00265 
00266 
00267 }
00268 ; //namespace annie
00269 #endif // define _HOPFIELDNETWORK_H
00270 

Generated on Fri Jun 18 13:18:10 2004 for Annie by doxygen 1.3.5