represents the core of attention mechanism More...
Public Types | |
typedef std::shared_ptr< Alignment > | ptr |
Public Member Functions | |
Alignment ()=default | |
virtual void | addToModel (nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *attentionKeys, nvinfer1::ITensor *queryStates, nvinfer1::ITensor **alignmentScores)=0 |
add the alignment scores calculation to the network More... | |
virtual void | addAttentionKeys (nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *memoryStates, nvinfer1::ITensor **attentionKeys)=0 |
add attention keys calculation (from source memory states) to the network More... | |
virtual int | getSourceStatesSize ()=0 |
get the size of the source states More... | |
virtual int | getAttentionKeySize ()=0 |
get the size of the attention keys More... | |
~Alignment () override=default | |
virtual std::string | getInfo ()=0 |
get the textual description of the component More... | |
represents the core of attention mechanism
typedef std::shared_ptr<Alignment> nmtSample::Alignment::ptr |
|
default |
|
overridedefault |
|
pure virtual |
add the alignment scores calculation to the network
Implemented in nmtSample::MultiplicativeAlignment.
|
pure virtual |
add attention keys calculation (from source memory states) to the network
The funtion is called if getAttentionKeySize returns positive value
Implemented in nmtSample::MultiplicativeAlignment.
|
pure virtual |
get the size of the source states
Implemented in nmtSample::MultiplicativeAlignment.
|
pure virtual |
get the size of the attention keys
Implemented in nmtSample::MultiplicativeAlignment.
|
pure virtualinherited |
get the textual description of the component
Implemented in nmtSample::SoftmaxLikelihood, nmtSample::BeamSearchPolicy, nmtSample::MultiplicativeAlignment, nmtSample::SLPEmbedder, nmtSample::SLPProjection, nmtSample::BLEUScoreWriter, nmtSample::Context, nmtSample::LSTMEncoder, nmtSample::TextWriter, nmtSample::SLPAttention, nmtSample::BenchmarkWriter, nmtSample::TextReader, nmtSample::LSTMDecoder, and nmtSample::LimitedSamplesDataReader.