00001
00003 #ifndef KOHONEN_H
00004 #define KOHONEN_H
00005
00006 #include "Exception.h"
00007 #include "Network.h"
00008 #include "SimpleNeuron.h"
00009 #include "InputNeuron.h"
00010 #include "InputLayer.h"
00011 #include <string>
00012
00013 namespace annie
00014 {
00015 class TrainingSet;
00016
00017
00018
00019
00020
00021
00022 class KohonenNeuron : public SimpleNeuron {
00023 public:
00024 KohonenNeuron(int label);
00025 static const real RANDOM_WEIGHTS_DEV;
00026 void adaptTowards(const Vector &input, real nbWeight, real learningParam);
00027 void randomizeWeights();
00028 protected:
00029 virtual void _recacheOutput() const;
00030 void setWeights(const Vector &w);
00031 };
00032
00033
00037 #if 0
00038 class Topology {
00039 public:
00043 real getNeighborWeight(uint srcNeuron, destNeuron, real nbSize);
00044
00045 };
00046 #endif
00047
00048 class Topology {
00049 public:
00053 virtual real getNeighborWeight(uint srcNeuron, uint destNeuron, real nbSize) const = 0;
00054
00059 virtual bool isGridNeighbor(uint srcNeuron, uint destNeuron) const = 0;
00060
00062 uint getTotalCount() const { return totalCount; }
00063
00068 real getMaxDim() const { return maxDim; }
00069 protected:
00070 uint totalCount;
00071 real maxDim;
00072 };
00073
00079 class KohonenParameters {
00080 public:
00081 static const char * const NB, *const LP;
00082 KohonenParameters(real initNB, real initLP, PublicValues &ctrl=defaultControl) : control(ctrl) { setNB(initNB); setLP(initLP); }
00083
00085 virtual void nextIteration() = 0;
00086
00088 real getNeighborhoodSize() const;
00089
00091 real getLearningParam() const;
00092 protected:
00093 void setNB(real nb);
00094 void setLP(real lp);
00095
00096 PublicValues &control;
00097 };
00098
00102 class StaticKohonenParameters : public KohonenParameters {
00103 public:
00104 StaticKohonenParameters(real initNB, real initLP) : KohonenParameters(initNB, initLP) {}
00105 virtual void nextIteration() {}
00106 };
00107
00108 class DynamicKohonenParameters : public KohonenParameters {
00109 public:
00110
00111 virtual void nextIteration() { ++_step; recompute(); }
00112 virtual void recompute() = 0;
00113 uint getStep() const { assert(_step); return _step; }
00114 protected:
00115 DynamicKohonenParameters() : KohonenParameters(1., 1.), _step(1) { }
00116 private:
00117 uint _step;
00118 };
00119
00126 class StandardKohonenParameters : public DynamicKohonenParameters {
00127 public:
00128 StandardKohonenParameters(real maxNB, real slope) : _slope(-slope), _maxNB(maxNB) { }
00129 virtual void recompute();
00130 protected:
00131 real _slope, _maxNB;
00132 };
00133
00134
00135
00136
00137
00138
00139
00140
00141
00146 class KohonenNetwork : public Network
00147 {
00148 public:
00149 enum { INPUT_LAYER, OUTPUT_LAYER };
00150 static const real MAX_OUTPUT, MIN_OUTPUT;
00151 static const real TERMINATING_LP;
00152
00153 KohonenNetwork(uint inputs, const Topology &topology, KohonenParameters ¶ms, PublicValues &ctrl=defaultControl);
00164 void train(TrainingSet &T, real terminatingLP=TERMINATING_LP);
00165
00167 void trainExample(const Vector &input);
00168
00169
00170
00175 uint getWinnerOutput(const Vector &input);
00176
00180 uint getWinner() const;
00181
00185 virtual Vector getOutput(const Vector &input);
00186 const Layer &getOutputLayer() const { return _outputLayer; }
00187 virtual const char *getClassName() const { return "KohonenNetwork"; }
00188
00192 virtual Vector getCluster(uint cluserNum) const;
00193
00194 virtual void save(const std::string &filename) { throw Exception("not implemented"); }
00195
00196 const Topology &getTopology() const { return topology; }
00197 protected:
00198 void connectLayers();
00199 void randomizeWeights();
00200 real getNBSize() const { return parameters.getNeighborhoodSize(); }
00201 real getLearningParam() const { return parameters.getLearningParam(); }
00202 void _setInput(const Vector &input);
00203
00204 InputLayer _inputLayer;
00205 Layer _outputLayer;
00206 const Topology &topology;
00207 KohonenParameters ¶meters;
00208 PublicValues &control;
00209 mutable uint _winner;
00210 mutable bool _winnerValid;
00211 };
00212
00213 }
00214 #endif //_H