Public Types | |
using | KernelMeta = FusedMultiHeadAttentionKernelMetaInfoV2 |
using | KernelParam = Fused_multihead_attention_params_v2 |
Public Member Functions | |
FusedMultiHeadAttentionXMMAKernelV2 (const FusedMultiHeadAttentionKernelMetaInfoV2 *pMetaStart, unsigned int nMetaCount, Data_type type, unsigned int sm) | |
uint64_t | hashID (unsigned int s, bool interleaved, bool unroll) const |
virtual uint64_t | hashID (const KernelMeta &kernelMeta) const |
virtual void | run (Fused_multihead_attention_params_v2 ¶ms, cudaStream_t ss) const |
uint64_t | hashID (unsigned int s, unsigned int d) const |
void | loadXMMAKernels () |
bool | isValid (int s) const |
Protected Attributes | |
nvinfer1::CUDADriverWrapper | mDriver |
Data_type | mDataType |
const FusedMultiHeadAttentionKernelMetaInfoV2 * | mKernelMeta |
unsigned int | mKernelMetaCount |
unsigned int | mSM |
std::unordered_map< const unsigned char *, CUmodule > | mModules |
std::unordered_map< uint64_t, FusedMultiHeadAttentionKernelInfo > | mFunctions |
std::set< int > | mValidSequences |
|
inherited |
|
inherited |
|
inline |
|
inline |
|
inlinevirtual |
Reimplemented from bert::TFusedMultiHeadAttentionXMMAKernel< FusedMultiHeadAttentionKernelMetaInfoV2, Fused_multihead_attention_params_v2 >.
|
inlinevirtual |
Reimplemented from bert::TFusedMultiHeadAttentionXMMAKernel< FusedMultiHeadAttentionKernelMetaInfoV2, Fused_multihead_attention_params_v2 >.
|
inlineinherited |
|
inlineinherited |
|
inlineinherited |
|
protectedinherited |
|
protectedinherited |
|
protectedinherited |
|
protectedinherited |
|
protectedinherited |
|
protectedinherited |
|
protectedinherited |
|
protectedinherited |