TensorRT  7.2.1.6
NVIDIA TensorRT
Looking for a C++ dev who knows TensorRT?
I'm looking for work. Hire me!
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Pages
bert Namespace Reference

Namespaces

 anonymous_namespace{embLayerNormVarSeqlenPlugin.cpp}
 
 anonymous_namespace{qkvToContextInt8InterleavedPlugin.cpp}
 
 anonymous_namespace{skipLayerNormInt8InterleavedPlugin.cpp}
 

Classes

class  EmbLayerNormVarSeqlenPlugin
 
class  EmbLayerNormVarSeqlenPluginCreator
 
struct  Fused_multihead_attention_params
 
struct  Fused_multihead_attention_params_v2
 
struct  FusedMultiHeadAttentionKernelMetaInfoV1
 
struct  FusedMultiHeadAttentionKernelMetaInfoV2
 
class  FusedMultiHeadAttentionXMMAKernelV2
 
class  QKVToContextInterleavedPlugin
 
class  QKVToContextInterleavedPluginCreator
 
class  SkipLayerNormInterleavedPlugin
 
class  SkipLayerNormInterleavedPluginCreator
 
class  TFusedMHAKernelFactory
 
class  TFusedMultiHeadAttentionXMMAKernel
 

Typedefs

using FusedMultiHeadAttentionXMMAKernel = TFusedMultiHeadAttentionXMMAKernel< FusedMultiHeadAttentionKernelMetaInfoV1, Fused_multihead_attention_params >
 
using FusedMHAKernelFactory = TFusedMHAKernelFactory< FusedMultiHeadAttentionXMMAKernel >
 
using FusedMHAKernelFactoryV2 = TFusedMHAKernelFactory< FusedMultiHeadAttentionXMMAKernelV2 >
 

Enumerations

enum  Data_type {
  DATA_TYPE_BOOL,
  DATA_TYPE_E8M10,
  DATA_TYPE_E8M7,
  DATA_TYPE_FP16,
  DATA_TYPE_FP32,
  DATA_TYPE_INT4,
  DATA_TYPE_INT8,
  DATA_TYPE_INT32
}
 

Functions

static size_t get_size_in_bytes (size_t n, Data_type dtype)
 
const FusedMultiHeadAttentionXMMAKernelgetXMMAKernels (Data_type type, unsigned int sm)
 
const FusedMultiHeadAttentionXMMAKernelV2getXMMAKernelsV2 (Data_type type, unsigned int sm)
 
 REGISTER_TENSORRT_PLUGIN (QKVToContextInterleavedPluginCreator)
 
 REGISTER_TENSORRT_PLUGIN (EmbLayerNormVarSeqlenPluginCreator)
 
template<typename T >
int embSkipLayerNormVarSeqlen (cudaStream_t stream, int ld, int B, int S, const uint32_t *cuSeqlens, const int *inputIds, const int *token_ids, const T *beta, const T *gamma, const T *wordEmb, const T *posEmb, const T *tokEmb, T *output)
 
template<typename T >
int embSkipLayerNorm2 (cudaStream_t stream, int ld, int B, int S, const int *inputIds, const int *tokenIds, const int *cuSeqlens, const float *beta, const float *gamma, const T *wordEmb, const T *posEmb, const T *tokEmb, T *output)
 
void cuSeqlensToPackedMask (const uint32_t S, const uint32_t B, const uint32_t warps_m, const uint32_t warps_n, const uint32_t warps_k, const int *cuSeqlens, uint32_t *inputMaskX, cudaStream_t stream)
 
void launch_small (cudaStream_t stream, const int ld, const int total, const int8_t *input, const int8_t *skip, const half *beta, const half *gamma, int8_t *output, const float dqScaleIn, const float dqScaleSkip, const float qScale)
 
void launch_large (cudaStream_t stream, const int ld, const int total, const int8_t *input, const int8_t *skip, const half *beta, const half *gamma, int8_t *output, const float dqScaleIn, const float dqScaleSkip, const float qScale)
 
 REGISTER_TENSORRT_PLUGIN (SkipLayerNormInterleavedPluginCreator)
 
static DataType getParamWordType (DataType cfgType)
 

Variables

unsigned char fused_multihead_attention_fp16_64_64_kernel_sm75_cu_o []
 
unsigned char fused_multihead_attention_fp16_96_64_kernel_sm75_cu_o []
 
unsigned char fused_multihead_attention_fp16_64_64_kernel_sm80_cu_o []
 
unsigned char fused_multihead_attention_fp16_96_64_kernel_sm80_cu_o []
 
unsigned char fused_multihead_attention_fp16_128_64_kernel_sm75_cu_o []
 
unsigned char fused_multihead_attention_fp16_384_64_kernel_sm75_cu_o []
 
unsigned char fused_multihead_attention_int8_128_64_kernel_sm75_cu_o []
 
unsigned char fused_multihead_attention_int8_384_64_kernel_sm75_cu_o []
 
unsigned char fused_multihead_attention_int8_384_64_kernel_sm80_cu_o []
 
unsigned char fused_multihead_attention_int8_128_64_kernel_sm80_cu_o []
 
unsigned char fused_multihead_attention_fp16_128_64_kernel_sm80_cu_o []
 
unsigned char fused_multihead_attention_fp16_384_64_kernel_sm80_cu_o []
 
unsigned int fused_multihead_attention_fp16_64_64_kernel_sm75_cu_o_len = 17000
 
unsigned int fused_multihead_attention_fp16_96_64_kernel_sm75_cu_o_len = 28648
 
unsigned int fused_multihead_attention_fp16_64_64_kernel_sm80_cu_o_len = 16744
 
unsigned int fused_multihead_attention_fp16_96_64_kernel_sm80_cu_o_len = 27880
 
unsigned int fused_multihead_attention_fp16_128_64_kernel_sm75_cu_o_len = 50936
 
unsigned int fused_multihead_attention_fp16_384_64_kernel_sm75_cu_o_len = 34152
 
unsigned int fused_multihead_attention_int8_128_64_kernel_sm75_cu_o_len = 67816
 
unsigned int fused_multihead_attention_int8_384_64_kernel_sm75_cu_o_len = 50664
 
unsigned int fused_multihead_attention_int8_384_64_kernel_sm80_cu_o_len = 51304
 
unsigned int fused_multihead_attention_int8_128_64_kernel_sm80_cu_o_len = 61672
 
unsigned int fused_multihead_attention_fp16_128_64_kernel_sm80_cu_o_len = 42984
 
unsigned int fused_multihead_attention_fp16_384_64_kernel_sm80_cu_o_len = 31208
 
static const struct bert::FusedMultiHeadAttentionKernelMetaInfoV1 sMhaKernelMetaInfos []
 
unsigned char fused_multihead_attention_v2_fp16_128_64_kernel_sm75_cubin []
 
unsigned char fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin []
 
unsigned char fused_multihead_attention_v2_fp16_256_64_kernel_sm75_cubin []
 
unsigned char fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin []
 
unsigned char fused_multihead_attention_v2_fp16_384_64_kernel_sm75_cubin []
 
unsigned char fused_multihead_attention_v2_fp16_384_64_kernel_sm80_cubin []
 
unsigned char fused_multihead_attention_v2_fp16_64_64_kernel_sm75_cubin []
 
unsigned char fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin []
 
unsigned char fused_multihead_attention_v2_fp16_96_64_kernel_sm75_cubin []
 
unsigned char fused_multihead_attention_v2_fp16_96_64_kernel_sm80_cubin []
 
unsigned char fused_multihead_attention_v2_int8_128_64_kernel_cubin []
 
unsigned char fused_multihead_attention_v2_int8_128_64_kernel_sm75_cubin []
 
unsigned char fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin []
 
unsigned char fused_multihead_attention_v2_int8_192_64_kernel_cubin []
 
unsigned char fused_multihead_attention_v2_int8_192_64_kernel_sm75_cubin []
 
unsigned char fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin []
 
unsigned char fused_multihead_attention_v2_int8_256_64_kernel_cubin []
 
