Namespaces | |
anonymous_namespace{embLayerNormVarSeqlenPlugin.cpp} | |
anonymous_namespace{qkvToContextInt8InterleavedPlugin.cpp} | |
anonymous_namespace{skipLayerNormInt8InterleavedPlugin.cpp} | |
Enumerations | |
enum | Data_type { DATA_TYPE_BOOL, DATA_TYPE_E8M10, DATA_TYPE_E8M7, DATA_TYPE_FP16, DATA_TYPE_FP32, DATA_TYPE_INT4, DATA_TYPE_INT8, DATA_TYPE_INT32 } |
Functions | |
static size_t | get_size_in_bytes (size_t n, Data_type dtype) |
const FusedMultiHeadAttentionXMMAKernel * | getXMMAKernels (Data_type type, unsigned int sm) |
const FusedMultiHeadAttentionXMMAKernelV2 * | getXMMAKernelsV2 (Data_type type, unsigned int sm) |
REGISTER_TENSORRT_PLUGIN (QKVToContextInterleavedPluginCreator) | |
REGISTER_TENSORRT_PLUGIN (EmbLayerNormVarSeqlenPluginCreator) | |
template<typename T > | |
int | embSkipLayerNormVarSeqlen (cudaStream_t stream, int ld, int B, int S, const uint32_t *cuSeqlens, const int *inputIds, const int *token_ids, const T *beta, const T *gamma, const T *wordEmb, const T *posEmb, const T *tokEmb, T *output) |
template<typename T > | |
int | embSkipLayerNorm2 (cudaStream_t stream, int ld, int B, int S, const int *inputIds, const int *tokenIds, const int *cuSeqlens, const float *beta, const float *gamma, const T *wordEmb, const T *posEmb, const T *tokEmb, T *output) |
void | cuSeqlensToPackedMask (const uint32_t S, const uint32_t B, const uint32_t warps_m, const uint32_t warps_n, const uint32_t warps_k, const int *cuSeqlens, uint32_t *inputMaskX, cudaStream_t stream) |
void | launch_small (cudaStream_t stream, const int ld, const int total, const int8_t *input, const int8_t *skip, const half *beta, const half *gamma, int8_t *output, const float dqScaleIn, const float dqScaleSkip, const float qScale) |
void | launch_large (cudaStream_t stream, const int ld, const int total, const int8_t *input, const int8_t *skip, const half *beta, const half *gamma, int8_t *output, const float dqScaleIn, const float dqScaleSkip, const float qScale) |
REGISTER_TENSORRT_PLUGIN (SkipLayerNormInterleavedPluginCreator) | |
static DataType | getParamWordType (DataType cfgType) |
using bert::FusedMultiHeadAttentionXMMAKernel = typedef TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionKernelMetaInfoV1, Fused_multihead_attention_params> |
using bert::FusedMHAKernelFactory = typedef TFusedMHAKernelFactory<FusedMultiHeadAttentionXMMAKernel> |
using bert::FusedMHAKernelFactoryV2 = typedef TFusedMHAKernelFactory<FusedMultiHeadAttentionXMMAKernelV2> |
enum bert::Data_type |
|
inlinestatic |
|
inline |
|
inline |
bert::REGISTER_TENSORRT_PLUGIN | ( | QKVToContextInterleavedPluginCreator | ) |
bert::REGISTER_TENSORRT_PLUGIN | ( | EmbLayerNormVarSeqlenPluginCreator | ) |
int bert::embSkipLayerNormVarSeqlen | ( | cudaStream_t | stream, |
int | ld, | ||
int | B, | ||
int | S, | ||
const uint32_t * | cuSeqlens, | ||
const int * | inputIds, | ||
const int * | token_ids, | ||
const T * | beta, | ||
const T * | gamma, | ||
const T * | wordEmb, | ||
const T * | posEmb, | ||
const T * | tokEmb, | ||
T * | output | ||
) |
int bert::embSkipLayerNorm2 | ( | cudaStream_t | stream, |
int | ld, | ||
int | B, | ||
int | S, | ||
const int * | inputIds, | ||
const int * | tokenIds, | ||
const int * | cuSeqlens, | ||
const float * | beta, | ||
const float * | gamma, | ||
const T * | wordEmb, | ||
const T * | posEmb, | ||
const T * | tokEmb, | ||
T * | output | ||
) |
void bert::cuSeqlensToPackedMask | ( | const uint32_t | S, |
const uint32_t | B, | ||
const uint32_t | warps_m, | ||
const uint32_t | warps_n, | ||
const uint32_t | warps_k, | ||
const int * | cuSeqlens, | ||
uint32_t * | inputMaskX, | ||
cudaStream_t | stream | ||
) |
void bert::launch_small | ( | cudaStream_t | stream, |
const int | ld, | ||
const int | total, | ||
const int8_t * | input, | ||
const int8_t * | skip, | ||
const half * | beta, | ||
const half * | gamma, | ||
int8_t * | output, | ||
const float | dqScaleIn, | ||
const float | dqScaleSkip, | ||
const float | qScale | ||
) |
void bert::launch_large | ( | cudaStream_t | stream, |
const int | ld, | ||
const int | total, | ||
const int8_t * | input, | ||
const int8_t * | skip, | ||
const half * | beta, | ||
const half * | gamma, | ||
int8_t * | output, | ||
const float | dqScaleIn, | ||
const float | dqScaleSkip, | ||
const float | qScale | ||
) |
bert::REGISTER_TENSORRT_PLUGIN | ( | SkipLayerNormInterleavedPluginCreator | ) |
unsigned char bert::fused_multihead_attention_fp16_64_64_kernel_sm75_cu_o |
unsigned char bert::fused_multihead_attention_fp16_96_64_kernel_sm75_cu_o |
unsigned char bert::fused_multihead_attention_fp16_64_64_kernel_sm80_cu_o |
unsigned char bert::fused_multihead_attention_fp16_96_64_kernel_sm80_cu_o |
unsigned char bert::fused_multihead_attention_fp16_128_64_kernel_sm75_cu_o |
unsigned char bert::fused_multihead_attention_fp16_384_64_kernel_sm75_cu_o |
unsigned char bert::fused_multihead_attention_int8_128_64_kernel_sm75_cu_o |
unsigned char bert::fused_multihead_attention_int8_384_64_kernel_sm75_cu_o |
unsigned char bert::fused_multihead_attention_int8_384_64_kernel_sm80_cu_o |
unsigned char bert::fused_multihead_attention_int8_128_64_kernel_sm80_cu_o |
unsigned char bert::fused_multihead_attention_fp16_128_64_kernel_sm80_cu_o |
unsigned char bert::fused_multihead_attention_fp16_384_64_kernel_sm80_cu_o |
unsigned int bert::fused_multihead_attention_fp16_64_64_kernel_sm75_cu_o_len = 17000 |
unsigned int bert::fused_multihead_attention_fp16_96_64_kernel_sm75_cu_o_len = 28648 |
unsigned int bert::fused_multihead_attention_fp16_64_64_kernel_sm80_cu_o_len = 16744 |
unsigned int bert::fused_multihead_attention_fp16_96_64_kernel_sm80_cu_o_len = 27880 |
unsigned int bert::fused_multihead_attention_fp16_128_64_kernel_sm75_cu_o_len = 50936 |
unsigned int bert::fused_multihead_attention_fp16_384_64_kernel_sm75_cu_o_len = 34152 |
unsigned int bert::fused_multihead_attention_int8_128_64_kernel_sm75_cu_o_len = 67816 |
unsigned int bert::fused_multihead_attention_int8_384_64_kernel_sm75_cu_o_len = 50664 |
unsigned int bert::fused_multihead_attention_int8_384_64_kernel_sm80_cu_o_len = 51304 |
unsigned int bert::fused_multihead_attention_int8_128_64_kernel_sm80_cu_o_len = 61672 |
unsigned int bert::fused_multihead_attention_fp16_128_64_kernel_sm80_cu_o_len = 42984 |
unsigned int bert::fused_multihead_attention_fp16_384_64_kernel_sm80_cu_o_len = 31208 |
|
static |
unsigned char bert::fused_multihead_attention_v2_fp16_128_64_kernel_sm75_cubin |
unsigned char bert::fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin |
unsigned char bert::fused_multihead_attention_v2_fp16_256_64_kernel_sm75_cubin |
unsigned char bert::fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin |
unsigned char bert::fused_multihead_attention_v2_fp16_384_64_kernel_sm75_cubin |
unsigned char bert::fused_multihead_attention_v2_fp16_384_64_kernel_sm80_cubin |
unsigned char bert::fused_multihead_attention_v2_fp16_64_64_kernel_sm75_cubin |
unsigned char bert::fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin |
unsigned char bert::fused_multihead_attention_v2_fp16_96_64_kernel_sm75_cubin |
unsigned char bert::fused_multihead_attention_v2_fp16_96_64_kernel_sm80_cubin |
unsigned char bert::fused_multihead_attention_v2_int8_128_64_kernel_cubin |
unsigned char bert::fused_multihead_attention_v2_int8_128_64_kernel_sm75_cubin |
unsigned char bert::fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin |
unsigned char bert::fused_multihead_attention_v2_int8_192_64_kernel_cubin |
unsigned char bert::fused_multihead_attention_v2_int8_192_64_kernel_sm75_cubin |
unsigned char bert::fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin |
unsigned char bert::fused_multihead_attention_v2_int8_256_64_kernel_cubin |
unsigned char bert::fused_multihead_attention_v2_int8_256_64_kernel_sm75_cubin |
unsigned char bert::fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin |
unsigned char bert::fused_multihead_attention_v2_int8_384_64_kernel_cubin |
unsigned char bert::fused_multihead_attention_v2_int8_384_64_kernel_sm75_cubin |
unsigned char bert::fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin |
unsigned int bert::fused_multihead_attention_v2_fp16_128_64_kernel_sm75_cubin_len = 78684 |
unsigned int bert::fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin_len = 72676 |
unsigned int bert::fused_multihead_attention_v2_fp16_256_64_kernel_sm75_cubin_len = 69208 |
unsigned int bert::fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin_len = 64864 |
unsigned int bert::fused_multihead_attention_v2_fp16_384_64_kernel_sm75_cubin_len = 56148 |
unsigned int bert::fused_multihead_attention_v2_fp16_384_64_kernel_sm80_cubin_len = 71784 |
unsigned int bert::fused_multihead_attention_v2_fp16_64_64_kernel_sm75_cubin_len = 19184 |
unsigned int bert::fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin_len = 19580 |
unsigned int bert::fused_multihead_attention_v2_fp16_96_64_kernel_sm75_cubin_len = 33968 |
unsigned int bert::fused_multihead_attention_v2_fp16_96_64_kernel_sm80_cubin_len = 33204 |
unsigned int bert::fused_multihead_attention_v2_int8_128_64_kernel_cubin_len = 241780 |
unsigned int bert::fused_multihead_attention_v2_int8_128_64_kernel_sm75_cubin_len = 189888 |
unsigned int bert::fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin_len = 177232 |
unsigned int bert::fused_multihead_attention_v2_int8_192_64_kernel_cubin_len = 191728 |
unsigned int bert::fused_multihead_attention_v2_int8_192_64_kernel_sm75_cubin_len = 243644 |
unsigned int bert::fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin_len = 179788 |
unsigned int bert::fused_multihead_attention_v2_int8_256_64_kernel_cubin_len = 239856 |
unsigned int bert::fused_multihead_attention_v2_int8_256_64_kernel_sm75_cubin_len = 167356 |
unsigned int bert::fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin_len = 156492 |
unsigned int bert::fused_multihead_attention_v2_int8_384_64_kernel_cubin_len = 235504 |
unsigned int bert::fused_multihead_attention_v2_int8_384_64_kernel_sm75_cubin_len = 224060 |
unsigned int bert::fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin_len = 206668 |
|
static |
|
constexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
constexpr |
|
constexpr |
|
constexpr |
|
constexpr |
|
constexpr |
|
constexpr |
|
constexpr |
|
constexpr |
|
constexpr |
|
constexpr |