#include <lstm.h>
Public Types | |
enum | WeightType { CI, GI, GF1, GO, GFS, WT_COUNT } |
Public Member Functions | |
LSTM (const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type) | |
virtual | ~LSTM () |
StaticShape | OutputShape (const StaticShape &input_shape) const override |
STRING | spec () const override |
void | SetEnableTraining (TrainingState state) override |
int | InitWeights (float range, TRand *randomizer) override |
int | RemapOutputs (int old_no, const std::vector< int > &code_map) override |
void | ConvertToInt () override |
void | DebugWeights () override |
bool | Serialize (TFile *fp) const override |
bool | DeSerialize (TFile *fp) override |
void | Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override |
bool | Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override |
void | Update (float learning_rate, float momentum, float adam_beta, int num_samples) override |
void | CountAlternators (const Network &other, double *same, double *changed) const override |
void | PrintW () |
void | PrintDW () |
bool | Is2D () const |
Public Member Functions inherited from tesseract::Network | |
Network () | |
Network (NetworkType type, const STRING &name, int ni, int no) | |
virtual | ~Network ()=default |
NetworkType | type () const |
bool | IsTraining () const |
bool | needs_to_backprop () const |
int | num_weights () const |
int | NumInputs () const |
int | NumOutputs () const |
virtual StaticShape | InputShape () const |
const STRING & | name () const |
bool | TestFlag (NetworkFlags flag) const |
virtual bool | IsPlumbingType () const |
virtual void | SetNetworkFlags (uint32_t flags) |
virtual void | SetRandomizer (TRand *randomizer) |
virtual bool | SetupNeedsBackprop (bool needs_backprop) |
virtual int | XScaleFactor () const |
virtual void | CacheXScaleFactor (int factor) |
void | DisplayForward (const NetworkIO &matrix) |
void | DisplayBackward (const NetworkIO &matrix) |
Private Member Functions | |
void | ResizeForward (const NetworkIO &input) |
Private Attributes | |
int32_t | na_ |
int32_t | ns_ |
int32_t | nf_ |
bool | is_2d_ |
WeightMatrix | gate_weights_ [WT_COUNT] |
FullyConnected * | softmax_ |
NetworkIO | source_ |
NetworkIO | state_ |
GENERIC_2D_ARRAY< int8_t > | which_fg_ |
NetworkIO | node_values_ [WT_COUNT] |
StrideMap | input_map_ |
int | input_width_ |
Additional Inherited Members | |
Static Public Member Functions inherited from tesseract::Network | |
static Network * | CreateFromFile (TFile *fp) |
static void | ClearWindow (bool tess_coords, const char *window_name, int width, int height, ScrollView **window) |
static int | DisplayImage (Pix *pix, ScrollView *window) |
Protected Member Functions inherited from tesseract::Network | |
double | Random (double range) |
Protected Attributes inherited from tesseract::Network | |
NetworkType | type_ |
TrainingState | training_ |
bool | needs_to_backprop_ |
int32_t | network_flags_ |
int32_t | ni_ |
int32_t | no_ |
int32_t | num_weights_ |
STRING | name_ |
ScrollView * | forward_win_ |
ScrollView * | backward_win_ |
TRand * | randomizer_ |
Static Protected Attributes inherited from tesseract::Network | |
static char const *const | kTypeNames [NT_COUNT] |
tesseract::LSTM::LSTM | ( | const STRING & | name, |
int | num_inputs, | ||
int | num_states, | ||
int | num_outputs, | ||
bool | two_dimensional, | ||
NetworkType | type | ||
) |
|
virtual |
|
overridevirtual |
Reimplemented from tesseract::Network.
|
overridevirtual |
Reimplemented from tesseract::Network.
|
overridevirtual |
Reimplemented from tesseract::Network.
|
overridevirtual |
Reimplemented from tesseract::Network.
|
overridevirtual |
Reimplemented from tesseract::Network.
|
overridevirtual |
Reimplemented from tesseract::Network.
|
overridevirtual |
Reimplemented from tesseract::Network.
|
inline |
|
overridevirtual |
Reimplemented from tesseract::Network.
void tesseract::LSTM::PrintDW | ( | ) |
void tesseract::LSTM::PrintW | ( | ) |
|
overridevirtual |
Reimplemented from tesseract::Network.
|
private |
|
overridevirtual |
Reimplemented from tesseract::Network.
|
overridevirtual |
Reimplemented from tesseract::Network.
|
inlineoverridevirtual |
Reimplemented from tesseract::Network.
|
overridevirtual |
Reimplemented from tesseract::Network.
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |