TensorRT  7.2.1.6
NVIDIA TensorRT
Looking for a C++ dev who knows TensorRT?
I'm looking for work. Hire me!
nmtSample::SLPAttention Class Reference

Linear attention calculation. More...

Inheritance diagram for nmtSample::SLPAttention:
Collaboration diagram for nmtSample::SLPAttention:

Public Types

typedef std::shared_ptr< Attentionptr
 

Public Member Functions

 SLPAttention (ComponentWeights::ptr weights)
 
void addToModel (nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *inputFromDecoder, nvinfer1::ITensor *context, nvinfer1::ITensor **attentionOutput) override
 add the attention vector calculation to the network More...
 
int getAttentionSize () override
 get the size of the attention vector More...
 
std::string getInfo () override
 get the textual description of the component More...
 

Protected Attributes

ComponentWeights::ptr mWeights
 
nvinfer1::Weights mKernelWeights
 
int mInputChannelCount
 
int mOutputChannelCount
 

Detailed Description

Linear attention calculation.

Calculates attention vector by concatinating input from the decoder with context vector and projecting the result into attention space by multiplying with weight matrix

Member Typedef Documentation

◆ ptr

typedef std::shared_ptr<Attention> nmtSample::Attention::ptr
inherited

Constructor & Destructor Documentation

◆ SLPAttention()

nmtSample::SLPAttention::SLPAttention ( ComponentWeights::ptr  weights)

Member Function Documentation

◆ addToModel()

void nmtSample::SLPAttention::addToModel ( nvinfer1::INetworkDefinition network,
nvinfer1::ITensor inputFromDecoder,
nvinfer1::ITensor context,
nvinfer1::ITensor **  attentionOutput 
)
overridevirtual

add the attention vector calculation to the network

Implements nmtSample::Attention.

Here is the call graph for this function:

◆ getAttentionSize()

int nmtSample::SLPAttention::getAttentionSize ( )
overridevirtual

get the size of the attention vector

Implements nmtSample::Attention.

◆ getInfo()

std::string nmtSample::SLPAttention::getInfo ( )
overridevirtual

get the textual description of the component

Implements nmtSample::Component.

Member Data Documentation

◆ mWeights

ComponentWeights::ptr nmtSample::SLPAttention::mWeights
protected

◆ mKernelWeights

nvinfer1::Weights nmtSample::SLPAttention::mKernelWeights
protected

◆ mInputChannelCount

int nmtSample::SLPAttention::mInputChannelCount
protected

◆ mOutputChannelCount

int nmtSample::SLPAttention::mOutputChannelCount
protected

The documentation for this class was generated from the following files: