19 #ifndef TESSERACT_LSTM_NETWORK_H_ 20 #define TESSERACT_LSTM_NETWORK_H_ 25 #include "genericvector.h" 28 #include "networkio.h" 30 #include "static_shape.h" 186 virtual int RemapOutputs(
int old_no,
const std::vector<int>& code_map) {
219 tprintf(
"Must override Network::DebugWeights for type %d\n",
type_);
231 virtual void Update(
float learning_rate,
float momentum,
float adam_beta,
237 double* changed)
const {}
265 tprintf(
"Must override Network::Forward for type %d\n",
type_);
276 tprintf(
"Must override Network::Backward for type %d\n",
type_);
287 static void ClearWindow(
bool tess_coords,
const char* window_name,
296 double Random(
double range);
320 #endif // TESSERACT_LSTM_NETWORK_H_ virtual int XScaleFactor() const
Definition: network.h:209
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:110
virtual int RemapOutputs(int old_no, const std::vector< int > &code_map)
Definition: network.h:186
double Random(double range)
Definition: network.cpp:275
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:306
static char const *const kTypeNames[NT_COUNT]
Definition: network.h:314
virtual void ConvertToInt()
Definition: network.h:191
virtual void CountAlternators(const Network &other, double *same, double *changed) const
Definition: network.h:236
virtual StaticShape InputShape() const
Definition: network.h:127
int32_t network_flags_
Definition: network.h:302
Definition: static_shape.h:38
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:199
int NumOutputs() const
Definition: network.h:123
virtual bool SetupNeedsBackprop(bool needs_backprop)
Definition: network.cpp:145
static int DisplayImage(Pix *pix, ScrollView *window)
Definition: network.cpp:329
Definition: networkscratch.h:36
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:133
virtual void DebugWeights()
Definition: network.h:218
virtual void SetNetworkFlags(uint32_t flags)
Definition: network.cpp:124
Definition: serialis.h:77
NetworkFlags
Definition: network.h:85
Definition: baseapi.cpp:94
bool needs_to_backprop_
Definition: network.h:301
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
const STRING & name() const
Definition: network.h:138
virtual bool DeSerialize(TFile *fp)
Definition: network.cpp:170
virtual STRING spec() const
Definition: network.h:141
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:138
NetworkType
Definition: network.h:43
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
Definition: network.h:273
bool IsTraining() const
Definition: network.h:115
Definition: weightmatrix.h:33
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
Definition: network.h:105
ScrollView * forward_win_
Definition: network.h:309
NetworkType type() const
Definition: network.h:112
int32_t ni_
Definition: network.h:303
STRING name_
Definition: network.h:306
virtual ~Network()=default
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:293
bool needs_to_backprop() const
Definition: network.h:116
TrainingState
Definition: network.h:92
Network()
Definition: network.cpp:76
virtual bool IsPlumbingType() const
Definition: network.h:152
ScrollView * backward_win_
Definition: network.h:310
int32_t num_weights_
Definition: network.h:305
int num_weights() const
Definition: network.h:119
virtual int InitWeights(float range, TRand *randomizer)
Definition: network.cpp:130
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
Definition: network.h:262
TRand * randomizer_
Definition: network.h:311
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:282
int32_t no_
Definition: network.h:304
Definition: networkio.h:39
int NumInputs() const
Definition: network.h:120
NetworkType type_
Definition: network.h:299
void set_depth(int value)
Definition: static_shape.h:49
virtual void CacheXScaleFactor(int factor)
Definition: network.h:215
virtual void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
Definition: network.h:231
TrainingState training_
Definition: network.h:300