00001 #ifndef _RADIALBASISNETWORK_H 00002 #define _RADIALBASISNETWORK_H 00003 00004 #include "Network.h" 00005 #include "SimpleNeuron.h" 00006 #include "CenterNeuron.h" 00007 #include "InputLayer.h" 00008 #include "TrainingSet.h" 00009 00010 namespace annie 00011 { 00012 00034 class RadialBasisNetwork : public Network 00035 { 00036 protected: 00041 int _nCenters; 00042 00044 InputLayer *_inputLayer; 00045 00047 Layer *_centerLayer; 00048 00050 Layer *_outputLayer; 00051 00052 public: 00060 RadialBasisNetwork(int inputs, int centers, int outputs); 00061 00067 RadialBasisNetwork(const char *filename); 00068 00070 RadialBasisNetwork(RadialBasisNetwork &srcNet); 00071 00072 virtual ~RadialBasisNetwork(); 00073 00078 virtual void setCenter(uint i, Vector ¢er); 00079 00084 virtual void setCenter(uint i, real *center); 00085 00091 virtual Vector getCenter(int i) const; 00092 00094 virtual uint getCenterCount() const; 00095 00101 virtual void setBias(uint i, real bias); 00102 00107 virtual real getBias(uint i) const; 00108 00113 virtual void removeBias(uint i); 00114 00121 virtual void setWeight(int center, int output, real weight); 00122 00128 virtual real getWeight(int center, int output) const; 00129 00134 virtual Vector getOutput(Vector &input); 00135 00140 virtual Vector getOutput(real *input); 00141 00147 void setCenterActivationFunction(ActivationFunction f, ActivationFunction df); 00148 00152 void trainWeights(TrainingSet &T); 00153 //void trainCentersAndWeights(TrainingSet &T, int epochs, real learningRate); 00154 00156 virtual const char *getClassName() const; 00157 00159 virtual void save(const char *filename); 00160 }; 00161 }; //namespace annie 00162 #endif // define _RADIALBASISNETWORK_H 00163