The SampleCharRNNBase class implements the char_rnn sample. More...
Public Types | |
template<typename T > | |
using | SampleUniquePtr = std::unique_ptr< T, samplesCommon::InferDeleter > |
Public Member Functions | |
SampleCharRNNBase (const SampleCharRNNParams ¶ms) | |
virtual | ~SampleCharRNNBase ()=default |
bool | build () |
Builds the network engine. More... | |
bool | infer () |
Runs the TensorRT inference engine for this sample. More... | |
bool | teardown () |
Used to clean up any state created in the sample class. More... | |
Protected Member Functions | |
virtual nvinfer1::ILayer * | addLSTMLayers (SampleUniquePtr< nvinfer1::INetworkDefinition > &network)=0 |
Add inputs to the TensorRT network and configure LSTM layers using network definition API. More... | |
nvinfer1::Weights | convertRNNWeights (nvinfer1::Weights input, int dataSize) |
Converts RNN weights from TensorFlow's format to TensorRT's format. More... | |
nvinfer1::Weights | convertRNNBias (nvinfer1::Weights input) |
Converts RNN Biases from TensorFlow's format to TensorRT's format. More... | |
nvinfer1::ITensor * | addReshape (SampleUniquePtr< nvinfer1::INetworkDefinition > &network, nvinfer1::ITensor &tensor, nvinfer1::Dims dims) |
Protected Attributes | |
std::map< std::string, nvinfer1::Weights > | mWeightMap |
std::vector< SampleUniquePtr< nvinfer1::IHostMemory > > | weightsMemory |
SampleCharRNNParams | mParams |
Private Member Functions | |
std::map< std::string, nvinfer1::Weights > | loadWeights (const std::string file) |
Load requested weights from a formatted file into a map. More... | |
void | constructNetwork (SampleUniquePtr< nvinfer1::IBuilder > &builder, SampleUniquePtr< nvinfer1::INetworkDefinition > &network, SampleUniquePtr< nvinfer1::IBuilderConfig > &config) |
Create full model using the TensorRT network definition API and build the engine. More... | |
void | copyEmbeddingToInput (samplesCommon::BufferManager &buffers, const char &c) |
Looks up the embedding tensor for a given char and copies it to input buffer. More... | |
bool | stepOnce (samplesCommon::BufferManager &buffers, SampleUniquePtr< nvinfer1::IExecutionContext > &context, cudaStream_t &stream) |
Perform one time step of inference with the TensorRT execution context. More... | |
void | copyRNNOutputsToInputs (samplesCommon::BufferManager &buffers) |
Copies Ct/Ht output from the RNN to the Ct-1/Ht-1 input buffers for next time step. More... | |
Private Attributes | |
std::shared_ptr< nvinfer1::ICudaEngine > | mEngine {nullptr} |
The TensorRT engine used to run the network. More... | |
The SampleCharRNNBase class implements the char_rnn sample.
It uses weights from a trained TensorFlow model and creates the network using the TensorRT network definition API
using SampleCharRNNBase::SampleUniquePtr = std::unique_ptr<T, samplesCommon::InferDeleter> |
|
inline |
|
virtualdefault |
bool SampleCharRNNBase::build | ( | ) |
Builds the network engine.
Creates the network, configures the builder and creates the network engine.
This function loads weights from a trained TensorFlow model, creates the network using the TensorRT network definition API, and builds a TensorRT engine.
bool SampleCharRNNBase::infer | ( | ) |
Runs the TensorRT inference engine for this sample.
This function is the main execution function of the sample. It allocates the buffer, sets inputs, executes the engine, and verifies the output.
bool SampleCharRNNBase::teardown | ( | ) |
Used to clean up any state created in the sample class.
|
protectedpure virtual |
Add inputs to the TensorRT network and configure LSTM layers using network definition API.
Implemented in SampleCharRNNLoop, and SampleCharRNNv2.
|
protected |
Converts RNN weights from TensorFlow's format to TensorRT's format.
input | Weights that are stored in TensorFlow's format. |
TensorRT expects the format to laid out in memory: CellN: Wi, Wc, Wf, Wo, Ri, Rc, Rf, Ro
|
protected |
Converts RNN Biases from TensorFlow's format to TensorRT's format.
input | Biases that are stored in TensorFlow's format. |
TensorRT expects the format to be: CellN: Wi, Wc, Wf, Wo, Ri, Rc, Rf, Ro
Since tensorflow already combines U and W, we double the size and set all of U to zero.
|
protected |
|
private |
Load requested weights from a formatted file into a map.
file | Path to weights file. File has to be the formatted dump from the dumpTFWts.py script. Otherwise, this function will not work as intended. |
|
private |
Create full model using the TensorRT network definition API and build the engine.
weightMap | Map that contains all the weights required by the model. |
modelStream | The stream within which the engine is serialized once built. |
|
private |
Looks up the embedding tensor for a given char and copies it to input buffer.
|
private |
Perform one time step of inference with the TensorRT execution context.
|
private |
Copies Ct/Ht output from the RNN to the Ct-1/Ht-1 input buffers for next time step.
|
protected |
|
protected |
|
protected |
|
private |
The TensorRT engine used to run the network.