19 #ifndef TESSERACT_LSTM_LSTM_H_ 20 #define TESSERACT_LSTM_LSTM_H_ 23 #include "fullyconnected.h" 50 LSTM(
const STRING&
name,
int num_inputs,
int num_states,
int num_outputs,
81 int RemapOutputs(
int old_no,
const std::vector<int>& code_map)
override;
106 void Update(
float learning_rate,
float momentum,
float adam_beta,
107 int num_samples)
override;
112 double* changed)
const override;
162 #endif // TESSERACT_LSTM_LSTM_H_ void add_str_int(const char *str, int number)
Definition: strngs.cpp:379
int input_width_
Definition: lstm.h:156
FullyConnected * softmax_
Definition: lstm.h:145
Definition: static_shape.h:38
WeightMatrix gate_weights_[WT_COUNT]
Definition: lstm.h:143
bool Is2D() const
Definition: lstm.h:119
bool is_2d_
Definition: lstm.h:140
Definition: networkscratch.h:36
int InitWeights(float range, TRand *randomizer) override
Definition: lstm.cpp:158
int32_t ns_
Definition: lstm.h:134
void ResizeForward(const NetworkIO &input)
Definition: lstm.cpp:753
Definition: serialis.h:77
void PrintDW()
Definition: lstm.cpp:727
STRING spec() const override
Definition: lstm.h:58
Definition: baseapi.cpp:94
GENERIC_2D_ARRAY< int8_t > which_fg_
Definition: lstm.h:151
const STRING & name() const
Definition: network.h:138
NetworkType
Definition: network.h:43
int32_t nf_
Definition: lstm.h:138
NetworkIO source_
Definition: lstm.h:147
NetworkIO node_values_[WT_COUNT]
Definition: lstm.h:153
Definition: weightmatrix.h:33
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: lstm.cpp:250
Definition: network.h:105
void SetEnableTraining(TrainingState state) override
Definition: lstm.cpp:137
int32_t na_
Definition: lstm.h:131
Definition: weightmatrix.h:66
NetworkType type() const
Definition: network.h:112
bool Serialize(TFile *fp) const override
Definition: lstm.cpp:207
NetworkIO state_
Definition: lstm.h:149
bool DeSerialize(TFile *fp) override
Definition: lstm.cpp:220
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: lstm.cpp:174
Definition: fullyconnected.h:28
void CountAlternators(const Network &other, double *same, double *changed) const override
Definition: lstm.cpp:687
void DebugWeights() override
Definition: lstm.cpp:194
TrainingState
Definition: network.h:92
StrideMap input_map_
Definition: lstm.h:155
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: lstm.cpp:127
LSTM(const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
Definition: lstm.cpp:99
void ConvertToInt() override
Definition: lstm.cpp:183
Definition: stridemap.h:43
Definition: networkio.h:39
STRING spec() const override
Definition: fullyconnected.h:37
NetworkType type_
Definition: network.h:299
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
Definition: lstm.cpp:667
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: lstm.cpp:441
void PrintW()
Definition: lstm.cpp:701
virtual ~LSTM()
Definition: lstm.cpp:123
WeightType
Definition: lstm.h:33