unsigned char fused_multihead_attention_v2_int8_256_64_kernel_sm75_cubin []
 
unsigned char fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin []
 
unsigned char fused_multihead_attention_v2_int8_384_64_kernel_cubin []
 
unsigned char fused_multihead_attention_v2_int8_384_64_kernel_sm75_cubin []
 
unsigned char fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin []
 
unsigned int fused_multihead_attention_v2_fp16_128_64_kernel_sm75_cubin_len = 78684
 
unsigned int fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin_len = 72676
 
unsigned int fused_multihead_attention_v2_fp16_256_64_kernel_sm75_cubin_len = 69208
 
unsigned int fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin_len = 64864
 
unsigned int fused_multihead_attention_v2_fp16_384_64_kernel_sm75_cubin_len = 56148
 
unsigned int fused_multihead_attention_v2_fp16_384_64_kernel_sm80_cubin_len = 71784
 
unsigned int fused_multihead_attention_v2_fp16_64_64_kernel_sm75_cubin_len = 19184
 
unsigned int fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin_len = 19580
 
unsigned int fused_multihead_attention_v2_fp16_96_64_kernel_sm75_cubin_len = 33968
 
unsigned int fused_multihead_attention_v2_fp16_96_64_kernel_sm80_cubin_len = 33204
 
unsigned int fused_multihead_attention_v2_int8_128_64_kernel_cubin_len = 241780
 
unsigned int fused_multihead_attention_v2_int8_128_64_kernel_sm75_cubin_len = 189888
 
unsigned int fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin_len = 177232
 
unsigned int fused_multihead_attention_v2_int8_192_64_kernel_cubin_len = 191728
 
unsigned int fused_multihead_attention_v2_int8_192_64_kernel_sm75_cubin_len = 243644
 
unsigned int fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin_len = 179788
 
unsigned int fused_multihead_attention_v2_int8_256_64_kernel_cubin_len = 239856
 
unsigned int fused_multihead_attention_v2_int8_256_64_kernel_sm75_cubin_len = 167356
 
unsigned int fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin_len = 156492
 
unsigned int fused_multihead_attention_v2_int8_384_64_kernel_cubin_len = 235504
 
unsigned int fused_multihead_attention_v2_int8_384_64_kernel_sm75_cubin_len = 224060
 
unsigned int fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin_len = 206668
 
static const struct bert::FusedMultiHeadAttentionKernelMetaInfoV2 sMhaKernelMetaInfosV2 []
 
constexpr uint32_t IIDX = 0
 
static constexpr int32_t kSM_XAVIER = 72
 
static constexpr int32_t kSM_TURING = 75
 
static constexpr int32_t kSM_AMPERE = 80
 
constexpr size_t threadsPerCta128 = 2 * 2 * 32
 
constexpr size_t threadsPerCta256 = 1 * 4 * 32
 
constexpr size_t threadsPerCta384 = 1 * 8 * 32
 
constexpr size_t xmmasM128 = 4
 
constexpr size_t xmmasM256 = 16
 
constexpr size_t xmmasM384 = 24
 
constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128
 
constexpr size_t packedMaskSize256 = xmmasM256 * threadsPerCta256
 
constexpr size_t packedMaskSize384 = xmmasM384 * threadsPerCta384
 
constexpr auto param_type = DataType::kHALF
 

Typedef Documentation

◆ FusedMultiHeadAttentionXMMAKernel

◆ FusedMHAKernelFactory

◆ FusedMHAKernelFactoryV2

Enumeration Type Documentation

◆ Data_type

Enumerator
DATA_TYPE_BOOL 
DATA_TYPE_E8M10 
DATA_TYPE_E8M7 
DATA_TYPE_FP16 
DATA_TYPE_FP32 
DATA_TYPE_INT4 
DATA_TYPE_INT8 
DATA_TYPE_INT32 

Function Documentation

◆ get_size_in_bytes()

static size_t bert::get_size_in_bytes ( size_t  n,
Data_type  dtype 
)
inlinestatic

◆ getXMMAKernels()

const FusedMultiHeadAttentionXMMAKernel* bert::getXMMAKernels ( Data_type  type,
unsigned int  sm 
)
inline
Here is the call graph for this function:

◆ getXMMAKernelsV2()

const FusedMultiHeadAttentionXMMAKernelV2* bert::getXMMAKernelsV2 ( Data_type  type,
unsigned int  sm 
)
inline
Here is the call graph for this function:
Here is the caller graph for this function:

◆ REGISTER_TENSORRT_PLUGIN() [1/3]

bert::REGISTER_TENSORRT_PLUGIN ( QKVToContextInterleavedPluginCreator  )

◆ REGISTER_TENSORRT_PLUGIN() [2/3]

bert::REGISTER_TENSORRT_PLUGIN ( EmbLayerNormVarSeqlenPluginCreator  )

◆ embSkipLayerNormVarSeqlen()

template<typename T >
int bert::embSkipLayerNormVarSeqlen ( cudaStream_t  stream,
int  ld,
int  B,
int  S,
const uint32_t *  cuSeqlens,
const int inputIds,
const int token_ids,
const T *  beta,
const T *  gamma,
const T *  wordEmb,
const T *  posEmb,
const T *  tokEmb,
T *  output 
)

◆ embSkipLayerNorm2()

template<typename T >
int bert::embSkipLayerNorm2 ( cudaStream_t  stream,
int  ld,
int  B,
int  S,
const int inputIds,
const int tokenIds,
const int cuSeqlens,
const float *  beta,
const float *  gamma,
const T *  wordEmb,
const T *  posEmb,
const T *  tokEmb,
T *  output 
)

◆ cuSeqlensToPackedMask()

void bert::cuSeqlensToPackedMask ( const uint32_t  S,
const uint32_t  B,
const uint32_t  warps_m,
const uint32_t  warps_n,
const uint32_t  warps_k,
const int cuSeqlens,
uint32_t *  inputMaskX,
cudaStream_t  stream 
)

◆ launch_small()

void bert::launch_small ( cudaStream_t  stream,
const int  ld,
const int  total,
const int8_t *  input,
const int8_t *  skip,
const half *  beta,
const half *  gamma,
int8_t *  output,
const float  dqScaleIn,
const float  dqScaleSkip,
const float  qScale 
)
Here is the caller graph for this function:

◆ launch_large()

void bert::launch_large ( cudaStream_t  stream,
const int  ld,
const int  total,
const int8_t *  input,
const int8_t *  skip,
const half *  beta,
const half *  gamma,
int8_t *  output,
const float  dqScaleIn,
const float  dqScaleSkip,
const float  qScale 
)
Here is the caller graph for this function:

◆ REGISTER_TENSORRT_PLUGIN() [3/3]

bert::REGISTER_TENSORRT_PLUGIN ( SkipLayerNormInterleavedPluginCreator  )

◆ getParamWordType()

static DataType bert::getParamWordType ( DataType  cfgType)
inlinestatic
Here is the call graph for this function:
Here is the caller graph for this function:

Variable Documentation

◆ fused_multihead_attention_fp16_64_64_kernel_sm75_cu_o

unsigned char bert::fused_multihead_attention_fp16_64_64_kernel_sm75_cu_o

◆ fused_multihead_attention_fp16_96_64_kernel_sm75_cu_o

unsigned char bert::fused_multihead_attention_fp16_96_64_kernel_sm75_cu_o

◆ fused_multihead_attention_fp16_64_64_kernel_sm80_cu_o

unsigned char bert::fused_multihead_attention_fp16_64_64_kernel_sm80_cu_o

◆ fused_multihead_attention_fp16_96_64_kernel_sm80_cu_o

unsigned char bert::fused_multihead_attention_fp16_96_64_kernel_sm80_cu_o

◆ fused_multihead_attention_fp16_128_64_kernel_sm75_cu_o

unsigned char bert::fused_multihead_attention_fp16_128_64_kernel_sm75_cu_o

◆ fused_multihead_attention_fp16_384_64_kernel_sm75_cu_o

