processes the results of one iteration of the generator with beam search and produces input for the next iteration More...
Classes | |
struct | Ray |
Public Types | |
typedef std::shared_ptr< BeamSearchPolicy > | ptr |
Public Member Functions | |
BeamSearchPolicy (int endSequenceId, LikelihoodCombinationOperator::ptr likelihoodCombinationOperator, int beamWidth) | |
void | initialize (int sampleCount, int *maxOutputSequenceLengths) |
void | processTimestep (int validSampleCount, const float *hCombinedLikelihoods, const int *hVocabularyIndices, const int *hRayOptionIndices, int *hSourceRayIndices, float *hSourceLikelihoods) |
int | getTailWithNoWorkRemaining () |
void | readGeneratedResult (int sampleCount, int maxOutputSequenceLength, int *hOutputData, int *hActualOutputSequenceLengths) |
std::string | getInfo () override |
get the textual description of the component More... | |
~BeamSearchPolicy () override=default | |
Protected Member Functions | |
void | backtrack (int lastTimestepId, int sampleId, int lastTimestepRayId, int *hOutputData, int lastTimestepWriteId) const |
Protected Attributes | |
int | mEndSequenceId |
LikelihoodCombinationOperator::ptr | mLikelihoodCombinationOperator |
int | mBeamWidth |
std::vector< bool > | mValidSamples |
std::vector< float > | mCurrentLikelihoods |
std::vector< Ray > | mBeamSearchTable |
int | mSampleCount |
std::vector< int > | mMaxOutputSequenceLengths |
int | mTimestepId |
std::vector< std::vector< int > > | mCandidates |
std::vector< float > | mCandidateLikelihoods |
processes the results of one iteration of the generator with beam search and produces input for the next iteration
typedef std::shared_ptr<BeamSearchPolicy> nmtSample::BeamSearchPolicy::ptr |
nmtSample::BeamSearchPolicy::BeamSearchPolicy | ( | int | endSequenceId, |
LikelihoodCombinationOperator::ptr | likelihoodCombinationOperator, | ||
int | beamWidth | ||
) |
|
overridedefault |
void nmtSample::BeamSearchPolicy::processTimestep | ( | int | validSampleCount, |
const float * | hCombinedLikelihoods, | ||
const int * | hVocabularyIndices, | ||
const int * | hRayOptionIndices, | ||
int * | hSourceRayIndices, | ||
float * | hSourceLikelihoods | ||
) |
int nmtSample::BeamSearchPolicy::getTailWithNoWorkRemaining | ( | ) |
void nmtSample::BeamSearchPolicy::readGeneratedResult | ( | int | sampleCount, |
int | maxOutputSequenceLength, | ||
int * | hOutputData, | ||
int * | hActualOutputSequenceLengths | ||
) |
|
overridevirtual |
get the textual description of the component
Implements nmtSample::Component.
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |