encodes input sentences into output states using LSTM More...
Public Types | |
typedef std::shared_ptr< Encoder > | ptr |
Public Member Functions | |
LSTMEncoder (ComponentWeights::ptr weights) | |
void | addToModel (nvinfer1::INetworkDefinition *network, int maxInputSequenceLength, nvinfer1::ITensor *inputEmbeddedData, nvinfer1::ITensor *actualInputSequenceLengths, nvinfer1::ITensor **inputStates, nvinfer1::ITensor **memoryStates, nvinfer1::ITensor **lastTimestepStates) override |
add the memory and last timestep states to the network lastTimestepHiddenStates is the pointer to the tensor where the encoder stores all layer hidden states for the last timestep (which is dependent on the sample), the function should define the tensor, it could be nullptr indicating these data are not needed More... | |
int | getMemoryStatesSize () override |
get the size of the memory state vector More... | |
std::vector< nvinfer1::Dims > | getStateSizes () override |
get the sizes (vector of them) of the hidden state vectors More... | |
std::string | getInfo () override |
get the textual description of the component More... | |
~LSTMEncoder () override=default | |
Protected Attributes | |
ComponentWeights::ptr | mWeights |
std::vector< nvinfer1::Weights > | mGateKernelWeights |
std::vector< nvinfer1::Weights > | mGateBiasWeights |
bool | mRNNKind |
int | mNumLayers |
int | mNumUnits |
encodes input sentences into output states using LSTM
|
inherited |
nmtSample::LSTMEncoder::LSTMEncoder | ( | ComponentWeights::ptr | weights | ) |
|
overridedefault |
|
overridevirtual |
add the memory and last timestep states to the network lastTimestepHiddenStates is the pointer to the tensor where the encoder stores all layer hidden states for the last timestep (which is dependent on the sample), the function should define the tensor, it could be nullptr indicating these data are not needed
Implements nmtSample::Encoder.
|
overridevirtual |
get the size of the memory state vector
Implements nmtSample::Encoder.
|
overridevirtual |
get the sizes (vector of them) of the hidden state vectors
Implements nmtSample::Encoder.
|
overridevirtual |
get the textual description of the component
Implements nmtSample::Component.
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |