19 #ifndef TESSERACT_LSTM_WEIGHTMATRIX_H_ 20 #define TESSERACT_LSTM_WEIGHTMATRIX_H_ 23 #include "genericvector.h" 24 #include "intsimdmatrix.h" 42 for (
int i = 0; i < size1; ++i)
put(i, t, data[i]);
46 for (
int i = 0; i < size1; ++i)
put(i, t, data[i]);
50 int num_features =
dim1();
52 for (
int y = 0; y < num_features; ++y) {
53 for (
int t = 0; t < width; ++t) {
54 if (num == 0 || t < num || t + num >= width) {
55 tprintf(
" %g", (*
this)(y, t));
74 int InitWeightsFloat(
int no,
int ni,
bool use_adam,
float weight_range,
81 int RemapOutputs(
const std::vector<int>& code_map);
94 if (multiplier_ ==
nullptr)
return size;
95 return multiplier_->RoundInputs(size);
102 int NumOutputs()
const {
return int_mode_ ? wi_.dim1() : wf_.dim1(); }
106 double GetDW(
int i,
int j)
const {
return dw_(i, j); }
118 bool DeSerializeOld(
bool training,
TFile* fp);
125 void MatrixDotVector(
const double* u,
double* v)
const;
126 void MatrixDotVector(
const int8_t* u,
double* v)
const;
134 void VectorDotMatrix(
const double* u,
double* v)
const;
143 void Update(
double learning_rate,
double momentum,
double adam_beta,
150 void CountAlternators(
const WeightMatrix& other,
double* same,
151 double* changed)
const;
153 void Debug2D(
const char* msg);
156 static double DotProduct(
const double* u,
const double* v,
int n);
170 bool add_bias_fwd,
bool skip_bias_back,
171 const double* u,
double* v);
199 #endif // TESSERACT_LSTM_WEIGHTMATRIX_H_ void MultiplyAccumulate(int n, const double *u, const double *v, double *out)
Definition: functions.h:201
GENERIC_2D_ARRAY< double > dw_
Definition: weightmatrix.h:188
GENERIC_2D_ARRAY< double > updates_
Definition: weightmatrix.h:189
int RoundInputs(int size) const
Definition: weightmatrix.h:93
std::unique_ptr< IntSimdMatrix > multiplier_
Definition: weightmatrix.h:194
void put(ICOORD pos, const double &thing)
Definition: matrix.h:220
GENERIC_2D_ARRAY< int8_t > wi_
Definition: weightmatrix.h:176
GENERIC_2D_ARRAY< double > wf_
Definition: weightmatrix.h:175
Definition: intsimdmatrix.h:25
TransposedArray wf_t_
Definition: weightmatrix.h:178
GenericVector< double > scales_
Definition: weightmatrix.h:185
Definition: serialis.h:77
int dim1() const
Definition: matrix.h:206
bool int_mode_
Definition: weightmatrix.h:180
Definition: baseapi.cpp:94
WeightMatrix()
Definition: weightmatrix.h:68
int dim2() const
Definition: matrix.h:207
bool Serialize(FILE *fp) const
Definition: matrix.h:144
Definition: weightmatrix.h:33
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:40
Definition: weightmatrix.h:66
bool is_int_mode() const
Definition: weightmatrix.h:99
double GetDW(int i, int j) const
Definition: weightmatrix.h:106
virtual ~TransposedArray()
bool DeSerialize(bool swap, FILE *fp)
Definition: matrix.h:161
void Transpose(const GENERIC_2D_ARRAY< double > &input)
Definition: weightmatrix.cpp:42
GENERIC_2D_ARRAY< double > dw_sq_sum_
Definition: weightmatrix.h:192
int NumOutputs() const
Definition: weightmatrix.h:102
void WriteStrided(int t, const double *data)
Definition: weightmatrix.h:44
virtual int index(int column, int row) const
Definition: matrix.h:215
bool use_adam_
Definition: weightmatrix.h:182
void PrintUnTransposed(int num)
Definition: weightmatrix.h:49
const double * GetWeights(int index) const
Definition: weightmatrix.h:104