/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File
block_fmha_pipeline_problem.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 template <typename QDataType_,
13  typename KDataType_,
14  typename VDataType_,
15  typename SaccDataType_,
16  typename SMPLComputeDataType_,
17  typename BiasDataType_,
18  typename RandValOutputDataType_,
19  typename LSEDataType_,
20  typename PDataType_,
21  typename OaccDataType_,
22  typename ODataType_,
23  typename BlockFmhaShape_,
24  bool kIsGroupMode_,
25  typename AttentionVariant_,
26  typename FmhaMask_,
27  bool kUseTrLoad_,
28  typename Traits_>
30 {
46 
47  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
48  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
49  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
50 
51  static constexpr bool kIsGroupMode = kIsGroupMode_;
52  static constexpr bool kUseTrLoad = kUseTrLoad_;
53 
54  // attributes from traits
55  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
56  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
57  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
58  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
59  static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
60  static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
61  static constexpr auto BiasEnum = Traits::BiasEnum;
62  static constexpr bool kStoreLSE = Traits::kStoreLSE;
63  static constexpr bool kHasDropout = Traits::kHasDropout;
64  static constexpr auto QScaleEnum = Traits::QScaleEnum;
65  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
66  static constexpr bool kHasSink = Traits::kHasSink;
67 };
68 
69 template <typename QDataType_,
70  typename KDataType_,
71  typename VDataType_,
72  typename SaccDataType_,
73  typename SMPLComputeDataType_,
74  typename BiasDataType_,
75  typename RandValOutputDataType_,
76  typename LSEDataType_,
77  typename PDataType_,
78  typename OaccDataType_,
79  typename ODataType_,
80  typename BlockFmhaShape_,
81  bool kIsGroupMode_,
82  typename AttentionVariant_,
83  typename FmhaMask_,
84  bool kUseTrLoad_,
85  int kPageBlockSize_,
86  typename Traits_>
88  : public BlockFmhaPipelineProblem<QDataType_,
89  KDataType_,
90  VDataType_,
91  SaccDataType_,
92  SMPLComputeDataType_,
93  BiasDataType_,
94  RandValOutputDataType_,
95  LSEDataType_,
96  PDataType_,
97  OaccDataType_,
98  ODataType_,
99  BlockFmhaShape_,
100  kIsGroupMode_,
101  AttentionVariant_,
102  FmhaMask_,
103  kUseTrLoad_,
104  Traits_>
105 {
106  static constexpr index_t kPageBlockSize = kPageBlockSize_;
107  static_assert(kPageBlockSize > 0, "kPageBlockSize must be positive");
108  static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0,
109  "kPageBlockSize must be power of two");
110 
111  static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4
112  static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout;
113  static constexpr auto kKVLookupTable = Traits_::kKVLookupTable;
114  static constexpr bool kIsVectorizedLayout =
116 
117  static_assert(BlockFmhaShape_::kQKHeaddim % kVectorSize == 0,
118  "kQKHeaddim must be divisible by kVectorSize");
119  static_assert(!(kPageBlockSize == 1 && kIsVectorizedLayout),
120  "page_size=1 only supports linear KV cache layout");
121  static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0,
122  "kPageBlockSize must be divisible by kVectorSize for vectorized layout");
123  static_assert(kIsGroupMode_, "Batch prefill requires group mode");
124 };
125 
126 template <typename QDataType_,
127  typename KDataType_,
128  typename VDataType_,
129  typename SaccDataType_,
130  typename SMPLComputeDataType_,
131  typename BiasDataType_,
132  typename LSEDataType_,
133  typename PDataType_,
134  typename OaccDataType_,
135  typename ODataType_,
136  typename BlockFmhaShape_,
137  bool kIsGroupMode_,
138  typename AttentionVariant_,
139  typename FmhaMask_,
140  typename Traits_>
142 {
157 
158  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
159  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
160  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
161 
162  static constexpr bool kIsGroupMode = kIsGroupMode_;
163 
164  // attributes from traits
165  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
166  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
167  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
168  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
169  static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
170  static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
171  static constexpr auto BiasEnum = Traits::BiasEnum;
172  static constexpr bool kStoreLSE = Traits::kStoreLSE;
173  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
174  static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
175  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
176  static constexpr bool kHasSink = Traits::kHasSink;
177 };
178 
179 template <typename QDataType_,
180  typename KDataType_,
181  typename VDataType_,
182  typename SaccDataType_,
183  typename SMPLComputeDataType_,
184  typename BiasDataType_,
185  typename LSEDataType_,
186  typename PDataType_,
187  typename OaccDataType_,
188  typename ODataType_,
189  typename BlockFmhaShape_,
190  bool kIsGroupMode_,
191  typename AttentionVariant_,
192  typename FmhaMask_,
193  typename Traits_>
195 {
210 
211  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
212  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
213  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
214 
215  static constexpr bool kIsGroupMode = kIsGroupMode_;
216 
217  // attributes from traits
218  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
219  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
220  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
221  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
222  static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
223  static constexpr auto BiasEnum = Traits::BiasEnum;
224  static constexpr bool kStoreLSE = Traits::kStoreLSE;
225  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
226  static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
227  static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
228  static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
229  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
230  static constexpr bool kHasSink = Traits::kHasSink;
231 };
232 
233 // extract tile size attributes to remove dependency on traits
234 template <typename OaccDataType_, ck_tile::index_t kN1_>
236 {
237  static constexpr index_t MaxVectorSize = 16 / sizeof(OaccDataType_);
238 
239  static constexpr index_t kN1 = kN1_;
240  static constexpr index_t NThreads = kN1 / MaxVectorSize;
241  static constexpr index_t kM0 = get_warp_size() / NThreads; // MThreadPerWarp
242 };
243 
244 template <typename LSEDataType_,
245  typename OaccDataType_,
246  typename ODataType_,
247  index_t HeadDimV_,
248  bool kIsGroupMode_,
249  ck_tile::index_t kN1_,
250  typename Traits_>
252  : BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType_, kN1_>
253 {
255 
260 
261  static_assert(std::is_same_v<LSEDataType, OaccDataType>);
262 
263  static constexpr index_t kHeadDimV = HeadDimV_;
264  static constexpr bool kIsGroupMode = kIsGroupMode_;
265 
266  using BaseType::kM0;
267  using BaseType::kN1;
268  using BaseType::NThreads;
269 
270  static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0);
271 
272  // attributes from traits
273  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
274  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
275  static constexpr bool kStoreLSE = Traits::kStoreLSE;
276  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
277  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
278  static constexpr index_t kMaxSplits = Traits::kMaxSplits;
279  static_assert(8 <= kMaxSplits);
280 
281  static constexpr index_t kNumWarps = 4;
282  static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
283 
284  static_assert(get_warp_size() <= (kM0 * kMaxSplits) &&
285  (kM0 * kMaxSplits) % get_warp_size() == 0);
286 };
287 
288 template <typename QDataType_,
289  typename KDataType_,
290  typename VDataType_,
291  index_t kM0_,
292  index_t kN0_,
293  index_t kK0_,
294  index_t kN1_,
295  bool kIsVLayoutRowMajor_,
296  RotaryEmbeddingEnum RotaryEnum_,
297  bool kIsPagedKV_,
298  typename Traits_>
300 {
305 
306  static constexpr index_t kBlockSize = 256;
307 
308  static constexpr index_t kM0 = kM0_;
309  static constexpr index_t kN0 = kN0_;
310  static constexpr index_t kK0 = kK0_;
311  static constexpr index_t kN1 = kN1_;
312 
313  using VLayout = std::conditional_t<kIsVLayoutRowMajor_,
316 
317  static constexpr auto RotaryEnum = RotaryEnum_;
318  static constexpr bool kIsPagedKV = kIsPagedKV_;
319 
320  // attributes from traits
321  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
322  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
323  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
324  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
325  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
326 };
327 
328 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
RotaryEmbeddingEnum
Definition: block_rotary_embedding.hpp:12
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: block_fmha_pipeline_problem.hpp:105
static constexpr index_t kVectorSize
Definition: block_fmha_pipeline_problem.hpp:111
static constexpr auto kKVLookupTable
Definition: block_fmha_pipeline_problem.hpp:113
static constexpr auto kKVMemoryLayout
Definition: block_fmha_pipeline_problem.hpp:112
static constexpr bool kIsVectorizedLayout
Definition: block_fmha_pipeline_problem.hpp:114
static constexpr index_t kPageBlockSize
Definition: block_fmha_pipeline_problem.hpp:106
Definition: block_fmha_pipeline_problem.hpp:300
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:301
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:322
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:323
std::conditional_t< kIsVLayoutRowMajor_, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition: block_fmha_pipeline_problem.hpp:315
static constexpr auto RotaryEnum
Definition: block_fmha_pipeline_problem.hpp:317
static constexpr index_t kK0
Definition: block_fmha_pipeline_problem.hpp:310
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:321
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:304
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:318
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:303
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:308
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:311
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:325
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:306
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:324
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:302
static constexpr index_t kN0
Definition: block_fmha_pipeline_problem.hpp:309
Definition: block_fmha_pipeline_problem.hpp:142
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:147
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:167
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:173
static constexpr bool kHasSink
Definition: block_fmha_pipeline_problem.hpp:176
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:156
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:175
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:153
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:168
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:145
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:154
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:155
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:150
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:146
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:144
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:171
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:143
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:148
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:151
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:169
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:165
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:160
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:166
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:152
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:158
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:162
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:149
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:174
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:172
static constexpr bool kSkipMinSeqlenQ
Definition: block_fmha_pipeline_problem.hpp:170
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:159
Definition: block_fmha_pipeline_problem.hpp:195
static constexpr bool kHasUnevenSplits
Definition: block_fmha_pipeline_problem.hpp:227
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:198
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:222
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:208
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:220
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:225
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:211
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:196
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:204
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:202
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:215
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:228
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:212
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:199
static constexpr bool kHasSink
Definition: block_fmha_pipeline_problem.hpp:230
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:226
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:200
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:206
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:197
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:218
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:213
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:229
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:203
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:205
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:223
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:207
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:219
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:224
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:201
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:221
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:209
Definition: block_fmha_pipeline_problem.hpp:30
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:56
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:43
static constexpr bool kHasDropout
Definition: block_fmha_pipeline_problem.hpp:63
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:62
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:59
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:61
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:42
static constexpr auto QScaleEnum
Definition: block_fmha_pipeline_problem.hpp:64
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:40
static constexpr bool kSkipMinSeqlenQ
Definition: block_fmha_pipeline_problem.hpp:60
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:47
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:45
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:57
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:34
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:38
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:51
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:32
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:65
remove_cvref_t< RandValOutputDataType_ > RandValOutputDataType
Definition: block_fmha_pipeline_problem.hpp:37
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:41
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:49
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:58
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:39
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:33
static constexpr bool kUseTrLoad
Definition: block_fmha_pipeline_problem.hpp:52
static constexpr bool kHasSink
Definition: block_fmha_pipeline_problem.hpp:66
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:35
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:55
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:36
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:44
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:48
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:31
Definition: block_fmha_pipeline_problem.hpp:253
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:258
static constexpr index_t kNumWarps
Definition: block_fmha_pipeline_problem.hpp:281
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:259
static constexpr index_t kHeadDimV
Definition: block_fmha_pipeline_problem.hpp:263
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:282
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:241
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:277
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:264
static constexpr index_t kMaxSplits
Definition: block_fmha_pipeline_problem.hpp:278
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:274
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:275
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:276
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:256
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:273
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:239
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:257
Definition: block_fmha_pipeline_problem.hpp:236
static constexpr index_t NThreads
Definition: block_fmha_pipeline_problem.hpp:240
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:241
static constexpr index_t MaxVectorSize
Definition: block_fmha_pipeline_problem.hpp:237
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:239
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17