encodes input sentences into output states More...
Public Types | |
typedef std::shared_ptr< Encoder > | ptr |
Public Member Functions | |
Encoder ()=default | |
virtual void | addToModel (nvinfer1::INetworkDefinition *network, int maxInputSequenceLength, nvinfer1::ITensor *inputEmbeddedData, nvinfer1::ITensor *actualInputSequenceLengths, nvinfer1::ITensor **inputStates, nvinfer1::ITensor **memoryStates, nvinfer1::ITensor **lastTimestepStates)=0 |
add the memory and last timestep states to the network lastTimestepHiddenStates is the pointer to the tensor where the encoder stores all layer hidden states for the last timestep (which is dependent on the sample), the function should define the tensor, it could be nullptr indicating these data are not needed More... | |
virtual int | getMemoryStatesSize ()=0 |
get the size of the memory state vector More... | |
virtual std::vector< nvinfer1::Dims > | getStateSizes ()=0 |
get the sizes (vector of them) of the hidden state vectors More... | |
~Encoder () override=default | |
virtual std::string | getInfo ()=0 |
get the textual description of the component More... | |
encodes input sentences into output states
typedef std::shared_ptr<Encoder> nmtSample::Encoder::ptr |
|
default |
|
overridedefault |
|
pure virtual |
add the memory and last timestep states to the network lastTimestepHiddenStates is the pointer to the tensor where the encoder stores all layer hidden states for the last timestep (which is dependent on the sample), the function should define the tensor, it could be nullptr indicating these data are not needed
Implemented in nmtSample::LSTMEncoder.
|
pure virtual |
get the size of the memory state vector
Implemented in nmtSample::LSTMEncoder.
|
pure virtual |
get the sizes (vector of them) of the hidden state vectors
Implemented in nmtSample::LSTMEncoder.
|
pure virtualinherited |
get the textual description of the component
Implemented in nmtSample::SoftmaxLikelihood, nmtSample::BeamSearchPolicy, nmtSample::MultiplicativeAlignment, nmtSample::SLPEmbedder, nmtSample::SLPProjection, nmtSample::BLEUScoreWriter, nmtSample::Context, nmtSample::LSTMEncoder, nmtSample::TextWriter, nmtSample::SLPAttention, nmtSample::BenchmarkWriter, nmtSample::TextReader, nmtSample::LSTMDecoder, and nmtSample::LimitedSamplesDataReader.