19 #ifndef TESSERACT_LSTM_FULLYCONNECTED_H_ 20 #define TESSERACT_LSTM_FULLYCONNECTED_H_ 23 #include "networkscratch.h" 73 int RemapOutputs(
int old_no,
const std::vector<int>& code_map)
override;
95 void ForwardTimeStep(
const double* d_input,
int t,
double* output_line);
96 void ForwardTimeStep(
const int8_t* i_input,
int t,
double* output_line);
109 void Update(
float learning_rate,
float momentum,
float adam_beta,
110 int num_samples)
override;
115 double* changed)
const override;
136 #endif // TESSERACT_LSTM_FULLYCONNECTED_H_ const TransposedArray * external_source_
Definition: fullyconnected.h:124
void add_str_int(const char *str, int number)
Definition: strngs.cpp:379
void CountAlternators(const Network &other, double *same, double *changed) const override
Definition: fullyconnected.cpp:306
void SetEnableTraining(TrainingState state) override
Definition: fullyconnected.cpp:61
void ForwardTimeStep(int t, double *output_line)
Definition: fullyconnected.cpp:185
WeightMatrix weights_
Definition: fullyconnected.h:119
Definition: static_shape.h:38
Definition: networkscratch.h:36
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
Definition: fullyconnected.cpp:265
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: fullyconnected.cpp:119
FullyConnected(const STRING &name, int ni, int no, NetworkType type)
Definition: fullyconnected.cpp:39
Definition: serialis.h:77
TransposedArray source_t_
Definition: fullyconnected.h:121
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
Definition: fullyconnected.cpp:298
Definition: baseapi.cpp:94
void ConvertToInt() override
Definition: fullyconnected.cpp:96
const STRING & name() const
Definition: network.h:138
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: fullyconnected.cpp:221
int InitWeights(float range, TRand *randomizer) override
Definition: fullyconnected.cpp:77
NetworkType
Definition: network.h:43
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
Definition: fullyconnected.cpp:173
Definition: weightmatrix.h:33
Definition: network.h:105
Definition: weightmatrix.h:66
NetworkType type() const
Definition: network.h:112
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: fullyconnected.cpp:46
Definition: fullyconnected.h:28
bool DeSerialize(TFile *fp) override
Definition: fullyconnected.cpp:113
bool Serialize(TFile *fp) const override
Definition: fullyconnected.cpp:106
TrainingState
Definition: network.h:92
NetworkIO acts_
Definition: fullyconnected.h:126
int32_t no_
Definition: network.h:304
Definition: networkio.h:39
STRING spec() const override
Definition: fullyconnected.h:37
virtual ~FullyConnected()=default
void ChangeType(NetworkType type)
Definition: fullyconnected.h:60
NetworkType type_
Definition: network.h:299
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: fullyconnected.cpp:87
bool int_mode_
Definition: fullyconnected.h:129
void DebugWeights() override
Definition: fullyconnected.cpp:101
void FinishBackward(const TransposedArray &errors_t)
Definition: fullyconnected.cpp:289