calculates context vector from raw alignment scores and memory states More...
Public Types | |
typedef std::shared_ptr< Context > | ptr |
Public Member Functions | |
Context ()=default | |
void | addToModel (nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *actualInputSequenceLengths, nvinfer1::ITensor *memoryStates, nvinfer1::ITensor *alignmentScores, nvinfer1::ITensor **contextOutput) |
add the context vector calculation to the network More... | |
std::string | getInfo () override |
get the textual description of the component More... | |
~Context () override=default | |
calculates context vector from raw alignment scores and memory states
typedef std::shared_ptr<Context> nmtSample::Context::ptr |
|
default |
|
overridedefault |
void nmtSample::Context::addToModel | ( | nvinfer1::INetworkDefinition * | network, |
nvinfer1::ITensor * | actualInputSequenceLengths, | ||
nvinfer1::ITensor * | memoryStates, | ||
nvinfer1::ITensor * | alignmentScores, | ||
nvinfer1::ITensor ** | contextOutput | ||
) |
add the context vector calculation to the network
|
overridevirtual |
get the textual description of the component
Implements nmtSample::Component.