Linear attention calculation. More...
Public Types | |
typedef std::shared_ptr< Attention > | ptr |
Public Member Functions | |
SLPAttention (ComponentWeights::ptr weights) | |
void | addToModel (nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *inputFromDecoder, nvinfer1::ITensor *context, nvinfer1::ITensor **attentionOutput) override |
add the attention vector calculation to the network More... | |
int | getAttentionSize () override |
get the size of the attention vector More... | |
std::string | getInfo () override |
get the textual description of the component More... | |
Protected Attributes | |
ComponentWeights::ptr | mWeights |
nvinfer1::Weights | mKernelWeights |
int | mInputChannelCount |
int | mOutputChannelCount |
Linear attention calculation.
Calculates attention vector by concatinating input from the decoder with context vector and projecting the result into attention space by multiplying with weight matrix
|
inherited |
nmtSample::SLPAttention::SLPAttention | ( | ComponentWeights::ptr | weights | ) |
|
overridevirtual |
add the attention vector calculation to the network
Implements nmtSample::Attention.
|
overridevirtual |
get the size of the attention vector
Implements nmtSample::Attention.
|
overridevirtual |
get the textual description of the component
Implements nmtSample::Component.
|
protected |
|
protected |
|
protected |
|
protected |