/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File#
block_fmha_pipeline_problem.hpp
Go to the documentation of this file.
Definition: cluster_descriptor.hpp:13
@ VECTORIZED_LAYOUT
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_fmha_pipeline_problem.hpp:105
static constexpr index_t kVectorSize
Definition: block_fmha_pipeline_problem.hpp:111
static constexpr auto kKVLookupTable
Definition: block_fmha_pipeline_problem.hpp:113
static constexpr auto kKVMemoryLayout
Definition: block_fmha_pipeline_problem.hpp:112
static constexpr bool kIsVectorizedLayout
Definition: block_fmha_pipeline_problem.hpp:114
static constexpr index_t kPageBlockSize
Definition: block_fmha_pipeline_problem.hpp:106
Definition: block_fmha_pipeline_problem.hpp:300
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:301
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:322
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:323
std::conditional_t< kIsVLayoutRowMajor_, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition: block_fmha_pipeline_problem.hpp:315
static constexpr auto RotaryEnum
Definition: block_fmha_pipeline_problem.hpp:317
static constexpr index_t kK0
Definition: block_fmha_pipeline_problem.hpp:310
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:321
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:304
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:318
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:303
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:308
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:311
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:325
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:306
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:324
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:302
static constexpr index_t kN0
Definition: block_fmha_pipeline_problem.hpp:309
Definition: block_fmha_pipeline_problem.hpp:142
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:147
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:167
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:173
static constexpr bool kHasSink
Definition: block_fmha_pipeline_problem.hpp:176
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:156
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:175
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:153
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:168
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:145
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:154
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:155
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:150
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:146
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:144
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:171
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:143
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:148
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:151
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:169
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:165
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:160
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:166
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:152
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:158
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:162
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:149
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:174
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:172
static constexpr bool kSkipMinSeqlenQ
Definition: block_fmha_pipeline_problem.hpp:170
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:159
Definition: block_fmha_pipeline_problem.hpp:195
static constexpr bool kHasUnevenSplits
Definition: block_fmha_pipeline_problem.hpp:227
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:198
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:222
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:208
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:220
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:225
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:211
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:196
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:204
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:202
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:215
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:228
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:212
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:199
static constexpr bool kHasSink
Definition: block_fmha_pipeline_problem.hpp:230
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:226
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:200
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:206
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:197
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:218
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:213
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:229
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:203
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:205
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:223
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:207
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:219
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:224
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:201
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:221
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:209
Definition: block_fmha_pipeline_problem.hpp:30
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:56
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:43
static constexpr bool kHasDropout
Definition: block_fmha_pipeline_problem.hpp:63
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:62
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:59
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:61
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:42
static constexpr auto QScaleEnum
Definition: block_fmha_pipeline_problem.hpp:64
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:40
static constexpr bool kSkipMinSeqlenQ
Definition: block_fmha_pipeline_problem.hpp:60
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:47
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:45
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:57
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:34
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:38
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:51
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:32
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:65
remove_cvref_t< RandValOutputDataType_ > RandValOutputDataType
Definition: block_fmha_pipeline_problem.hpp:37
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:41
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:49
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:58
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:39
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:33
static constexpr bool kUseTrLoad
Definition: block_fmha_pipeline_problem.hpp:52
static constexpr bool kHasSink
Definition: block_fmha_pipeline_problem.hpp:66
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:35
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:55
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:36
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:44
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:48
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:31
Definition: block_fmha_pipeline_problem.hpp:253
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:258
static constexpr index_t kNumWarps
Definition: block_fmha_pipeline_problem.hpp:281
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:259
static constexpr index_t kHeadDimV
Definition: block_fmha_pipeline_problem.hpp:263
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:282
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:241
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:277
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:264
static constexpr index_t kMaxSplits
Definition: block_fmha_pipeline_problem.hpp:278
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:274
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:275
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:276
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:256
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:273
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:239
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:257
Definition: block_fmha_pipeline_problem.hpp:236
static constexpr index_t NThreads
Definition: block_fmha_pipeline_problem.hpp:240
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:241
static constexpr index_t MaxVectorSize
Definition: block_fmha_pipeline_problem.hpp:237
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:239
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17