00001 #ifndef _TRAININGSET_H
00002 #define _TRAININGSET_H
00003
00004 #include "Neuron.h"
00005 #include "Network.h"
00006
00007 namespace annie
00008 {
00009
00012 typedef Vector (*XformFunction)(const Vector &in);
00013
00015 struct TSTransformer {
00016 TSTransformer(uint isize, uint osize) : _isize(isize), _osize(osize) {}
00017 virtual ~TSTransformer() {}
00018
00022 virtual void xform(const Vector &in1, const Vector &out1, Vector &in2, Vector &out2) const = 0;
00023 uint getISize() const { return _isize; }
00024 uint getOSize() const { return _osize; }
00025 private:
00026 uint _isize, _osize;
00027 };
00028
00042 class TrainingSet
00043 {
00044 protected:
00046 std::vector< Vector > _inputs;
00047
00049 std::vector< Vector > _outputs;
00050
00052 std::vector< Vector >::iterator _inputIter;
00053
00055 std::vector< Vector >::iterator _outputIter;
00056
00058 uint _nInputs;
00059
00061 uint _nOutputs;
00062
00064 void save_binary(const std::string &filename);
00065
00067 void save_text(const std::string &filename);
00068
00070 void load_binary(const std::string &filename);
00071
00073 void load_text(const std::string &filename);
00074 public:
00080 TrainingSet(uint in, uint out);
00081
00088 TrainingSet(const std::string &filename,int file_type = TEXT_FILE);
00089
00090 virtual ~TrainingSet();
00091
00093 void addIOpair(real *input,real *output);
00094
00099 void addIOpair(const Vector &input, const Vector &output);
00100
00102 void addIOpair(const Vector &input);
00103
00108 virtual void initialize();
00109
00115 virtual bool epochOver() const;
00116
00118 virtual uint getSize() const;
00119
00121 virtual uint getInputSize() const;
00122
00124 virtual uint getOutputSize() const;
00125
00141 virtual void getNextPair(Vector &input, Vector &desired);
00142
00144 friend std::ostream& operator << (std::ostream& s, TrainingSet &T);
00145
00150 virtual void save(const std::string &filename, int file_type = TEXT_FILE);
00151
00153 operator std::string() const;
00154
00156 virtual const char *getClassName() const;
00157
00159 void shuffle();
00160
00166 TrainingSet operator+ (const TrainingSet &ts) const;
00167 TrainingSet &operator+= (const TrainingSet &ts);
00168 TrainingSet xform(XformFunction ix, XformFunction ox, uint resI, uint resO);
00169 TrainingSet xform(XformFunction ix, XformFunction ox) { return xform(ox, ix, getInputSize(), getOutputSize()); }
00170
00172 TrainingSet xform(XformFunction ix, uint resI);
00173
00175 TrainingSet xform(XformFunction ix);
00176
00180 TrainingSet mixedXform(const TSTransformer &xf);
00181 };
00182
00183 };
00184 #endif // define _TRAININGSET_H