unsigned char bert::fused_multihead_attention_fp16_384_64_kernel_sm75_cu_o

◆ fused_multihead_attention_int8_128_64_kernel_sm75_cu_o

unsigned char bert::fused_multihead_attention_int8_128_64_kernel_sm75_cu_o

◆ fused_multihead_attention_int8_384_64_kernel_sm75_cu_o

unsigned char bert::fused_multihead_attention_int8_384_64_kernel_sm75_cu_o

◆ fused_multihead_attention_int8_384_64_kernel_sm80_cu_o

unsigned char bert::fused_multihead_attention_int8_384_64_kernel_sm80_cu_o

◆ fused_multihead_attention_int8_128_64_kernel_sm80_cu_o

unsigned char bert::fused_multihead_attention_int8_128_64_kernel_sm80_cu_o

◆ fused_multihead_attention_fp16_128_64_kernel_sm80_cu_o

unsigned char bert::fused_multihead_attention_fp16_128_64_kernel_sm80_cu_o

◆ fused_multihead_attention_fp16_384_64_kernel_sm80_cu_o

unsigned char bert::fused_multihead_attention_fp16_384_64_kernel_sm80_cu_o

◆ fused_multihead_attention_fp16_64_64_kernel_sm75_cu_o_len

unsigned int bert::fused_multihead_attention_fp16_64_64_kernel_sm75_cu_o_len = 17000

◆ fused_multihead_attention_fp16_96_64_kernel_sm75_cu_o_len

unsigned int bert::fused_multihead_attention_fp16_96_64_kernel_sm75_cu_o_len = 28648

◆ fused_multihead_attention_fp16_64_64_kernel_sm80_cu_o_len

unsigned int bert::fused_multihead_attention_fp16_64_64_kernel_sm80_cu_o_len = 16744

◆ fused_multihead_attention_fp16_96_64_kernel_sm80_cu_o_len

unsigned int bert::fused_multihead_attention_fp16_96_64_kernel_sm80_cu_o_len = 27880

◆ fused_multihead_attention_fp16_128_64_kernel_sm75_cu_o_len

unsigned int bert::fused_multihead_attention_fp16_128_64_kernel_sm75_cu_o_len = 50936

◆ fused_multihead_attention_fp16_384_64_kernel_sm75_cu_o_len

unsigned int bert::fused_multihead_attention_fp16_384_64_kernel_sm75_cu_o_len = 34152

◆ fused_multihead_attention_int8_128_64_kernel_sm75_cu_o_len

unsigned int bert::fused_multihead_attention_int8_128_64_kernel_sm75_cu_o_len = 67816

◆ fused_multihead_attention_int8_384_64_kernel_sm75_cu_o_len

unsigned int bert::fused_multihead_attention_int8_384_64_kernel_sm75_cu_o_len = 50664

◆ fused_multihead_attention_int8_384_64_kernel_sm80_cu_o_len

unsigned int bert::fused_multihead_attention_int8_384_64_kernel_sm80_cu_o_len = 51304

◆ fused_multihead_attention_int8_128_64_kernel_sm80_cu_o_len

unsigned int bert::fused_multihead_attention_int8_128_64_kernel_sm80_cu_o_len = 61672

◆ fused_multihead_attention_fp16_128_64_kernel_sm80_cu_o_len

unsigned int bert::fused_multihead_attention_fp16_128_64_kernel_sm80_cu_o_len = 42984

◆ fused_multihead_attention_fp16_384_64_kernel_sm80_cu_o_len

unsigned int bert::fused_multihead_attention_fp16_384_64_kernel_sm80_cu_o_len = 31208

◆ sMhaKernelMetaInfos

const struct bert::FusedMultiHeadAttentionKernelMetaInfoV1 bert::sMhaKernelMetaInfos[]
static

◆ fused_multihead_attention_v2_fp16_128_64_kernel_sm75_cubin

unsigned char bert::fused_multihead_attention_v2_fp16_128_64_kernel_sm75_cubin

◆ fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin

unsigned char bert::fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin

◆ fused_multihead_attention_v2_fp16_256_64_kernel_sm75_cubin

