tesseract  v4.0.0-17-g361f3264
Open Source OCR Engine
networkbuilder.h
1 // File: networkbuilder.h
3 // Description: Class to parse the network description language and
4 // build a corresponding network.
5 // Author: Ray Smith
6 // Created: Wed Jul 16 18:35:38 PST 2014
7 //
8 // (C) Copyright 2014, Google Inc.
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 // http://www.apache.org/licenses/LICENSE-2.0
13 // Unless required by applicable law or agreed to in writing, software
14 // distributed under the License is distributed on an "AS IS" BASIS,
15 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 // See the License for the specific language governing permissions and
17 // limitations under the License.
19 
20 #ifndef TESSERACT_LSTM_NETWORKBUILDER_H_
21 #define TESSERACT_LSTM_NETWORKBUILDER_H_
22 
23 #include "static_shape.h"
24 #include "stridemap.h"
25 
26 class STRING;
27 class UNICHARSET;
28 
29 namespace tesseract {
30 
31 class Input;
32 class Network;
33 class Parallel;
34 class TRand;
35 
37  public:
38  explicit NetworkBuilder(int num_softmax_outputs)
39  : num_softmax_outputs_(num_softmax_outputs) {}
40 
41  // Builds a network with a network_spec in the network description
42  // language, to recognize a character set of num_outputs size.
43  // If append_index is non-negative, then *network must be non-null and the
44  // given network_spec will be appended to *network AFTER append_index, with
45  // the top of the input *network discarded.
46  // Note that network_spec is call by value to allow a non-const char* pointer
47  // into the string for BuildFromString.
48  // net_flags control network behavior according to the NetworkFlags enum.
49  // The resulting network is returned via **network.
50  // Returns false if something failed.
51  static bool InitNetwork(int num_outputs, STRING network_spec,
52  int append_index, int net_flags, float weight_range,
53  TRand* randomizer, Network** network);
54 
55  // Parses the given string and returns a network according to the following
56  // language:
57  // ============ Syntax of description below: ============
58  // <d> represents a number.
59  // <net> represents any single network element, including (recursively) a
60  // [...] series or (...) parallel construct.
61  // (s|t|r|l|m) (regex notation) represents a single required letter.
62  // NOTE THAT THROUGHOUT, x and y are REVERSED from conventional mathematics,
63  // to use the same convention as Tensor Flow. The reason TF adopts this
64  // convention is to eliminate the need to transpose images on input, since
65  // adjacent memory locations in images increase x and then y, while adjacent
66  // memory locations in tensors in TF, and NetworkIO in tesseract increase the
67  // rightmost index first, then the next-left and so-on, like C arrays.
68  // ============ INPUTS ============
69  // <b>,<h>,<w>,<d> A batch of b images with height h, width w, and depth d.
70  // b, h and/or w may be zero, to indicate variable size. Some network layer
71  // (summarizing LSTM) must be used to make a variable h known.
72  // d may be 1 for greyscale, 3 for color.
73  // NOTE that throughout the constructed network, the inputs/outputs are all of
74  // the same [batch,height,width,depth] dimensions, even if a different size.
75  // ============ PLUMBING ============
76  // [...] Execute ... networks in series (layers).
77  // (...) Execute ... networks in parallel, with their output depths added.
78  // R<d><net> Execute d replicas of net in parallel, with their output depths
79  // added.
80  // Rx<net> Execute <net> with x-dimension reversal.
81  // Ry<net> Execute <net> with y-dimension reversal.
82  // S<y>,<x> Rescale 2-D input by shrink factor x,y, rearranging the data by
83  // increasing the depth of the input by factor xy.
84  // Mp<y>,<x> Maxpool the input, reducing the size by an (x,y) rectangle.
85  // ============ FUNCTIONAL UNITS ============
86  // C(s|t|r|l|m)<y>,<x>,<d> Convolves using a (x,y) window, with no shrinkage,
87  // random infill, producing d outputs, then applies a non-linearity:
88  // s: Sigmoid, t: Tanh, r: Relu, l: Linear, m: Softmax.
89  // F(s|t|r|l|m)<d> Truly fully-connected with s|t|r|l|m non-linearity and d
90  // outputs. Connects to every x,y,depth position of the input, reducing
91  // height, width to 1, producing a single <d> vector as the output.
92  // Input height and width must be constant.
93  // For a sliding-window linear or non-linear map that connects just to the
94  // input depth, and leaves the input image size as-is, use a 1x1 convolution
95  // eg. Cr1,1,64 instead of Fr64.
96  // L(f|r|b)(x|y)[s]<n> LSTM cell with n states/outputs.
97  // The LSTM must have one of:
98  // f runs the LSTM forward only.
99  // r runs the LSTM reversed only.
100  // b runs the LSTM bidirectionally.
101  // It will operate on either the x- or y-dimension, treating the other
102  // dimension independently (as if part of the batch).
103  // s (optional) summarizes the output in the requested dimension,
104  // outputting only the final step, collapsing the dimension to a
105  // single element.
106  // LS<n> Forward-only LSTM cell in the x-direction, with built-in Softmax.
107  // LE<n> Forward-only LSTM cell in the x-direction, with built-in softmax,
108  // with binary Encoding.
109  // L2xy<n> Full 2-d LSTM operating in quad-directions (bidi in x and y) and
110  // all the output depths added.
111  // ============ OUTPUTS ============
112  // The network description must finish with an output specification:
113  // O(2|1|0)(l|s|c)<n> output layer with n classes
114  // 2 (heatmap) Output is a 2-d vector map of the input (possibly at
115  // different scale).
116  // 1 (sequence) Output is a 1-d sequence of vector values.
117  // 0 (category) Output is a 0-d single vector value.
118  // l uses a logistic non-linearity on the output, allowing multiple
119  // hot elements in any output vector value.
120  // s uses a softmax non-linearity, with one-hot output in each value.
121  // c uses a softmax with CTC. Can only be used with s (sequence).
122  // NOTE1: Only O1s and O1c are currently supported.
123  // NOTE2: n is totally ignored, and for compatibility purposes only. The
124  // output number of classes is obtained automatically from the
125  // unicharset.
126  Network* BuildFromString(const StaticShape& input_shape, char** str);
127 
128  private:
129  // Parses an input specification and returns the result, which may include a
130  // series.
131  Network* ParseInput(char** str);
132  // Parses a sequential series of networks, defined by [<net><net>...].
133  Network* ParseSeries(const StaticShape& input_shape, Input* input_layer,
134  char** str);
135  // Parses a parallel set of networks, defined by (<net><net>...).
136  Network* ParseParallel(const StaticShape& input_shape, char** str);
137  // Parses a network that begins with 'R'.
138  Network* ParseR(const StaticShape& input_shape, char** str);
139  // Parses a network that begins with 'S'.
140  Network* ParseS(const StaticShape& input_shape, char** str);
141  // Parses a network that begins with 'C'.
142  Network* ParseC(const StaticShape& input_shape, char** str);
143  // Parses a network that begins with 'M'.
144  Network* ParseM(const StaticShape& input_shape, char** str);
145  // Parses an LSTM network, either individual, bi- or quad-directional.
146  Network* ParseLSTM(const StaticShape& input_shape, char** str);
147  // Builds a set of 4 lstms with t and y reversal, running in true parallel.
148  static Network* BuildLSTMXYQuad(int num_inputs, int num_states);
149  // Parses a Fully connected network.
150  Network* ParseFullyConnected(const StaticShape& input_shape, char** str);
151  // Parses an Output spec.
152  Network* ParseOutput(const StaticShape& input_shape, char** str);
153 
154  private:
156 };
157 
158 } // namespace tesseract.
159 
160 #endif // TESSERACT_LSTM_NETWORKBUILDER_H_
static bool InitNetwork(int num_outputs, STRING network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
Definition: networkbuilder.cpp:45
Network * ParseC(const StaticShape &input_shape, char **str)
Definition: networkbuilder.cpp:266
Definition: helpers.h:42
Definition: static_shape.h:38
Definition: unicharset.h:146
Network * ParseSeries(const StaticShape &input_shape, Input *input_layer, char **str)
Definition: networkbuilder.cpp:146
Definition: baseapi.cpp:94
Network * ParseInput(char **str)
Definition: networkbuilder.cpp:123
Network * ParseR(const StaticShape &input_shape, char **str)
Definition: networkbuilder.cpp:190
NetworkBuilder(int num_softmax_outputs)
Definition: networkbuilder.h:38
Definition: networkbuilder.h:36
Definition: network.h:105
Network * ParseM(const StaticShape &input_shape, char **str)
Definition: networkbuilder.cpp:294
Definition: strngs.h:45
Network * ParseLSTM(const StaticShape &input_shape, char **str)
Definition: networkbuilder.cpp:305
static Network * BuildLSTMXYQuad(int num_inputs, int num_states)
Definition: networkbuilder.cpp:377
int num_softmax_outputs_
Definition: networkbuilder.h:155
Definition: input.h:28
Network * ParseParallel(const StaticShape &input_shape, char **str)
Definition: networkbuilder.cpp:171
Network * ParseOutput(const StaticShape &input_shape, char **str)
Definition: networkbuilder.cpp:439
Network * ParseS(const StaticShape &input_shape, char **str)
Definition: networkbuilder.cpp:225
Network * BuildFromString(const StaticShape &input_shape, char **str)
Definition: networkbuilder.cpp:86
Network * ParseFullyConnected(const StaticShape &input_shape, char **str)
Definition: networkbuilder.cpp:421