19 #ifndef TESSERACT_LSTM_LSTMTRAINER_H_ 20 #define TESSERACT_LSTM_LSTMTRAINER_H_ 22 #include "imagedata.h" 23 #include "lstmrecognizer.h" 25 #include "tesscallback.h" 94 CheckPointReader checkpoint_reader,
95 CheckPointWriter checkpoint_writer,
96 const char* model_base,
const char* checkpoint_name,
97 int debug_interval, int64_t max_memory);
110 ASSERT_HOST(
mgr_.
Init(traineddata_path.c_str()));
124 bool InitNetwork(
const STRING& network_spec,
int append_index,
int net_flags,
174 const ImageData* trainingdata,
int iteration,
double min_dict_ratio,
175 double dict_ratio_step,
double max_dict_ratio,
double min_cert_offset,
176 double cert_offset_step,
double max_cert_offset, STRING* results);
186 bool randomly_rotate);
199 TestCallback tester, STRING* log_msg);
204 void LogIterations(
const char* intro_str, STRING* log_msg)
const;
214 bool Serialize(SerializeAmount serialize_amount,
215 const TessdataManager* mgr,
TFile* fp)
const;
242 LSTMTrainer* samples_trainer);
263 if (image !=
nullptr) {
285 const LSTMTrainer* trainer,
292 LSTMTrainer* trainer)
const {
293 if (data.
empty())
return false;
297 LSTMTrainer* trainer)
const {
393 TestCallback tester);
488 #endif // TESSERACT_LSTM_LSTMTRAINER_H_ bool SimpleTextOutput() const
Definition: lstmrecognizer.h:76
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
Definition: lstmtrainer.cpp:171
const DocumentCache & training_data() const
Definition: lstmtrainer.h:165
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
Definition: lstmtrainer.cpp:547
TessResultCallback4< STRING, int, const double *, const TessdataManager &, int > * TestCallback
Definition: lstmtrainer.h:83
int null_char() const
Definition: lstmrecognizer.h:154
void FillErrorBuffer(double new_error, ErrorTypes type)
Definition: lstmtrainer.cpp:949
const double * error_rates() const
Definition: lstmtrainer.h:140
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
Definition: lstmtrainer.cpp:1119
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
Definition: lstmtrainer.cpp:468
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
Definition: lstmtrainer.cpp:1099
bool IsRecoding() const
Definition: lstmrecognizer.h:79
int learning_iteration_
Definition: lstmtrainer.h:464
GenericVector< char > worst_model_data_
Definition: lstmtrainer.h:445
void InitCharSet()
Definition: lstmtrainer.cpp:992
bool TransitionTrainingStage(float error_threshold)
Definition: lstmtrainer.cpp:421
double ActivationError() const
Definition: lstmtrainer.h:136
Definition: lstmtrainer.h:50
Definition: lstmtrainer.h:57
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:259
Definition: lstmtrainer.h:52
bool Init(const char *data_file_name)
Definition: tessdatamanager.cpp:55
int stall_iteration_
Definition: lstmtrainer.h:442
FileReader file_reader_
Definition: lstmtrainer.h:420
bool LoadAllTrainingData(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
Definition: lstmtrainer.cpp:300
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
Definition: lstmtrainer.cpp:128
CachingStrategy
Definition: imagedata.h:42
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:438
Definition: imagedata.h:105
void UpdateErrorBuffer(double new_error, ErrorTypes type)
Definition: lstmtrainer.cpp:1248
int training_iteration() const
Definition: lstmrecognizer.h:61
GenericVector< int > best_error_iterations_
Definition: lstmtrainer.h:458
void RollErrorBuffers()
Definition: lstmtrainer.cpp:1261
int checkpoint_iteration_
Definition: lstmtrainer.h:407
SerializeAmount
Definition: lstmtrainer.h:56
int32_t improvement_steps_
Definition: lstmtrainer.h:460
TessResultCallback2< bool, const GenericVector< char > &, LSTMTrainer * > * CheckPointReader
Definition: lstmtrainer.h:69
Definition: unicharset.h:146
bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, int size)
Definition: lstmtrainer.cpp:909
Definition: imagedata.h:314
virtual ~LSTMTrainer()
Definition: lstmtrainer.cpp:116
int learning_iteration() const
Definition: lstmtrainer.h:149
Definition: lstmtrainer.h:58
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:481
const GenericVector< char > & best_trainer() const
Definition: lstmtrainer.h:152
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)
Definition: lstmtrainer.cpp:243
bool(* FileReader)(const STRING &filename, GenericVector< char > *data)
Definition: genericvector.h:360
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
Definition: lstmtrainer.cpp:1130
Definition: serialis.h:77
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:432
double ComputeRMSError(const NetworkIO &deltas)
Definition: lstmtrainer.cpp:1150
Definition: baseapi.cpp:94
void SetNullChar()
Definition: lstmtrainer.cpp:1005
Definition: lstmtrainer.h:48
int debug_interval_
Definition: lstmtrainer.h:405
SubTrainerResult
Definition: lstmtrainer.h:63
int perfect_delay_
Definition: lstmtrainer.h:472
bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
Definition: lstmtrainer.cpp:431
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
Definition: lstmtrainer.cpp:609
void StartSubtrainer(STRING *log_msg)
Definition: lstmtrainer.cpp:517
double learning_rate() const
Definition: lstmrecognizer.h:67
bool MaintainCheckpointsSpecific(int iteration, const GenericVector< char > *train_model, const GenericVector< char > *rec_model, TestCallback tester, STRING *log_msg)
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
Definition: lstmtrainer.cpp:957
int32_t null_char_
Definition: lstmrecognizer.h:290
int prev_sample_iteration_
Definition: lstmtrainer.h:466
STRING DumpFilename() const
Definition: lstmtrainer.cpp:940
Definition: unicharcompress.h:128
void SetupCheckpointInfo()
double best_error_rate() const
Definition: lstmtrainer.h:143
int sample_iteration() const
Definition: lstmrecognizer.h:64
void InitIterations()
Definition: lstmtrainer.cpp:218
double ComputeWinnerError(const NetworkIO &deltas)
Definition: lstmtrainer.cpp:1169
bool(* FileWriter)(const GenericVector< char > &data, const STRING &filename)
Definition: genericvector.h:363
Definition: tesscallback.h:1702
bool ReadSizedTrainingDump(const char *data, int size, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:296
int worst_iteration_
Definition: lstmtrainer.h:440
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
Definition: lstmtrainer.cpp:798
int CurrentTrainingStage() const
Definition: lstmtrainer.h:211
int32_t improvement_steps() const
Definition: lstmtrainer.h:150
FileWriter file_writer_
Definition: lstmtrainer.h:421
GenericVector< double > best_error_history_
Definition: lstmtrainer.h:457
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
Definition: lstmtrainer.cpp:1215
int InitTensorFlowNetwork(const std::string &tf_proto)
Definition: lstmtrainer.cpp:198
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:154
TessResultCallback3< bool, SerializeAmount, const LSTMTrainer *, GenericVector< char > * > * CheckPointWriter
Definition: lstmtrainer.h:78
double LastSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:160
double CharError() const
Definition: lstmtrainer.h:139
void DebugNetwork()
Definition: lstmtrainer.cpp:293
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:424
Definition: lstmtrainer.h:66
ScrollView * ctc_win_
Definition: lstmtrainer.h:401
void PrepareLogMsg(STRING *log_msg) const
Definition: lstmtrainer.cpp:400
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
Definition: lstmtrainer.cpp:1280
void set_perfect_delay(int delay)
Definition: lstmtrainer.h:151
Definition: lstmtrainer.h:64
Definition: tessdatamanager.h:126
int training_stage_
Definition: lstmtrainer.h:454
Definition: lstmtrainer.h:49
int best_iteration() const
Definition: lstmtrainer.h:146
bool SaveTraineddata(const STRING &filename)
Definition: lstmtrainer.cpp:921
int num_training_stages_
Definition: lstmtrainer.h:418
STRING best_model_name_
Definition: lstmtrainer.h:416
void SaveRecognitionDump(GenericVector< char > *data) const
Definition: lstmtrainer.cpp:930
Definition: lstmtrainer.h:59
Definition: lstmtrainer.h:40
LSTMTrainer()
Definition: lstmtrainer.cpp:73
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
Definition: lstmtrainer.cpp:590
Definition: lstmtrainer.h:41
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
Definition: lstmtrainer.h:246
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:425
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
Definition: lstmtrainer.cpp:1187
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:450
Definition: lstmtrainer.h:51
Definition: tesscallback.h:1716
static const int kRollingBufferSize_
Definition: lstmtrainer.h:478
DocumentCache * mutable_training_data()
Definition: lstmtrainer.h:168
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:479
Definition: lstmtrainer.h:65
double worst_error_rate_
Definition: lstmtrainer.h:436
const UNICHARSET & GetUnicharset() const
Definition: lstmrecognizer.h:139
Definition: lstmtrainer.h:39
STRING checkpoint_name_
Definition: lstmtrainer.h:411
int size() const
Definition: genericvector.h:71
ScrollView * recon_win_
Definition: lstmtrainer.h:403
void InitCharSet(const TessdataManager &mgr)
Definition: lstmtrainer.h:113
Definition: lstmtrainer.h:43
UnicharCompress recoder_
Definition: lstmrecognizer.h:277
int32_t sample_iteration_
Definition: lstmrecognizer.h:287
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
Definition: lstmtrainer.cpp:1062
const ImageData * GetPageBySerial(int serial)
Definition: imagedata.h:337
Definition: lstmtrainer.h:38
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
Definition: lstmtrainer.cpp:900
Definition: lstmrecognizer.h:53
ErrorTypes
Definition: lstmtrainer.h:37
Definition: networkio.h:39
bool randomly_rotate_
Definition: lstmtrainer.h:413
GenericVector< char > best_model_data_
Definition: lstmtrainer.h:444
Trainability
Definition: lstmtrainer.h:47
void LogIterations(const char *intro_str, STRING *log_msg) const
Definition: lstmtrainer.cpp:412
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:447
ScrollView * align_win_
Definition: lstmtrainer.h:397
STRING model_base_
Definition: lstmtrainer.h:409
TessdataManager mgr_
Definition: lstmtrainer.h:483
Definition: lstmtrainer.h:89
void EmptyConstructor()
Definition: lstmtrainer.cpp:1014
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:452
ScrollView * target_win_
Definition: lstmtrainer.h:399
DocumentCache training_data_
Definition: lstmtrainer.h:414
bool empty() const
Definition: genericvector.h:90
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
Definition: lstmtrainer.cpp:1029
int best_iteration_
Definition: lstmtrainer.h:434
Definition: lstmtrainer.h:42
double best_error_rate_
Definition: lstmtrainer.h:430
int last_perfect_training_iteration_
Definition: lstmtrainer.h:475
void InitCharSet(const std::string &traineddata_path)
Definition: lstmtrainer.h:109
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:291
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
Definition: lstmtrainer.cpp:312