unsigned char bert::fused_multihead_attention_v2_fp16_256_64_kernel_sm75_cubin

◆ fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin

unsigned char bert::fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin

◆ fused_multihead_attention_v2_fp16_384_64_kernel_sm75_cubin

unsigned char bert::fused_multihead_attention_v2_fp16_384_64_kernel_sm75_cubin

◆ fused_multihead_attention_v2_fp16_384_64_kernel_sm80_cubin

unsigned char bert::fused_multihead_attention_v2_fp16_384_64_kernel_sm80_cubin

◆ fused_multihead_attention_v2_fp16_64_64_kernel_sm75_cubin

unsigned char bert::fused_multihead_attention_v2_fp16_64_64_kernel_sm75_cubin

◆ fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin

unsigned char bert::fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin

◆ fused_multihead_attention_v2_fp16_96_64_kernel_sm75_cubin

unsigned char bert::fused_multihead_attention_v2_fp16_96_64_kernel_sm75_cubin

◆ fused_multihead_attention_v2_fp16_96_64_kernel_sm80_cubin

unsigned char bert::fused_multihead_attention_v2_fp16_96_64_kernel_sm80_cubin

◆ fused_multihead_attention_v2_int8_128_64_kernel_cubin

unsigned char bert::fused_multihead_attention_v2_int8_128_64_kernel_cubin

◆ fused_multihead_attention_v2_int8_128_64_kernel_sm75_cubin

unsigned char bert::fused_multihead_attention_v2_int8_128_64_kernel_sm75_cubin

◆ fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin

unsigned char bert::fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin

◆ fused_multihead_attention_v2_int8_192_64_kernel_cubin

unsigned char bert::fused_multihead_attention_v2_int8_192_64_kernel_cubin

◆ fused_multihead_attention_v2_int8_192_64_kernel_sm75_cubin

unsigned char bert::fused_multihead_attention_v2_int8_192_64_kernel_sm75_cubin

◆ fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin

unsigned char bert::fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin

◆ fused_multihead_attention_v2_int8_256_64_kernel_cubin

unsigned char bert::fused_multihead_attention_v2_int8_256_64_kernel_cubin

◆ fused_multihead_attention_v2_int8_256_64_kernel_sm75_cubin

unsigned char bert::fused_multihead_attention_v2_int8_256_64_kernel_sm75_cubin

◆ fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin

unsigned char bert::fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin

◆ fused_multihead_attention_v2_int8_384_64_kernel_cubin

unsigned char bert::fused_multihead_attention_v2_int8_384_64_kernel_cubin

◆ fused_multihead_attention_v2_int8_384_64_kernel_sm75_cubin

unsigned char bert::fused_multihead_attention_v2_int8_384_64_kernel_sm75_cubin

◆ fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin

unsigned char bert::fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin

◆ fused_multihead_attention_v2_fp16_128_64_kernel_sm75_cubin_len

unsigned int bert::fused_multihead_attention_v2_fp16_128_64_kernel_sm75_cubin_len = 78684

◆ fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin_len

unsigned int bert::fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin_len = 72676

◆ fused_multihead_attention_v2_fp16_256_64_kernel_sm75_cubin_len

unsigned int bert::fused_multihead_attention_v2_fp16_256_64_kernel_sm75_cubin_len = 69208

◆ fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin_len

unsigned int bert::fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin_len = 64864

◆ fused_multihead_attention_v2_fp16_384_64_kernel_sm75_cubin_len

unsigned int bert::fused_multihead_attention_v2_fp16_384_64_kernel_sm75_cubin_len = 56148

◆ fused_multihead_attention_v2_fp16_384_64_kernel_sm80_cubin_len

unsigned int bert::fused_multihead_attention_v2_fp16_384_64_kernel_sm80_cubin_len = 71784

◆ fused_multihead_attention_v2_fp16_64_64_kernel_sm75_cubin_len

unsigned int bert::fused_multihead_attention_v2_fp16_64_64_kernel_sm75_cubin_len = 19184

◆ fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin_len

unsigned int bert::fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin_len = 19580

◆ fused_multihead_attention_v2_fp16_96_64_kernel_sm75_cubin_len

unsigned int bert::fused_multihead_attention_v2_fp16_96_64_kernel_sm75_cubin_len = 33968

◆ fused_multihead_attention_v2_fp16_96_64_kernel_sm80_cubin_len

unsigned int bert::fused_multihead_attention_v2_fp16_96_64_kernel_sm80_cubin_len = 33204

◆ fused_multihead_attention_v2_int8_128_64_kernel_cubin_len

unsigned int bert::fused_multihead_attention_v2_int8_128_64_kernel_cubin_len = 241780

◆ fused_multihead_attention_v2_int8_128_64_kernel_sm75_cubin_len

unsigned int bert::fused_multihead_attention_v2_int8_128_64_kernel_sm75_cubin_len = 189888

◆ fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin_len

unsigned int bert::fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin_len = 177232

◆ fused_multihead_attention_v2_int8_192_64_kernel_cubin_len

unsigned int bert::fused_multihead_attention_v2_int8_192_64_kernel_cubin_len = 191728

◆ fused_multihead_attention_v2_int8_192_64_kernel_sm75_cubin_len

unsigned int bert::fused_multihead_attention_v2_int8_192_64_kernel_sm75_cubin_len = 243644

◆ fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin_len

unsigned int bert::fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin_len = 179788

◆ fused_multihead_attention_v2_int8_256_64_kernel_cubin_len

unsigned int bert::fused_multihead_attention_v2_int8_256_64_kernel_cubin_len = 239856

◆ fused_multihead_attention_v2_int8_256_64_kernel_sm75_cubin_len

unsigned int bert::fused_multihead_attention_v2_int8_256_64_kernel_sm75_cubin_len = 167356

◆ fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin_len

unsigned int bert::fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin_len = 156492

◆ fused_multihead_attention_v2_int8_384_64_kernel_cubin_len

unsigned int bert::fused_multihead_attention_v2_int8_384_64_kernel_cubin_len = 235504

◆ fused_multihead_attention_v2_int8_384_64_kernel_sm75_cubin_len

unsigned int bert::fused_multihead_attention_v2_int8_384_64_kernel_sm75_cubin_len = 224060

◆ fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin_len

unsigned int bert::fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin_len = 206668

◆ sMhaKernelMetaInfosV2

const struct bert::FusedMultiHeadAttentionKernelMetaInfoV2 bert::sMhaKernelMetaInfosV2[]
static

◆ IIDX

constexpr uint32_t bert::IIDX = 0
constexpr

◆ kSM_XAVIER

constexpr int32_t bert::kSM_XAVIER = 72
staticconstexpr

◆ kSM_TURING

constexpr int32_t bert::kSM_TURING = 75
staticconstexpr

◆ kSM_AMPERE

constexpr int32_t bert::kSM_AMPERE = 80
staticconstexpr

◆ threadsPerCta128

constexpr size_t bert::threadsPerCta128 = 2 * 2 * 32
constexpr

◆ threadsPerCta256

constexpr size_t bert::threadsPerCta256 = 1 * 4 * 32
constexpr

◆ threadsPerCta384

constexpr size_t bert::threadsPerCta384 = 1 * 8 * 32
constexpr

◆ xmmasM128

constexpr size_t bert::xmmasM128 = 4
constexpr

◆ xmmasM256

constexpr size_t bert::xmmasM256 = 16
constexpr

◆ xmmasM384

constexpr size_t bert::xmmasM384 = 24
constexpr

◆ packedMaskSize128

constexpr size_t bert::packedMaskSize128 = xmmasM128 * threadsPerCta128
constexpr

◆ packedMaskSize256

constexpr size_t bert::packedMaskSize256 = xmmasM256 * threadsPerCta256
constexpr

◆ packedMaskSize384

constexpr size_t bert::packedMaskSize384 = xmmasM384 * threadsPerCta384
constexpr

◆ param_type

constexpr auto bert::param_type = DataType::kHALF
constexpr