/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Source File
fmha_fwd_pagedkv_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 #include <utility>
14 #include <variant>
15 
16 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
17 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
18 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
19 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
20 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
21 
22 namespace ck_tile {
23 
24 // TODO: This class is a variant of the existing FmhaFwdSplitKVKernel pipeline.
25 // Refactoring to extract shared logic is recommended as future work.
26 template <typename FmhaPipeline_, typename EpiloguePipeline_>
28 {
31  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
32  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
33 
34  static_assert(kBlockPerCu > 0);
35  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
36 
44 
46 
47  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
48  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
49  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
50  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
51  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
52  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
53  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
54  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
55  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
56  static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
57  static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
58  static constexpr bool kHasSink = FmhaPipeline::kHasSink;
59 
62  static constexpr bool kHasMask = FmhaMask::IsMasking;
63 
64  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
65 
66  // clang-format off
67  template <typename T> struct t2s;
68  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
69  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
70  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
71  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
72  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
73  // clang-format on
74 
75  CK_TILE_HOST static std::string GetName()
76  {
77  // sync with generate.py
78  // clang-format off
79  using bfs = typename FmhaPipeline::BlockFmhaShape;
80  using g0br = typename bfs::Gemm0BlockWarps;
81  using g1br = typename bfs::Gemm1BlockWarps;
82  using g0wt = typename bfs::Gemm0WarpTile;
83  using g1wt = typename bfs::Gemm1WarpTile;
84  #define _SS_ std::string
85  #define _TS_ std::to_string
86  auto pn = [&] () {
87  std::string n;
88  if (kPadSeqLenQ) n += "s";
89  if (kPadSeqLenK) n += "sk";
90  if (kPadHeadDimQ) n += "d";
91  if (kPadHeadDimV) n += "dv";
92  return n.empty() ? n : std::string("p") + n; }();
93  return
94  _SS_("fmha_fwd_pagedkv_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
95  "_" + (kIsGroupMode ? "group" : "batch") + "_"
96  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
97  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
98  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
99  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
100  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
101  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
102  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
103  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
104  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
105  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
106  #undef _SS_
107  #undef _TS_
108  // clang-format on
109  }
110 
111  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
112  // arg
114  {
115  };
116 
117  // kargs use aggregate initializer, so no constructor will provided
118  // use inheritance to minimize karg size
119  // user need to use MakeKargs() function to create kargs.
121  {
122  const void* q_ptr;
123  const void* k_ptr;
124  const void* v_ptr;
125  void* o_ptr;
126  const void* sink_ptr;
127 
132 
134  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
135  // if this param is larger than 1, indicate MQA/GQA case
137  float scale_s;
138 
143 
148  };
149 
151  {
153 
154  void init_logits_soft_cap(float logits_soft_cap_)
155  {
156  if(0 < logits_soft_cap_)
157  {
158  logits_soft_cap = logits_soft_cap_;
160  }
161  else
162  {
163  logits_soft_cap = 0.f;
164  logits_soft_cap_rcp = 0.f;
165  }
166  }
167 
170  };
171 
173  {
174  const void* bias_ptr = nullptr;
177  };
178 
180  {
182  };
183 
185  {
186  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
187  const void* alibi_slope_ptr;
188  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
189  };
190 
192  {
193  // ck_tile::index_t window_size_left, window_size_right;
196  };
197 
199  {
200  float scale_p;
201  float scale_o;
202  };
203 
205  {
206  void* lse_ptr = nullptr;
209  };
210 
212  {
214  };
215 
217  {
221  };
222 
224  {
225  bool is_gappy = false;
226  };
227 
229  {
231  };
232 
235  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
236  FmhaFwdBatchModeBiasKargs,
237  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
238  FmhaFwdAlibiKargs,
239  FmhaFwdEmptyKargs<0>>>,
240  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
241  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
242  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
243  std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>,
244  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<4>>
245  {
247 
249  ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
250  // single kcache page-block
251  ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
252  // single vcache page-block
254  };
255 
258  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
259  FmhaFwdCommonBiasKargs,
260  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
261  FmhaFwdAlibiKargs,
262  FmhaFwdEmptyKargs<0>>>,
263  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
264  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
265  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
266  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<4>>,
267  std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, FmhaFwdEmptyKargs<5>>,
268  std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
269  {
273 
274  ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
275  // for single kcache page-block
276  ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
277  // for single vcache page-block
278  };
279 
280  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
281 
283  {
287  };
288 
289  template <bool Cond = !kIsGroupMode>
290  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
291  MakeKargsImpl(const void* q_ptr,
292  const void* k_ptr,
293  const void* v_ptr,
294  const void* bias_ptr,
295  void* lse_ptr,
296  void* o_ptr,
297  ck_tile::index_t seqlen_q,
298  ck_tile::index_t seqlen_k,
299  const void* seqlen_k_ptr, // only used for (paged-) kvcache
300  ck_tile::index_t hdim_q,
301  ck_tile::index_t hdim_v,
302  ck_tile::index_t num_head_q,
303  ck_tile::index_t nhead_ratio_qk,
304  const void* block_table_ptr,
305  ck_tile::index_t batch_stride_block_table,
306  ck_tile::index_t page_block_size,
307  const void* cache_batch_idx,
308  float scale_s,
309  float scale_p,
310  float scale_o,
311  float logits_soft_cap,
312  ck_tile::index_t stride_q,
313  ck_tile::index_t stride_k,
314  ck_tile::index_t stride_v,
315  ck_tile::index_t stride_bias,
316  ck_tile::index_t stride_o,
317  ck_tile::index_t nhead_stride_q,
318  ck_tile::index_t nhead_stride_k,
319  ck_tile::index_t nhead_stride_v,
320  ck_tile::index_t nhead_stride_bias,
321  ck_tile::index_t nhead_stride_lse,
322  ck_tile::index_t nhead_stride_o,
323  ck_tile::index_t batch_stride_q,
324  ck_tile::index_t batch_stride_k,
325  ck_tile::index_t batch_stride_v,
326  ck_tile::index_t batch_stride_bias,
327  ck_tile::index_t batch_stride_lse,
328  ck_tile::index_t batch_stride_o,
329  ck_tile::index_t window_size_left,
330  ck_tile::index_t window_size_right,
331  ck_tile::index_t sink_size,
332  ck_tile::index_t mask_type,
333  const void* sink_ptr = nullptr)
334  {
335  Kargs kargs{{q_ptr,
336  k_ptr,
337  v_ptr,
338  o_ptr,
339  sink_ptr,
340  seqlen_q,
341  seqlen_k,
342  hdim_q,
343  hdim_v,
344  num_head_q,
345  nhead_ratio_qk,
346 #if CK_TILE_FMHA_FWD_FAST_EXP2
347  static_cast<float>(scale_s * ck_tile::log2e_v<>),
348 #else
349  scale_s,
350 #endif
351  stride_q,
352  stride_k,
353  stride_v,
354  stride_o,
355  nhead_stride_q,
356  nhead_stride_k,
357  nhead_stride_v,
358  nhead_stride_o}, // args for common karg
359  {}, // placeholder for bias
360  {}, // placeholder for mask
361  {}, // placeholder for lse
362  {}, // placeholder for fp8_static_quant args
363  {}, // placeholder for pagedkv
364  {}, // placeholder for logits_soft_cap
365  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
366  batch_stride_q,
367  batch_stride_k,
368  batch_stride_v,
369  batch_stride_o};
370 
372  {
373  kargs.bias_ptr = bias_ptr;
374  kargs.stride_bias = stride_bias;
375  kargs.nhead_stride_bias = nhead_stride_bias;
376  kargs.batch_stride_bias = batch_stride_bias;
377  }
378  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
379  {
380  kargs.alibi_slope_ptr = bias_ptr;
381  kargs.alibi_slope_stride = stride_bias;
382  }
383  if constexpr(kHasMask)
384  {
385  kargs.window_size_left = window_size_left;
386  kargs.window_size_right = window_size_right;
387  kargs.sink_size = sink_size;
388  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
389  }
390  if constexpr(kStoreLSE)
391  {
392  kargs.lse_ptr = lse_ptr;
393  kargs.nhead_stride_lse = nhead_stride_lse;
394  kargs.batch_stride_lse = batch_stride_lse;
395  }
396  if constexpr(kDoFp8StaticQuant)
397  {
398  kargs.scale_p = scale_p;
399  kargs.scale_o = scale_o;
400  }
401  if constexpr(kIsPagedKV)
402  {
403  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
404  kargs.batch_stride_block_table = batch_stride_block_table;
405  kargs.page_block_size = page_block_size;
406  }
407  else
408  {
409  kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
410  }
411  if constexpr(kHasLogitsSoftCap)
412  {
413  kargs.init_logits_soft_cap(logits_soft_cap);
414  }
415 
416  return kargs;
417  }
418 
419  // std::variant<> can't take in a list initializer, overload for backward compatibility
420  template <bool Cond = !kIsGroupMode>
421  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
422  MakeKargs(const void* q_ptr,
423  const void* k_ptr,
424  const void* v_ptr,
425  const void* bias_ptr,
426  void* lse_ptr,
427  void* o_ptr,
428  ck_tile::index_t seqlen_q,
429  ck_tile::index_t seqlen_k,
430  const void* seqlen_k_ptr, // only used for (paged-) kvcache
431  ck_tile::index_t hdim_q,
432  ck_tile::index_t hdim_v,
433  ck_tile::index_t num_head_q,
434  ck_tile::index_t nhead_ratio_qk,
435  const void* block_table_ptr,
436  ck_tile::index_t batch_stride_block_table,
437  ck_tile::index_t page_block_size,
438  const void* cache_batch_idx,
439  float scale_s,
440  float scale_p,
441  float scale_o,
442  float logits_soft_cap,
443  ck_tile::index_t stride_q,
444  ck_tile::index_t stride_k,
445  ck_tile::index_t stride_v,
446  ck_tile::index_t stride_bias,
447  ck_tile::index_t stride_o,
448  ck_tile::index_t nhead_stride_q,
449  ck_tile::index_t nhead_stride_k,
450  ck_tile::index_t nhead_stride_v,
451  ck_tile::index_t nhead_stride_bias,
452  ck_tile::index_t nhead_stride_lse,
453  ck_tile::index_t nhead_stride_o,
454  ck_tile::index_t batch_stride_q,
455  ck_tile::index_t batch_stride_k,
456  ck_tile::index_t batch_stride_v,
457  ck_tile::index_t batch_stride_bias,
458  ck_tile::index_t batch_stride_lse,
459  ck_tile::index_t batch_stride_o,
460  ck_tile::index_t window_size_left,
461  ck_tile::index_t window_size_right,
462  ck_tile::index_t sink_size,
463  ck_tile::index_t mask_type,
464  const void* sink_ptr = nullptr)
465  {
466  return MakeKargsImpl(q_ptr,
467  k_ptr,
468  v_ptr,
469  bias_ptr,
470  lse_ptr,
471  o_ptr,
472  seqlen_q,
473  seqlen_k,
474  seqlen_k_ptr,
475  hdim_q,
476  hdim_v,
477  num_head_q,
478  nhead_ratio_qk,
479  block_table_ptr,
480  batch_stride_block_table,
481  page_block_size,
482  cache_batch_idx,
483  scale_s,
484  scale_p,
485  scale_o,
486  logits_soft_cap,
487  stride_q,
488  stride_k,
489  stride_v,
490  stride_bias,
491  stride_o,
492  nhead_stride_q,
493  nhead_stride_k,
494  nhead_stride_v,
495  nhead_stride_bias,
496  nhead_stride_lse,
497  nhead_stride_o,
498  batch_stride_q,
499  batch_stride_k,
500  batch_stride_v,
501  batch_stride_bias,
502  batch_stride_lse,
503  batch_stride_o,
504  window_size_left,
505  window_size_right,
506  sink_size,
507  mask_type,
508  sink_ptr);
509  }
510 
511  template <bool Cond = kIsGroupMode>
512  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
513  MakeKargsImpl(const void* q_ptr,
514  const void* k_ptr,
515  const void* v_ptr,
516  const void* bias_ptr,
517  void* lse_ptr,
518  void* o_ptr,
519  const void* seqstart_q_ptr,
520  const void* seqstart_k_ptr,
521  const void* seqlen_k_ptr,
522  ck_tile::index_t hdim_q,
523  ck_tile::index_t hdim_v,
524  ck_tile::index_t num_head_q,
525  ck_tile::index_t nhead_ratio_qk,
526  const void* block_table_ptr,
527  ck_tile::index_t batch_stride_block_table,
528  ck_tile::index_t page_block_size,
529  bool is_gappy,
530  float scale_s,
531  float scale_p,
532  float scale_o,
533  float logits_soft_cap,
534  ck_tile::index_t stride_q,
535  ck_tile::index_t stride_k,
536  ck_tile::index_t stride_v,
537  ck_tile::index_t stride_bias,
538  ck_tile::index_t stride_o,
539  ck_tile::index_t nhead_stride_q,
540  ck_tile::index_t nhead_stride_k,
541  ck_tile::index_t nhead_stride_v,
542  ck_tile::index_t nhead_stride_bias,
543  ck_tile::index_t nhead_stride_lse,
544  ck_tile::index_t nhead_stride_o,
545  ck_tile::index_t batch_stride_k, // only used for paged-kvcache
546  ck_tile::index_t batch_stride_v, // only used for paged-kvcache
547  ck_tile::index_t window_size_left,
548  ck_tile::index_t window_size_right,
549  ck_tile::index_t sink_size,
550  ck_tile::index_t mask_type,
551  ck_tile::index_t min_seqlen_q,
552  const void* sink_ptr = nullptr)
553  {
554  Kargs kargs{{q_ptr,
555  k_ptr,
556  v_ptr,
557  o_ptr,
558  sink_ptr,
559  -1, // seqlen will be updated by another pointer
560  -1, //
561  hdim_q,
562  hdim_v,
563  num_head_q,
564  nhead_ratio_qk,
565 #if CK_TILE_FMHA_FWD_FAST_EXP2
566  static_cast<float>(scale_s * ck_tile::log2e_v<>),
567 #else
568  scale_s,
569 #endif
570  stride_q,
571  stride_k,
572  stride_v,
573  stride_o,
574  nhead_stride_q,
575  nhead_stride_k,
576  nhead_stride_v,
577  nhead_stride_o}, // args for common karg
578  {}, // placeholder for bias
579  {}, // placeholder for mask
580  {}, // placeholder for lse
581  {}, // placeholder for fp8_static_quant args
582  {}, // placeholder for logits_soft_cap
583  {}, // placeholder for pagdkv
584  {}, // placeholder for min_seqlen_q
585  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
586  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
587  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
588  batch_stride_k,
589  batch_stride_v};
590 
592  {
593  kargs.bias_ptr = bias_ptr;
594  kargs.stride_bias = stride_bias;
595  kargs.nhead_stride_bias = nhead_stride_bias;
596  }
597  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
598  {
599  kargs.alibi_slope_ptr = bias_ptr;
600  kargs.alibi_slope_stride = stride_bias;
601  }
602  if constexpr(kHasMask)
603  {
604  kargs.window_size_left = window_size_left;
605  kargs.window_size_right = window_size_right;
606  kargs.sink_size = sink_size;
607  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
608  }
609  if constexpr(kStoreLSE)
610  {
611  kargs.lse_ptr = lse_ptr;
612  kargs.nhead_stride_lse = nhead_stride_lse;
613  }
614  if constexpr(kDoFp8StaticQuant)
615  {
616  kargs.scale_p = scale_p;
617  kargs.scale_o = scale_o;
618  }
619  if constexpr(kHasLogitsSoftCap)
620  {
621  kargs.init_logits_soft_cap(logits_soft_cap);
622  }
623  if constexpr(kIsPagedKV)
624  {
625  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
626  kargs.batch_stride_block_table = batch_stride_block_table;
627  kargs.page_block_size = page_block_size;
628  kargs.is_gappy = is_gappy;
629  }
630  if constexpr(kSkipMinSeqlenQ)
631  {
632  kargs.min_seqlen_q = min_seqlen_q;
633  }
634 
635  return kargs;
636  }
637 
638  // std::variant<> can't take in a list initializer, overload for backward compatibility
639  template <bool Cond = kIsGroupMode>
640  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
641  MakeKargs(const void* q_ptr,
642  const void* k_ptr,
643  const void* v_ptr,
644  const void* bias_ptr,
645  void* lse_ptr,
646  void* o_ptr,
647  const void* seqstart_q_ptr,
648  const void* seqstart_k_ptr,
649  const void* seqlen_k_ptr,
650  ck_tile::index_t hdim_q,
651  ck_tile::index_t hdim_v,
652  ck_tile::index_t num_head_q,
653  ck_tile::index_t nhead_ratio_qk,
654  const void* block_table_ptr,
655  ck_tile::index_t batch_stride_block_table,
656  ck_tile::index_t page_block_size,
657  bool is_gappy,
658  float scale_s,
659  float scale_p,
660  float scale_o,
661  float logits_soft_cap,
662  ck_tile::index_t stride_q,
663  ck_tile::index_t stride_k,
664  ck_tile::index_t stride_v,
665  ck_tile::index_t stride_bias,
666  ck_tile::index_t stride_o,
667  ck_tile::index_t nhead_stride_q,
668  ck_tile::index_t nhead_stride_k,
669  ck_tile::index_t nhead_stride_v,
670  ck_tile::index_t nhead_stride_bias,
671  ck_tile::index_t nhead_stride_lse,
672  ck_tile::index_t nhead_stride_o,
673  ck_tile::index_t batch_stride_k, // only used for paged-kvcache
674  ck_tile::index_t batch_stride_v, // only used for paged-kvcache
675  ck_tile::index_t window_size_left,
676  ck_tile::index_t window_size_right,
677  ck_tile::index_t sink_size,
678  ck_tile::index_t mask_type,
679  ck_tile::index_t min_seqlen_q,
680  const void* sink_ptr = nullptr)
681  {
682  return MakeKargsImpl(q_ptr,
683  k_ptr,
684  v_ptr,
685  bias_ptr,
686  lse_ptr,
687  o_ptr,
688  seqstart_q_ptr,
689  seqstart_k_ptr,
690  seqlen_k_ptr,
691  hdim_q,
692  hdim_v,
693  num_head_q,
694  nhead_ratio_qk,
695  block_table_ptr,
696  batch_stride_block_table,
697  page_block_size,
698  is_gappy,
699  scale_s,
700  scale_p,
701  scale_o,
702  logits_soft_cap,
703  stride_q,
704  stride_k,
705  stride_v,
706  stride_bias,
707  stride_o,
708  nhead_stride_q,
709  nhead_stride_k,
710  nhead_stride_v,
711  nhead_stride_bias,
712  nhead_stride_lse,
713  nhead_stride_o,
714  batch_stride_k,
715  batch_stride_v,
716  window_size_left,
717  window_size_right,
718  sink_size,
719  mask_type,
720  min_seqlen_q,
721  sink_ptr);
722  }
723 
724  CK_TILE_HOST static void PrintParameters(const Kargs& kargs, int num_batches)
725  {
726  static bool dummy = [&]() {
727  std::cout << std::endl;
728 
729  std::cout << " q_ptr: " << kargs.q_ptr << " k_ptr:" << kargs.k_ptr
730  << " v_ptr: " << kargs.v_ptr << " o_ptr:" << kargs.o_ptr
731  << " hdim_q: " << kargs.hdim_q << " hdim_v: " << kargs.hdim_v
732  << " num_head_q:" << kargs.num_head_q
733  << " nhead_ratio_qk: " << kargs.nhead_ratio_qk << " scale_s:" << kargs.scale_s
734  << " stride_q:" << kargs.stride_q << " stride_k:" << kargs.stride_k
735  << " stride_v:" << kargs.stride_v << " stride_o:" << kargs.stride_o
736  << " nhead_stride_q: " << kargs.nhead_stride_q
737  << " nhead_stride_k: " << kargs.nhead_stride_k
738  << " nhead_stride_v:" << kargs.nhead_stride_v
739  << " nhead_stride_o: " << kargs.nhead_stride_o;
740  if constexpr(!kIsGroupMode)
741  {
742  std::cout << " batch_stride_q:" << kargs.batch_stride_q;
743  }
744  std::cout << " batch_stride_k:" << kargs.batch_stride_k
745  << " batch_stride_v:" << kargs.batch_stride_v;
746 
747  if constexpr(kIsGroupMode)
748  {
749  if constexpr(kSkipMinSeqlenQ)
750  {
751  std::cout << " min_seqlen_q: " << kargs.min_seqlen_q;
752  }
753 
754  std::cout << " seqstart_q_ptr:" << kargs.seqstart_q_ptr
755  << " seqstart_k_ptr: " << kargs.seqstart_k_ptr
756  << " seqlen_k_ptr:" << kargs.seqlen_k_ptr;
757  if(kargs.seqlen_k_ptr != nullptr)
758  {
759  std::cout << "{";
760  for(int i_batch = 0; i_batch < num_batches; i_batch++)
761  std::cout << kargs.seqlen_k_ptr[i_batch] << ",";
762  std::cout << "}";
763  }
764  }
765  if constexpr(kHasMask)
766  {
767  std::cout << " window_size_left: " << kargs.window_size_left
768  << " window_size_right:" << kargs.window_size_right
769  << " mask_type: " << static_cast<int>(kargs.mask_type);
770  }
771 
772  if constexpr(kIsPagedKV)
773  {
774  std::cout << " block_table_ptr: " << kargs.block_table_ptr
775  << " batch_stride_block_table:" << kargs.batch_stride_block_table
776  << " page_block_size: " << kargs.page_block_size;
777 
778  std::cout << "table value: [";
779  for(int b = 0; b < num_batches; b++)
780  {
781  std::cout << "[ ";
782  for(int i = 0; i < kargs.batch_stride_block_table; i++)
783  {
784  std::cout << kargs.block_table_ptr[b * kargs.batch_stride_block_table + i]
785  << ",";
786  }
787  std::cout << " ]";
788  }
789  std::cout << " ]";
790  }
791  std::cout << std::endl;
792  return true;
793  }();
794  (void)dummy;
795  }
796  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
797  ck_tile::index_t nhead_,
798  ck_tile::index_t seqlen_q_,
799  ck_tile::index_t hdim_v_,
800  bool has_padded_seqlen_k)
801  {
802  // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
803  if(has_padded_seqlen_k)
804  {
805  // TODO: this may need tuning
806  return dim3(nhead_,
807  batch_size_,
808  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
809  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
810  }
811  else
812  {
813  // TODO: this may need tuning
814  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
815  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
816  nhead_,
817  batch_size_);
818  }
819  }
820 
821  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
822  {
823  bool has_padded_seqlen_k = false;
824 
825  if constexpr(kIsGroupMode)
826  has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
827 
828  if(has_padded_seqlen_k)
829  {
830  // const index_t num_tile_m0 = seqlen_q / kM0;
831  const index_t num_tile_n1 =
832  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
833 
834  const index_t i_block = blockIdx.z;
835  const index_t i_nhead = blockIdx.x;
836  const index_t i_batch = blockIdx.y;
837 
838  const auto f = [](index_t dividend, index_t divisor) {
839  index_t quotient = dividend / divisor;
840  index_t modulus = dividend - quotient * divisor;
841  return ck_tile::make_tuple(quotient, modulus);
842  };
843 
844  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
845 
846  if constexpr(kHasMask)
847  {
848  // assume that num_tile_n1 is always 1
849  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
850  }
851  else
852  {
853  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
854  }
855  }
856  else
857  {
858  // const index_t num_tile_m0 = seqlen_q / kM0;
859  const index_t num_tile_n1 =
860  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
861 
862  const index_t i_block = blockIdx.x;
863  const index_t i_nhead = blockIdx.y;
864  const index_t i_batch = blockIdx.z;
865 
866  const auto f = [](index_t dividend, index_t divisor) {
867  index_t quotient = dividend / divisor;
868  index_t modulus = dividend - quotient * divisor;
869  return ck_tile::make_tuple(quotient, modulus);
870  };
871 
872  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
873 
874  if constexpr(kHasMask)
875  {
876  // assume that num_tile_n1 is always 1
877  return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
878  }
879  else
880  {
881  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
882  }
883  }
884  }
885 
886  CK_TILE_HOST static dim3 BlockSize()
887  {
888  if(is_wave32())
889  {
890  return dim3(kBlockSize / 2);
891  }
892  else
893  {
894  return dim3(kBlockSize);
895  }
896  }
897 
899  {
900  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
901  }
902 
903  CK_TILE_DEVICE void operator()(Kargs kargs) const
904  {
905  // allocate LDS
906  __shared__ char smem_ptr[GetSmemSize()];
907 
908  // divide problem
909  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
910  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
911  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
912 
913  long_index_t batch_offset_q = 0;
914  long_index_t batch_offset_k = 0;
915  long_index_t batch_offset_v = 0;
916  long_index_t batch_offset_bias = 0;
917  long_index_t batch_offset_lse = 0;
918  long_index_t batch_offset_o = 0;
919  index_t kv_l2p_offset = 0;
920  const float sink_value =
921  kargs.sink_ptr != nullptr
922  ? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
924 
925  if constexpr(kIsGroupMode)
926  {
927  // get starting offset for each batch
928  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
929  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
930 
931  batch_offset_q = query_start * kargs.stride_q;
932  batch_offset_k = key_start * kargs.stride_k;
933  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
934  {
935  batch_offset_v = key_start * kargs.stride_v;
936  }
937  else
938  {
939  batch_offset_v = key_start;
940  }
942  {
943  batch_offset_bias = query_start * kargs.stride_bias;
944  }
945  if constexpr(kStoreLSE)
946  {
947  batch_offset_lse = query_start;
948  }
949 
950  batch_offset_o = query_start * kargs.stride_o;
951 
952  // get real # queries & # keys under group mode
953  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
954  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
955 
956  if constexpr(kSkipMinSeqlenQ)
957  {
958  if(kargs.seqlen_q <= kargs.min_seqlen_q)
959  {
960  return;
961  }
962  }
963 
964  // # of required blocks is different in each groups, terminate unnecessary blocks
965  // earlier
966  if(kargs.seqlen_q <= i_m0)
967  {
968  return;
969  }
970 
971  if(kargs.seqlen_k_ptr != nullptr)
972  {
973  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
974  }
975  else
976  {
977  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
978  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
979  }
980 
981  if constexpr(kIsPagedKV)
982  {
983  if(kargs.is_gappy)
984  {
985  // seqstart_k_ptr has different meaning in this case
986  kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
987  }
988  }
989  }
990  else
991  {
992  const index_t i_cache_batch = [&, i_batch_ = i_batch] {
993  if constexpr(kIsPagedKV)
994  {
995  return i_batch_;
996  }
997  else
998  {
999  return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
1000  : i_batch_);
1001  }
1002  }();
1003 
1004  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1005  batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
1006  batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
1008  {
1009  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1010  }
1011  if constexpr(kStoreLSE)
1012  {
1013  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1014  }
1015 
1016  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1017 
1018  if(kargs.seqlen_k_ptr != nullptr)
1019  {
1020  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1021  }
1022  }
1023 
1024  // for simplicity, batch stride we just modify the pointer
1025  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1026  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1027  batch_offset_q;
1028  const KDataType* k_ptr =
1029  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1030  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
1031  batch_offset_k;
1032  const VDataType* v_ptr =
1033  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1034  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
1035  batch_offset_v;
1036  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1037  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1038  batch_offset_o;
1039 
1040  // Q/K/V DRAM and DRAM window
1041  const auto q_dram = [&]() {
1042  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1043  q_ptr,
1044  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1045  make_tuple(kargs.stride_q, 1),
1047  number<1>{});
1048  if constexpr(FmhaPipeline::kQLoadOnce)
1049  {
1050  return pad_tensor_view(
1051  q_dram_naive,
1054  }
1055  else
1056  {
1057  return pad_tensor_view(
1058  q_dram_naive,
1061  }
1062  }();
1063 
1064  const auto make_k_dram = [&](const KDataType* data, index_t height) {
1065  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1066  data, // will update this pointer if using paged-kvcache
1067  make_tuple(height, kargs.hdim_q),
1068  make_tuple(kargs.stride_k, 1),
1070  number<1>{});
1071 
1072  return pad_tensor_view(
1073  k_dram_naive,
1076  };
1077  const auto k_dram = [&]() {
1078  if constexpr(kIsPagedKV)
1079  {
1080  return make_k_dram(nullptr, kargs.page_block_size);
1081  }
1082  else
1083  {
1084  return make_k_dram(k_ptr, kargs.seqlen_k);
1085  }
1086  }();
1087 
1088  const auto make_v_dram = [&](const VDataType* data, index_t length) {
1089  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1090  {
1091  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1092  data, // will update this pointer if using paged-kvcache
1093  make_tuple(length, kargs.hdim_v),
1094  make_tuple(kargs.stride_v, 1),
1096  number<1>{});
1097 
1098  const auto v_dram_transposed =
1099  transform_tensor_view(v_dram_naive,
1101  make_pass_through_transform(length)),
1104 
1105  return pad_tensor_view(
1106  v_dram_transposed,
1109  }
1110  else
1111  {
1112  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1113  data, // will update this pointer if using paged-kvcache
1114  make_tuple(kargs.hdim_v, length),
1115  make_tuple(kargs.stride_v, 1),
1117  number<1>{});
1118 
1119  return pad_tensor_view(
1120  v_dram_naive,
1123  }
1124  };
1125  const auto v_dram = [&]() {
1126  if constexpr(kIsPagedKV)
1127  {
1128  return make_v_dram(nullptr, kargs.page_block_size);
1129  }
1130  else
1131  {
1132  return make_v_dram(v_ptr, kargs.seqlen_k);
1133  }
1134  }();
1135 
1136  auto q_dram_window = make_tile_window(
1137  q_dram,
1138  [&]() {
1139  if constexpr(FmhaPipeline::kQLoadOnce)
1142  else
1144  }(),
1145  {i_m0, 0});
1146 
1147  auto k_page_block_navigator =
1148  [&, i_batch_ = i_batch, i_nhead_ = i_nhead / kargs.nhead_ratio_qk]() {
1149  if constexpr(kIsPagedKV)
1150  {
1151  const auto* block_indices =
1152  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
1153  i_batch_ * kargs.batch_stride_block_table;
1154  const index_t num_blocks =
1155  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
1156 
1157  const long_index_t fixed_offset =
1158  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_k;
1159 
1160  return make_page_block_navigator<const KDataType, 0>(
1161  kargs.k_ptr,
1162  kargs.batch_stride_k, // kcache page-block stride/size
1163  fixed_offset,
1164  block_indices,
1165  num_blocks,
1166  kargs.page_block_size,
1167  k_dram,
1168  make_k_dram(nullptr,
1169  (kv_l2p_offset + kargs.seqlen_k) -
1170  (num_blocks - 1) * kargs.page_block_size));
1171  }
1172  else
1173  {
1174  return make_page_block_navigator(k_dram);
1175  }
1176  }();
1177 
1178  auto v_page_block_navigator =
1179  [&, i_batch_ = i_batch, i_nhead_ = i_nhead / kargs.nhead_ratio_qk]() {
1180  if constexpr(kIsPagedKV)
1181  {
1182  const auto* block_indices =
1183  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
1184  i_batch_ * kargs.batch_stride_block_table;
1185  const index_t num_blocks =
1186  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
1187 
1188  const long_index_t fixed_offset =
1189  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_v;
1190 
1191  return make_page_block_navigator<const VDataType, 1>(
1192  kargs.v_ptr,
1193  kargs.batch_stride_v, // vcache page-block stride/size
1194  fixed_offset,
1195  block_indices,
1196  num_blocks,
1197  kargs.page_block_size,
1198  v_dram,
1199  make_v_dram(nullptr,
1200  (kv_l2p_offset + kargs.seqlen_k) -
1201  (num_blocks - 1) * kargs.page_block_size));
1202  }
1203  else
1204  {
1205  return make_page_block_navigator(v_dram);
1206  }
1207  }();
1208 
1209  auto k_dram_window_lengths =
1211  auto v_dram_window_lengths =
1213 
1216  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
1217  constexpr auto bias_dram_window_lengths =
1220  {
1221  const BiasDataType* bias_ptr =
1222  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1223  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1224  batch_offset_bias;
1225 
1226  const auto bias_dram = [&]() {
1227  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1228  bias_ptr,
1229  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1230  make_tuple(kargs.stride_bias, 1),
1232  number<1>{});
1233 
1234  return pad_tensor_view(bias_dram_naive,
1235  bias_dram_window_lengths,
1237  }();
1238 
1239  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1240  }
1241  else
1242  {
1243  return make_null_tile_window(bias_dram_window_lengths);
1244  }
1245  }();
1246 
1247  // lse
1248  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1249  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1250  if constexpr(kStoreLSE)
1251  {
1252  LSEDataType* lse_ptr =
1253  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1254  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
1255 
1256  const auto lse_dram = [&]() {
1257  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1258  lse_ptr,
1259  make_tuple(kargs.seqlen_q),
1260  make_tuple(1),
1261  number<1>{},
1262  number<1>{});
1263 
1264  return pad_tensor_view(
1265  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1266  }();
1267 
1268  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1269  }
1270  else
1271  {
1272  return make_null_tile_window(lse_dram_window_lengths);
1273  }
1274  }();
1275 
1276  FmhaMask mask = [&]() {
1277  if constexpr(kHasMask)
1278  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1279  kargs.window_size_left,
1280  kargs.window_size_right,
1281  kargs.sink_size,
1282  kargs.seqlen_q,
1283  kargs.seqlen_k,
1285  else
1286  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1287  }();
1288 
1289  // WA i_batch capture structure binding before c++20
1290  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1291  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1292  {
1293  // data loading, shared by entire wg
1294  // TODO: how to use s_read?
1295  SaccDataType slope =
1296  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1297  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1298 #if CK_TILE_FMHA_FWD_FAST_EXP2
1299  slope *= ck_tile::log2e_v<>;
1300 #endif
1301  if constexpr(kHasMask)
1302  {
1303  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1304  kargs.window_size_left,
1305  kargs.window_size_right,
1306  kargs.seqlen_q,
1307  kargs.seqlen_k,
1308  kargs.mask_type);
1309  }
1310  else
1311  {
1313  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1314  }
1315  }
1316  else
1317  {
1319  }
1320  }();
1321 
1322  AttentionVariant variant;
1323  const auto variant_params = [&] {
1324  if constexpr(kHasLogitsSoftCap)
1325  {
1327  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1328  }
1329  else
1330  {
1331  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1332  }
1333  }();
1334 
1335  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1336 
1337  auto o_acc_tile = [&]() {
1338  if constexpr(kDoFp8StaticQuant)
1339  {
1340  return FmhaPipeline{}(q_dram_window,
1341  identity{}, // q_element_func
1342  k_dram_window_lengths,
1343  k_page_block_navigator,
1344  identity{}, // k_element_func
1345  v_dram_window_lengths,
1346  v_page_block_navigator,
1347  identity{}, // v_element_func
1348  bias_dram_window,
1349  identity{}, // bias_element_func
1350  lse_dram_window,
1351  identity{}, // lse_element_func
1352  identity{}, // s_acc_element_func
1353  scales<remove_cvref_t<decltype(kargs.scale_p)>>{
1354  kargs.scale_p}, // p_compute_element_func
1356  scales<remove_cvref_t<decltype(kargs.scale_o)>>{
1357  kargs.scale_o}), // o_acc_element_func
1358  mask,
1359  position_encoding,
1360  kargs.scale_s,
1361  variant,
1362  variant_params,
1363  block_indices,
1364  kv_l2p_offset,
1365  smem_ptr,
1366  sink_value);
1367  }
1368  else
1369  {
1370  return FmhaPipeline{}(q_dram_window,
1371  k_dram_window_lengths,
1372  k_page_block_navigator,
1373  v_dram_window_lengths,
1374  v_page_block_navigator,
1375  bias_dram_window,
1376  lse_dram_window,
1377  mask,
1378  position_encoding,
1379  kargs.scale_s,
1380  variant,
1381  variant_params,
1382  block_indices,
1383  kv_l2p_offset,
1384  smem_ptr,
1385  sink_value);
1386  }
1387  }();
1388 
1389  // O DRAM and O DRAM window
1390  auto o_dram = [&]() {
1391  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1392  o_ptr,
1393  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1394  make_tuple(kargs.stride_o, 1),
1396  number<1>{});
1397  return pad_tensor_view(
1398  o_dram_naive,
1401  }();
1402 
1403  auto o_dram_window =
1404  make_tile_window(o_dram,
1406  {i_m0, i_n1});
1407 
1408  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1409  }
1410 };
1411 
1412 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:146
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_composes(Ts &&... ts)
Definition: unary_element_function.hpp:51
_Float16 fp16_t
Definition: half.hpp:110
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition: page_block_navigator.hpp:333
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:526
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1634
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:158
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_position_encoding.hpp:137
Definition: fmha_fwd_pagedkv_kernel.hpp:283
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_pagedkv_kernel.hpp:285
ck_tile::index_t batch_idx
Definition: fmha_fwd_pagedkv_kernel.hpp:284
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_pagedkv_kernel.hpp:286
Definition: fmha_fwd_pagedkv_kernel.hpp:229
const int32_t * cache_batch_idx
Definition: fmha_fwd_pagedkv_kernel.hpp:230
Definition: fmha_fwd_pagedkv_kernel.hpp:217
ck_tile::index_t batch_stride_block_table
Definition: fmha_fwd_pagedkv_kernel.hpp:219
const int32_t * block_table_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:218
ck_tile::index_t page_block_size
Definition: fmha_fwd_pagedkv_kernel.hpp:220
Definition: fmha_fwd_pagedkv_kernel.hpp:185
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_pagedkv_kernel.hpp:188
const void * alibi_slope_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:187
Definition: fmha_fwd_pagedkv_kernel.hpp:180
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_pagedkv_kernel.hpp:181
Definition: fmha_fwd_pagedkv_kernel.hpp:245
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_pagedkv_kernel.hpp:249
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_pagedkv_kernel.hpp:248
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:246
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_pagedkv_kernel.hpp:253
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_pagedkv_kernel.hpp:251
Definition: fmha_fwd_pagedkv_kernel.hpp:173
ck_tile::index_t stride_bias
Definition: fmha_fwd_pagedkv_kernel.hpp:175
const void * bias_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:174
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_pagedkv_kernel.hpp:176
Definition: fmha_fwd_pagedkv_kernel.hpp:121
ck_tile::index_t hdim_v
Definition: fmha_fwd_pagedkv_kernel.hpp:131
ck_tile::index_t seqlen_q
Definition: fmha_fwd_pagedkv_kernel.hpp:128
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_pagedkv_kernel.hpp:145
ck_tile::index_t stride_o
Definition: fmha_fwd_pagedkv_kernel.hpp:142
const void * k_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:123
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_pagedkv_kernel.hpp:144
ck_tile::index_t stride_v
Definition: fmha_fwd_pagedkv_kernel.hpp:141
float scale_s
Definition: fmha_fwd_pagedkv_kernel.hpp:137
ck_tile::index_t stride_k
Definition: fmha_fwd_pagedkv_kernel.hpp:140
const void * v_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:124
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_pagedkv_kernel.hpp:146
const void * sink_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:126
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_pagedkv_kernel.hpp:147
ck_tile::index_t hdim_q
Definition: fmha_fwd_pagedkv_kernel.hpp:130
ck_tile::index_t seqlen_k
Definition: fmha_fwd_pagedkv_kernel.hpp:129
const void * q_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:122
ck_tile::index_t num_head_q
Definition: fmha_fwd_pagedkv_kernel.hpp:133
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_pagedkv_kernel.hpp:136
void * o_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:125
ck_tile::index_t stride_q
Definition: fmha_fwd_pagedkv_kernel.hpp:139
Definition: fmha_fwd_pagedkv_kernel.hpp:205
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_pagedkv_kernel.hpp:208
void * lse_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:206
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_pagedkv_kernel.hpp:207
Definition: fmha_fwd_pagedkv_kernel.hpp:114
Definition: fmha_fwd_pagedkv_kernel.hpp:199
float scale_p
Definition: fmha_fwd_pagedkv_kernel.hpp:200
float scale_o
Definition: fmha_fwd_pagedkv_kernel.hpp:201
Definition: fmha_fwd_pagedkv_kernel.hpp:269
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:272
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:271
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_pagedkv_kernel.hpp:274
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_pagedkv_kernel.hpp:276
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:270
Definition: fmha_fwd_pagedkv_kernel.hpp:151
float logits_soft_cap
Definition: fmha_fwd_pagedkv_kernel.hpp:168
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_pagedkv_kernel.hpp:154
float logits_soft_cap_rcp
Definition: fmha_fwd_pagedkv_kernel.hpp:169
Definition: fmha_fwd_pagedkv_kernel.hpp:192
ck_tile::index_t window_size_left
Definition: fmha_fwd_pagedkv_kernel.hpp:194
ck_tile::index_t sink_size
Definition: fmha_fwd_pagedkv_kernel.hpp:194
ck_tile::index_t window_size_right
Definition: fmha_fwd_pagedkv_kernel.hpp:194
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_pagedkv_kernel.hpp:195
Definition: fmha_fwd_pagedkv_kernel.hpp:212
ck_tile::index_t min_seqlen_q
Definition: fmha_fwd_pagedkv_kernel.hpp:213
Definition: fmha_fwd_pagedkv_kernel.hpp:224
bool is_gappy
Definition: fmha_fwd_pagedkv_kernel.hpp:225
Definition: fmha_fwd_pagedkv_kernel.hpp:67
Definition: fmha_fwd_pagedkv_kernel.hpp:28
static constexpr bool kHasSink
Definition: fmha_fwd_pagedkv_kernel.hpp:58
static constexpr bool kIsGroupMode
Definition: fmha_fwd_pagedkv_kernel.hpp:47
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_pagedkv_kernel.hpp:52
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_pagedkv_kernel.hpp:30
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_pagedkv_kernel.hpp:32
static constexpr bool kStoreLSE
Definition: fmha_fwd_pagedkv_kernel.hpp:54
static CK_TILE_HOST std::string GetName()
Definition: fmha_fwd_pagedkv_kernel.hpp:75
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_pagedkv_kernel.hpp:49
static constexpr bool kIsPagedKV
Definition: fmha_fwd_pagedkv_kernel.hpp:57
static constexpr bool kSkipMinSeqlenQ
Definition: fmha_fwd_pagedkv_kernel.hpp:56
static CK_TILE_HOST void PrintParameters(const Kargs &kargs, int num_batches)
Definition: fmha_fwd_pagedkv_kernel.hpp:724
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_pagedkv_kernel.hpp:31
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_pagedkv_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:41
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, const void *sink_ptr=nullptr)
Definition: fmha_fwd_pagedkv_kernel.hpp:513
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_fwd_pagedkv_kernel.hpp:886
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, const void *sink_ptr=nullptr)
Definition: fmha_fwd_pagedkv_kernel.hpp:291
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:39
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_pagedkv_kernel.hpp:821
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_pagedkv_kernel.hpp:35
static constexpr bool kUseAsyncCopy
Definition: fmha_fwd_pagedkv_kernel.hpp:64
static constexpr bool kHasMask
Definition: fmha_fwd_pagedkv_kernel.hpp:62
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k)
Definition: fmha_fwd_pagedkv_kernel.hpp:796
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_pagedkv_kernel.hpp:29
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_pagedkv_kernel.hpp:42
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:40
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_pagedkv_kernel.hpp:50
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, const void *sink_ptr=nullptr)
Definition: fmha_fwd_pagedkv_kernel.hpp:641
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_pagedkv_kernel.hpp:45
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_pagedkv_kernel.hpp:48
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:43
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, const void *sink_ptr=nullptr)
Definition: fmha_fwd_pagedkv_kernel.hpp:422
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_pagedkv_kernel.hpp:61
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_pagedkv_kernel.hpp:898
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_pagedkv_kernel.hpp:903
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_pagedkv_kernel.hpp:60
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:38
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:37
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_pagedkv_kernel.hpp:55
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_pagedkv_kernel.hpp:280
static constexpr auto BiasEnum
Definition: fmha_fwd_pagedkv_kernel.hpp:53
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:114
Definition: numeric.hpp:18
Definition: unary_element_function.hpp:58
Definition: math.hpp:28
Definition: sequence.hpp:49