TensorRT  7.2.1.6
NVIDIA TensorRT
Looking for a C++ dev who knows TensorRT?
I'm looking for work. Hire me!
nmtSample::MultiplicativeAlignment Class Reference

alignment scores from Luong attention mechanism More...

Inheritance diagram for nmtSample::MultiplicativeAlignment:
Collaboration diagram for nmtSample::MultiplicativeAlignment:

Public Types

typedef std::shared_ptr< Alignmentptr
 

Public Member Functions

 MultiplicativeAlignment (ComponentWeights::ptr weights)
 
void addToModel (nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *attentionKeys, nvinfer1::ITensor *queryStates, nvinfer1::ITensor **alignmentScores) override
 add the alignment scores calculation to the network More...
 
void addAttentionKeys (nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *memoryStates, nvinfer1::ITensor **attentionKeys) override
 add attention keys calculation (from source memory states) to the network More...
 
int getSourceStatesSize () override
 get the size of the source states More...
 
int getAttentionKeySize () override
 get the size of the attention keys More...
 
std::string getInfo () override
 get the textual description of the component More...
 
 ~MultiplicativeAlignment () override=default
 

Protected Attributes

ComponentWeights::ptr mWeights
 
nvinfer1::Weights mKernelWeights
 
int mInputChannelCount
 
int mOutputChannelCount
 

Detailed Description

alignment scores from Luong attention mechanism

Member Typedef Documentation

◆ ptr

typedef std::shared_ptr<Alignment> nmtSample::Alignment::ptr
inherited

Constructor & Destructor Documentation

◆ MultiplicativeAlignment()

nmtSample::MultiplicativeAlignment::MultiplicativeAlignment ( ComponentWeights::ptr  weights)

◆ ~MultiplicativeAlignment()

nmtSample::MultiplicativeAlignment::~MultiplicativeAlignment ( )
overridedefault

Member Function Documentation

◆ addToModel()

void nmtSample::MultiplicativeAlignment::addToModel ( nvinfer1::INetworkDefinition network,
nvinfer1::ITensor attentionKeys,
nvinfer1::ITensor queryStates,
nvinfer1::ITensor **  alignmentScores 
)
overridevirtual

add the alignment scores calculation to the network

Implements nmtSample::Alignment.

Here is the call graph for this function:

◆ addAttentionKeys()

void nmtSample::MultiplicativeAlignment::addAttentionKeys ( nvinfer1::INetworkDefinition network,
nvinfer1::ITensor memoryStates,
nvinfer1::ITensor **  attentionKeys 
)
overridevirtual

add attention keys calculation (from source memory states) to the network

The funtion is called if getAttentionKeySize returns positive value

Implements nmtSample::Alignment.

Here is the call graph for this function:

◆ getSourceStatesSize()

int nmtSample::MultiplicativeAlignment::getSourceStatesSize ( )
overridevirtual

get the size of the source states

Implements nmtSample::Alignment.

◆ getAttentionKeySize()

int nmtSample::MultiplicativeAlignment::getAttentionKeySize ( )
overridevirtual

get the size of the attention keys

Implements nmtSample::Alignment.

◆ getInfo()

std::string nmtSample::MultiplicativeAlignment::getInfo ( )
overridevirtual

get the textual description of the component

Implements nmtSample::Component.

Member Data Documentation

◆ mWeights

ComponentWeights::ptr nmtSample::MultiplicativeAlignment::mWeights
protected

◆ mKernelWeights

nvinfer1::Weights nmtSample::MultiplicativeAlignment::mKernelWeights
protected

◆ mInputChannelCount

int nmtSample::MultiplicativeAlignment::mInputChannelCount
protected

◆ mOutputChannelCount

int nmtSample::MultiplicativeAlignment::mOutputChannelCount
protected

The documentation for this class was generated from the following files: