alignment scores from Luong attention mechanism More...
Public Types | |
typedef std::shared_ptr< Alignment > | ptr |
Public Member Functions | |
MultiplicativeAlignment (ComponentWeights::ptr weights) | |
void | addToModel (nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *attentionKeys, nvinfer1::ITensor *queryStates, nvinfer1::ITensor **alignmentScores) override |
add the alignment scores calculation to the network More... | |
void | addAttentionKeys (nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *memoryStates, nvinfer1::ITensor **attentionKeys) override |
add attention keys calculation (from source memory states) to the network More... | |
int | getSourceStatesSize () override |
get the size of the source states More... | |
int | getAttentionKeySize () override |
get the size of the attention keys More... | |
std::string | getInfo () override |
get the textual description of the component More... | |
~MultiplicativeAlignment () override=default | |
Protected Attributes | |
ComponentWeights::ptr | mWeights |
nvinfer1::Weights | mKernelWeights |
int | mInputChannelCount |
int | mOutputChannelCount |
alignment scores from Luong attention mechanism
|
inherited |
nmtSample::MultiplicativeAlignment::MultiplicativeAlignment | ( | ComponentWeights::ptr | weights | ) |
|
overridedefault |
|
overridevirtual |
add the alignment scores calculation to the network
Implements nmtSample::Alignment.
|
overridevirtual |
add attention keys calculation (from source memory states) to the network
The funtion is called if getAttentionKeySize returns positive value
Implements nmtSample::Alignment.
|
overridevirtual |
get the size of the source states
Implements nmtSample::Alignment.
|
overridevirtual |
get the size of the attention keys
Implements nmtSample::Alignment.
|
overridevirtual |
get the textual description of the component
Implements nmtSample::Component.
|
protected |
|
protected |
|
protected |
|
protected |