TensorRT  7.2.1.6
NVIDIA TensorRT
Looking for a C++ dev who knows TensorRT?
I'm looking for work. Hire me!
nmtSample::BeamSearchPolicy Class Reference

processes the results of one iteration of the generator with beam search and produces input for the next iteration More...

Inheritance diagram for nmtSample::BeamSearchPolicy:
Collaboration diagram for nmtSample::BeamSearchPolicy:

Classes

struct  Ray
 

Public Types

typedef std::shared_ptr< BeamSearchPolicyptr
 

Public Member Functions

 BeamSearchPolicy (int endSequenceId, LikelihoodCombinationOperator::ptr likelihoodCombinationOperator, int beamWidth)
 
void initialize (int sampleCount, int *maxOutputSequenceLengths)
 
void processTimestep (int validSampleCount, const float *hCombinedLikelihoods, const int *hVocabularyIndices, const int *hRayOptionIndices, int *hSourceRayIndices, float *hSourceLikelihoods)
 
int getTailWithNoWorkRemaining ()
 
void readGeneratedResult (int sampleCount, int maxOutputSequenceLength, int *hOutputData, int *hActualOutputSequenceLengths)
 
std::string getInfo () override
 get the textual description of the component More...
 
 ~BeamSearchPolicy () override=default
 

Protected Member Functions

void backtrack (int lastTimestepId, int sampleId, int lastTimestepRayId, int *hOutputData, int lastTimestepWriteId) const
 

Protected Attributes

int mEndSequenceId
 
LikelihoodCombinationOperator::ptr mLikelihoodCombinationOperator
 
int mBeamWidth
 
std::vector< bool > mValidSamples
 
std::vector< float > mCurrentLikelihoods
 
std::vector< RaymBeamSearchTable
 
int mSampleCount
 
std::vector< intmMaxOutputSequenceLengths
 
int mTimestepId
 
std::vector< std::vector< int > > mCandidates
 
std::vector< float > mCandidateLikelihoods
 

Detailed Description

processes the results of one iteration of the generator with beam search and produces input for the next iteration

Member Typedef Documentation

◆ ptr

Constructor & Destructor Documentation

◆ BeamSearchPolicy()

nmtSample::BeamSearchPolicy::BeamSearchPolicy ( int  endSequenceId,
LikelihoodCombinationOperator::ptr  likelihoodCombinationOperator,
int  beamWidth 
)

◆ ~BeamSearchPolicy()

nmtSample::BeamSearchPolicy::~BeamSearchPolicy ( )
overridedefault

Member Function Documentation

◆ initialize()

void nmtSample::BeamSearchPolicy::initialize ( int  sampleCount,
int maxOutputSequenceLengths 
)

◆ processTimestep()

void nmtSample::BeamSearchPolicy::processTimestep ( int  validSampleCount,
const float *  hCombinedLikelihoods,
const int hVocabularyIndices,
const int hRayOptionIndices,
int hSourceRayIndices,
float *  hSourceLikelihoods 
)
Here is the call graph for this function:

◆ getTailWithNoWorkRemaining()

int nmtSample::BeamSearchPolicy::getTailWithNoWorkRemaining ( )

◆ readGeneratedResult()

void nmtSample::BeamSearchPolicy::readGeneratedResult ( int  sampleCount,
int  maxOutputSequenceLength,
int hOutputData,
int hActualOutputSequenceLengths 
)
Here is the call graph for this function:

◆ getInfo()

std::string nmtSample::BeamSearchPolicy::getInfo ( )
overridevirtual

get the textual description of the component

Implements nmtSample::Component.

◆ backtrack()

void nmtSample::BeamSearchPolicy::backtrack ( int  lastTimestepId,
int  sampleId,
int  lastTimestepRayId,
int hOutputData,
int  lastTimestepWriteId 
) const
protected
Here is the caller graph for this function:

Member Data Documentation

◆ mEndSequenceId

int nmtSample::BeamSearchPolicy::mEndSequenceId
protected

◆ mLikelihoodCombinationOperator

LikelihoodCombinationOperator::ptr nmtSample::BeamSearchPolicy::mLikelihoodCombinationOperator
protected

◆ mBeamWidth

int nmtSample::BeamSearchPolicy::mBeamWidth
protected

◆ mValidSamples

std::vector<bool> nmtSample::BeamSearchPolicy::mValidSamples
protected

◆ mCurrentLikelihoods

std::vector<float> nmtSample::BeamSearchPolicy::mCurrentLikelihoods
protected

◆ mBeamSearchTable

std::vector<Ray> nmtSample::BeamSearchPolicy::mBeamSearchTable
protected

◆ mSampleCount

int nmtSample::BeamSearchPolicy::mSampleCount
protected

◆ mMaxOutputSequenceLengths

std::vector<int> nmtSample::BeamSearchPolicy::mMaxOutputSequenceLengths
protected

◆ mTimestepId

int nmtSample::BeamSearchPolicy::mTimestepId
protected

◆ mCandidates

std::vector<std::vector<int> > nmtSample::BeamSearchPolicy::mCandidates
protected

◆ mCandidateLikelihoods

std::vector<float> nmtSample::BeamSearchPolicy::mCandidateLikelihoods
protected

The documentation for this class was generated from the following files: