/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Source File
fmha_fwd_splitkv_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 
14 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
15 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
16 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
17 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
18 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
19 
20 namespace ck_tile {
21 
22 template <typename FmhaPipeline_, typename EpiloguePipeline_>
24 {
27  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
28  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
29 
30  static_assert(kBlockPerCu > 0);
31  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
32 
41 
43 
44  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
45  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
46  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
47  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
48  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
49  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
50  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
51  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
52  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
53  static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
54  static constexpr bool kHasSink = FmhaPipeline::Problem::kHasSink;
55  static constexpr bool kMergeNumHeadGroupsSeqLenQ =
56  FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
59  static constexpr bool kHasMask = FmhaMask::IsMasking;
60 
61  static_assert(!kMergeNumHeadGroupsSeqLenQ ||
63  !kHasMask));
64 
65  // clang-format off
66  template <typename T> struct t2s;
67  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
68  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
69  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
70  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
71  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
72  // clang-format on
73 
74  CK_TILE_HOST static std::string GetName()
75  {
76  // sync with generate.py
77  // clang-format off
78  using bfs = typename FmhaPipeline::BlockFmhaShape;
79  using g0br = typename bfs::Gemm0BlockWarps;
80  using g1br = typename bfs::Gemm1BlockWarps;
81  using g0wt = typename bfs::Gemm0WarpTile;
82  using g1wt = typename bfs::Gemm1WarpTile;
83  #define _SS_ std::string
84  #define _TS_ std::to_string
85  auto pn = [&] () {
86  std::string n;
87  if (kPadSeqLenQ) n += "s";
88  if (kPadSeqLenK) n += "sk";
89  if (kPadHeadDimQ) n += "d";
90  if (kPadHeadDimV) n += "dv";
91  return n.empty() ? n : std::string("p") + n; }();
92  return
93  _SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
94  "_" + (kIsGroupMode ? "group" : "batch") + "_"
95  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
96  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
97  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
98  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
99  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
100  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
101  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
102  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
103  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
104  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
105  (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
106  #undef _SS_
107  #undef _TS_
108  // clang-format on
109  }
110 
111  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
112  // arg
113  struct EmptyKargs
114  {
115  };
116 
117  // kargs use aggregate initializer, so no constructor will provided
118  // use inheritance to minimize karg size
119  // user need to use MakeKargs() function to create kargs.
120  struct CommonKargs
121  {
122  const void* q_ptr;
123  const void* k_ptr;
124  const void* v_ptr;
125  void* lse_acc_ptr;
126  void* o_acc_ptr;
127  const void* sink_ptr;
128 
130 
135 
137  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
138  // if this param is larger than 1, indicate MQA/GQA case
141 
142  float scale_s;
143 
148 
154 
157  };
158 
160  {
161  LogitsSoftCapKargs() = default;
162 
163  void init_logits_soft_cap(float logits_soft_cap_)
164  {
165  if(0 < logits_soft_cap_)
166  {
167  logits_soft_cap = logits_soft_cap_;
169  }
170  else
171  {
172  logits_soft_cap = 0.f;
173  logits_soft_cap_rcp = 0.f;
174  }
175  }
176 
179  };
180 
182  {
183  const void* bias_ptr = nullptr;
186  };
187 
189  {
191  };
192 
193  struct AlibiKargs
194  {
195  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
196  const void* alibi_slope_ptr;
197  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
198  };
199 
200  struct MaskKargs
201  {
202  // ck_tile::index_t window_size_left, window_size_right;
205  };
206 
208  {
209  float scale_p;
210  };
211 
213  {
217  };
218 
220  {
221  bool is_gappy = false;
222  };
223 
225  {
227  };
228 
230  : CommonKargs,
231  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
232  BatchModeBiasKargs,
233  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
234  AlibiKargs,
235  EmptyKargs<0>>>,
236  std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
237  std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
238  std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>,
239  std::conditional_t<kHasLogitsSoftCap, LogitsSoftCapKargs, EmptyKargs<3>>
240  {
242 
244  ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
245  // single kcache page-block
246  ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
247  // single vcache page-block
250  };
251 
253  : CommonKargs,
254  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
255  CommonBiasKargs,
256  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
257  AlibiKargs,
258  EmptyKargs<0>>>,
259  std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
260  std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
261  std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, EmptyKargs<3>>,
262  std::conditional_t<kHasLogitsSoftCap, LogitsSoftCapKargs, EmptyKargs<4>>
263  {
267 
268  ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
269  // for single kcache page-block
270  ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
271  // for single vcache page-block
272  };
273 
274  using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
275 
277  {
281  };
282 
283  template <bool Cond = !kIsGroupMode>
284  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
285  MakeKargs(const void* q_ptr,
286  const void* k_ptr,
287  const void* v_ptr,
288  const void* bias_ptr,
289  void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
290  final lse */
291  void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
292  o */
293  ck_tile::index_t batch,
294  ck_tile::index_t seqlen_q,
295  ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
296  const void* seqlen_k_ptr, // only used for (paged-) kvcache
297  ck_tile::index_t hdim_q,
298  ck_tile::index_t hdim_v,
299  ck_tile::index_t num_head_q,
300  ck_tile::index_t nhead_ratio_qk,
301  ck_tile::index_t num_splits,
302  const void* block_table_ptr,
303  ck_tile::index_t batch_stride_block_table,
304  ck_tile::index_t page_block_size,
305  const void* cache_batch_idx,
306  float scale_s,
307  float scale_p,
308  float logits_soft_cap,
309  ck_tile::index_t stride_q,
310  ck_tile::index_t stride_k,
311  ck_tile::index_t stride_v,
312  ck_tile::index_t stride_bias,
313  ck_tile::index_t stride_o_acc,
314  ck_tile::index_t nhead_stride_q,
315  ck_tile::index_t nhead_stride_k,
316  ck_tile::index_t nhead_stride_v,
317  ck_tile::index_t nhead_stride_bias,
318  ck_tile::index_t nhead_stride_lse_acc,
319  ck_tile::index_t nhead_stride_o_acc,
320  ck_tile::index_t batch_stride_q,
321  ck_tile::index_t batch_stride_k,
322  ck_tile::index_t batch_stride_v,
323  ck_tile::index_t batch_stride_bias,
324  ck_tile::index_t batch_stride_lse_acc,
325  ck_tile::index_t batch_stride_o_acc,
326  ck_tile::index_t split_stride_lse_acc,
327  ck_tile::index_t split_stride_o_acc,
328  ck_tile::index_t window_size_left,
329  ck_tile::index_t window_size_right,
330  ck_tile::index_t sink_size,
331  ck_tile::index_t mask_type,
332  const void* sink_ptr = nullptr)
333  {
334  Kargs kargs{{q_ptr,
335  k_ptr,
336  v_ptr,
337  lse_acc_ptr,
338  o_acc_ptr,
339  sink_ptr,
340  batch,
341  seqlen_q,
342  seqlen_k,
343  hdim_q,
344  hdim_v,
345  num_head_q,
346  nhead_ratio_qk,
347  num_splits,
348 #if CK_TILE_FMHA_FWD_FAST_EXP2
349  static_cast<float>(scale_s * ck_tile::log2e_v<>),
350 #else
351  scale_s,
352 #endif
353  stride_q,
354  stride_k,
355  stride_v,
356  stride_o_acc,
357  nhead_stride_q,
358  nhead_stride_k,
359  nhead_stride_v,
360  nhead_stride_lse_acc,
361  nhead_stride_o_acc,
362  split_stride_lse_acc,
363  split_stride_o_acc}, // args for common karg
364  {}, // placeholder for bias
365  {}, // placeholder for mask
366  {}, // placeholder for fp8_static_quant args
367  {}, // placeholder for paged-block table or cache_batch_idx
368  {}, // placeholder for logits_soft_cap
369  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
370  batch_stride_q,
371  batch_stride_k,
372  batch_stride_v,
373  batch_stride_lse_acc,
374  batch_stride_o_acc};
375 
377  {
378  kargs.bias_ptr = bias_ptr;
379  kargs.stride_bias = stride_bias;
380  kargs.nhead_stride_bias = nhead_stride_bias;
381  kargs.batch_stride_bias = batch_stride_bias;
382  }
383  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
384  {
385  kargs.alibi_slope_ptr = bias_ptr;
386  kargs.alibi_slope_stride = stride_bias;
387  }
388  if constexpr(kHasMask)
389  {
390  kargs.window_size_left = window_size_left;
391  kargs.window_size_right = window_size_right;
392  kargs.sink_size = sink_size;
393  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
394  }
395  if constexpr(kDoFp8StaticQuant)
396  {
397  kargs.scale_p = scale_p;
398  }
399  if constexpr(kIsPagedKV)
400  {
401  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
402  kargs.batch_stride_block_table = batch_stride_block_table;
403  kargs.page_block_size = page_block_size;
404  }
405  else
406  {
407  kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
408  }
409  if constexpr(kHasLogitsSoftCap)
410  {
411  kargs.init_logits_soft_cap(logits_soft_cap);
412  }
413 
414  return kargs;
415  }
416 
417  template <bool Cond = kIsGroupMode>
418  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
419  MakeKargs(const void* q_ptr,
420  const void* k_ptr,
421  const void* v_ptr,
422  const void* bias_ptr,
423  void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
424  final lse */
425  void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
426  o */
427  ck_tile::index_t batch,
428  const void* seqstart_q_ptr,
429  const void* seqstart_k_ptr,
430  const void* seqlen_k_ptr,
431  ck_tile::index_t hdim_q,
432  ck_tile::index_t hdim_v,
433  ck_tile::index_t num_head_q,
434  ck_tile::index_t nhead_ratio_qk,
435  ck_tile::index_t num_splits,
436  const void* block_table_ptr,
437  ck_tile::index_t batch_stride_block_table,
438  ck_tile::index_t page_block_size,
439  bool is_gappy,
440  float scale_s,
441  float scale_p,
442  float logits_soft_cap,
443  ck_tile::index_t stride_q,
444  ck_tile::index_t stride_k,
445  ck_tile::index_t stride_v,
446  ck_tile::index_t stride_bias,
447  ck_tile::index_t stride_o_acc,
448  ck_tile::index_t nhead_stride_q,
449  ck_tile::index_t nhead_stride_k,
450  ck_tile::index_t nhead_stride_v,
451  ck_tile::index_t nhead_stride_bias,
452  ck_tile::index_t nhead_stride_lse_acc,
453  ck_tile::index_t nhead_stride_o_acc,
454  ck_tile::index_t batch_stride_k, // only used for paged-kvcache
455  ck_tile::index_t batch_stride_v, // only used for paged-kvcache
456  ck_tile::index_t split_stride_lse_acc,
457  ck_tile::index_t split_stride_o_acc,
458  ck_tile::index_t window_size_left,
459  ck_tile::index_t window_size_right,
460  ck_tile::index_t sink_size,
461  ck_tile::index_t mask_type,
462  const void* sink_ptr = nullptr)
463  {
464  Kargs kargs{{q_ptr,
465  k_ptr,
466  v_ptr,
467  lse_acc_ptr,
468  o_acc_ptr,
469  sink_ptr,
470  batch,
471  -1, // seqlen_q will be updated by another pointer
472  -1, // seqlen_k will be updated by another pointer
473  hdim_q,
474  hdim_v,
475  num_head_q,
476  nhead_ratio_qk,
477  num_splits,
478 #if CK_TILE_FMHA_FWD_FAST_EXP2
479  static_cast<float>(scale_s * ck_tile::log2e_v<>),
480 #else
481  scale_s,
482 #endif
483  stride_q,
484  stride_k,
485  stride_v,
486  stride_o_acc,
487  nhead_stride_q,
488  nhead_stride_k,
489  nhead_stride_v,
490  nhead_stride_lse_acc,
491  nhead_stride_o_acc,
492  split_stride_lse_acc,
493  split_stride_o_acc}, // args for common karg
494  {}, // placeholder for bias
495  {}, // placeholder for mask
496  {}, // placeholder for fp8_static_quant args
497  {}, // placeholder for paged-block table
498  {}, // placeholder for logits_soft_cap
499  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
500  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
501  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
502  batch_stride_k,
503  batch_stride_v};
504 
506  {
507  kargs.bias_ptr = bias_ptr;
508  kargs.stride_bias = stride_bias;
509  kargs.nhead_stride_bias = nhead_stride_bias;
510  }
511  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
512  {
513  kargs.alibi_slope_ptr = bias_ptr;
514  kargs.alibi_slope_stride = stride_bias;
515  }
516  if constexpr(kHasMask)
517  {
518  kargs.window_size_left = window_size_left;
519  kargs.window_size_right = window_size_right;
520  kargs.sink_size = sink_size;
521  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
522  }
523  if constexpr(kDoFp8StaticQuant)
524  {
525  kargs.scale_p = scale_p;
526  }
527  if constexpr(kIsPagedKV)
528  {
529  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
530  kargs.batch_stride_block_table = batch_stride_block_table;
531  kargs.page_block_size = page_block_size;
532  kargs.is_gappy = is_gappy;
533  }
534  if constexpr(kHasLogitsSoftCap)
535  {
536  kargs.init_logits_soft_cap(logits_soft_cap);
537  }
538  return kargs;
539  }
540 
541  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
542  ck_tile::index_t nhead_q,
543  ck_tile::index_t nhead_kv,
544  ck_tile::index_t max_seqlen_q,
545  ck_tile::index_t hdim_v,
546  ck_tile::index_t num_splits)
547  {
548  ck_tile::index_t nhead_ = kMergeNumHeadGroupsSeqLenQ ? nhead_kv : nhead_q;
549  ck_tile::index_t max_seqlen_q_ =
550  max_seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? nhead_q / nhead_kv : 1);
551 
552  // TODO: this may need tuning
553  return dim3(ck_tile::integer_divide_ceil(max_seqlen_q_, FmhaPipeline::kM0) *
554  ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits,
555  nhead_,
556  batch_size);
557  }
558 
559  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
560  {
561  const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
562 
563  const auto f = [](index_t dividend, index_t divisor) {
564  index_t quotient = dividend / divisor;
565  index_t modulus = dividend - quotient * divisor;
566  return ck_tile::make_tuple(quotient, modulus);
567  };
568 
569  const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits);
570  const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
571  const index_t i_nhead = blockIdx.y;
572  const index_t i_batch = blockIdx.z;
573 
574  if constexpr(kHasMask)
575  {
576  // assume that num_tile_n1 is always 1
577  return ck_tile::make_tuple(
578  (gridDim.x / kargs.num_splits) - 1 - i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
579  }
580  else
581  {
582  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
583  }
584  }
585 
586  CK_TILE_HOST static dim3 BlockSize()
587  {
588  if(is_wave32())
589  {
590  return dim3(kBlockSize / 2);
591  }
592  else
593  {
594  return dim3(kBlockSize);
595  }
596  }
597 
599  {
600  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
601  }
602 
603  CK_TILE_DEVICE void operator()(Kargs kargs) const
604  {
605  // allocate LDS
606  __shared__ char smem_ptr[GetSmemSize()];
607 
608  // divide problem
609  const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs);
610 
611  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
612  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
613 
614  long_index_t batch_offset_q = 0;
615  long_index_t batch_offset_k = 0; // unused for paged-kvcache
616  long_index_t batch_offset_v = 0; // unused for paged-kvcache
617  long_index_t batch_offset_bias = 0;
618  long_index_t batch_offset_lse_acc = 0;
619  long_index_t batch_offset_o_acc = 0;
620  index_t kv_l2p_offset =
621  0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache
622  const float sink_value =
623  kargs.sink_ptr != nullptr
624  ? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
626 
627  if constexpr(kIsGroupMode)
628  {
629  // get starting offset for each batch
630  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
631  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
632 
633  batch_offset_q = query_start * kargs.stride_q;
634  batch_offset_k = key_start * kargs.stride_k;
635  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
636  {
637  batch_offset_v = key_start * kargs.stride_v;
638  }
639  else
640  {
641  batch_offset_v = key_start;
642  }
644  {
645  batch_offset_bias = query_start * kargs.stride_bias;
646  }
647 
648  batch_offset_lse_acc = query_start;
649  batch_offset_o_acc = query_start * kargs.stride_o_acc;
650 
651  // get real # queries & # keys under group mode
652  kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
653 
654  // # of required blocks is different in each groups, terminate unnecessary blocks
655  // earlier
656  if(kargs.seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) <= i_m0)
657  {
658  return;
659  }
660 
661  if(kargs.seqlen_k_ptr != nullptr)
662  {
663  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
664  }
665  else
666  {
667  kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
668  }
669 
670  if constexpr(kIsPagedKV)
671  {
672  if(kargs.is_gappy)
673  {
674  // seqstart_k_ptr has different meaning in this case
675  kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
676  }
677  }
678  }
679  else
680  {
681  const index_t i_cache_batch = [&, i_batch_ = i_batch] {
682  if constexpr(kIsPagedKV)
683  {
684  return i_batch_;
685  }
686  else
687  {
688  return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
689  : i_batch_);
690  }
691  }();
692 
693  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
694  batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
695  batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
696  batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
697  batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
698 
700  {
701  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
702  }
703 
704  if(kargs.seqlen_k_ptr != nullptr)
705  {
706  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
707  }
708  }
709  // for simplicity, batch stride we just modify the pointer
710  const index_t i_nhead_k =
711  (kMergeNumHeadGroupsSeqLenQ ? i_nhead : i_nhead / kargs.nhead_ratio_qk);
712 
713  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
714  static_cast<long_index_t>(i_nhead) *
715  (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
716  kargs.nhead_stride_q +
717  batch_offset_q;
718  const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
719  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
720  batch_offset_k;
721  const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
722  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
723  batch_offset_v;
724 
725  ODataType* o_acc_ptr = reinterpret_cast<ODataType*>(kargs.o_acc_ptr) +
726  static_cast<long_index_t>(i_nhead) *
727  (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
728  kargs.nhead_stride_o_acc +
729  batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
730 
731  // Q/K/V DRAM and DRAM window
732  const auto q_dram = [&] {
733  const auto q_dram_naive = [&] {
734  if constexpr(kMergeNumHeadGroupsSeqLenQ)
735  {
736  // reshape: (nhead_ratio_qk, seqlen_q, hdim_q) -> (nhead_ratio_qk * seqlen_q,
737  // hdim_q)
738  const auto view = make_naive_tensor_view<address_space_enum::global>(
739  q_ptr,
740  make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_q),
741  make_tuple(kargs.nhead_stride_q, kargs.stride_q, 1),
743  number<1>{});
744 
745  return transform_tensor_view(
746  view,
747  make_tuple(
748  make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
749  make_pass_through_transform(kargs.hdim_q)),
752  }
753  else
754  {
755  return make_naive_tensor_view<address_space_enum::global>(
756  q_ptr,
757  make_tuple(kargs.seqlen_q, kargs.hdim_q),
758  make_tuple(kargs.stride_q, 1),
760  number<1>{});
761  }
762  }();
763 
764  if constexpr(FmhaPipeline::kQLoadOnce)
765  {
766  return pad_tensor_view(
767  q_dram_naive,
770  }
771  else
772  {
773  return pad_tensor_view(
774  q_dram_naive,
777  }
778  }();
779 
780  const auto make_k_dram = [&](const KDataType* data, index_t height) {
781  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
782  data, // will update this pointer if using paged-kvcache
783  make_tuple(height, kargs.hdim_q),
784  make_tuple(kargs.stride_k, 1),
786  number<1>{});
787 
788  return pad_tensor_view(
789  k_dram_naive,
792  };
793  const auto k_dram = [&]() {
794  if constexpr(kIsPagedKV)
795  {
796  return make_k_dram(nullptr, kargs.page_block_size);
797  }
798  else
799  {
800  return make_k_dram(k_ptr, kargs.seqlen_k);
801  }
802  }();
803 
804  const auto make_v_dram = [&](const VDataType* data, index_t length) {
805  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
806  {
807  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
808  data, // will update this pointer if using paged-kvcache
809  make_tuple(length, kargs.hdim_v),
810  make_tuple(kargs.stride_v, 1),
812  number<1>{});
813 
814  const auto v_dram_transposed =
815  transform_tensor_view(v_dram_naive,
820 
821  return pad_tensor_view(
822  v_dram_transposed,
825  }
826  else
827  {
828  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
829  data, // will update this pointer if using paged-kvcache
830  make_tuple(kargs.hdim_v, length),
831  make_tuple(kargs.stride_v, 1),
833  number<1>{});
834 
835  return pad_tensor_view(
836  v_dram_naive,
839  }
840  };
841  const auto v_dram = [&]() {
842  if constexpr(kIsPagedKV)
843  {
844  return make_v_dram(nullptr, kargs.page_block_size);
845  }
846  else
847  {
848  return make_v_dram(v_ptr, kargs.seqlen_k);
849  }
850  }();
851 
852  auto k_page_block_navigator = [&, i_batch_ = i_batch]() {
853  if constexpr(kIsPagedKV)
854  {
855  const auto* block_indices =
856  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
857  i_batch_ * kargs.batch_stride_block_table;
858  const index_t num_blocks =
859  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
860 
861  const long_index_t fixed_offset =
862  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k;
863 
864  return make_page_block_navigator<const KDataType, 0>(
865  kargs.k_ptr,
866  kargs.batch_stride_k, // kcache page-block stride/size
867  fixed_offset,
868  block_indices,
869  num_blocks,
870  kargs.page_block_size,
871  k_dram,
872  make_k_dram(nullptr,
873  (kv_l2p_offset + kargs.seqlen_k) -
874  (num_blocks - 1) * kargs.page_block_size));
875  }
876  else
877  {
878  return make_page_block_navigator(k_dram);
879  }
880  }();
881 
882  auto v_page_block_navigator = [&, i_batch_ = i_batch]() {
883  if constexpr(kIsPagedKV)
884  {
885  const auto* block_indices =
886  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
887  i_batch_ * kargs.batch_stride_block_table;
888  const index_t num_blocks =
889  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
890 
891  const long_index_t fixed_offset =
892  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v;
893 
894  return make_page_block_navigator<const VDataType, 1>(
895  kargs.v_ptr,
896  kargs.batch_stride_v, // vcache page-block stride/size
897  fixed_offset,
898  block_indices,
899  num_blocks,
900  kargs.page_block_size,
901  v_dram,
902  make_v_dram(nullptr,
903  (kv_l2p_offset + kargs.seqlen_k) -
904  (num_blocks - 1) * kargs.page_block_size));
905  }
906  else
907  {
908  return make_page_block_navigator(v_dram);
909  }
910  }();
911 
912  auto q_dram_window = make_tile_window(
913  q_dram,
914  [&]() {
915  if constexpr(FmhaPipeline::kQLoadOnce)
918  else
920  }(),
921  {i_m0, 0});
922 
923  auto k_dram_window_lengths =
925  auto v_dram_window_lengths =
927 
930  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
931  constexpr auto bias_dram_window_lengths =
934  {
935  const BiasDataType* bias_ptr =
936  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
937  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
938  batch_offset_bias;
939 
940  const auto bias_dram = [&]() {
941  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
942  bias_ptr,
943  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
944  make_tuple(kargs.stride_bias, 1),
946  number<1>{});
947 
948  return pad_tensor_view(
949  bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{});
950  }();
951 
952  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
953  }
954  else
955  {
956  return make_null_tile_window(bias_dram_window_lengths);
957  }
958  }();
959 
960  // lse acc
961  auto lse_acc_dram_window = [&, i_nhead_ = i_nhead, i_split_ = i_split]() {
962  constexpr auto lse_acc_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
963  LSEDataType* lse_acc_ptr = reinterpret_cast<LSEDataType*>(kargs.lse_acc_ptr) +
964  static_cast<long_index_t>(i_nhead_) *
965  (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
966  kargs.nhead_stride_lse_acc +
967  batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc;
968 
969  const auto lse_acc_dram = [&] {
970  const auto lse_acc_dram_naive = [&] {
971  if constexpr(kMergeNumHeadGroupsSeqLenQ)
972  {
973  // reshape: (nhead_ratio_qk, seqlen_q) -> (nhead_ratio_qk * seqlen_q)
974  const auto view = make_naive_tensor_view<address_space_enum::global>(
975  lse_acc_ptr,
976  make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q),
977  make_tuple(kargs.nhead_stride_lse_acc, 1),
978  number<1>{},
979  number<1>{});
980 
981  return transform_tensor_view(view,
983  kargs.nhead_ratio_qk, kargs.seqlen_q))),
986  }
987  else
988  {
989  return make_naive_tensor_view<address_space_enum::global>(
990  lse_acc_ptr,
991  make_tuple(kargs.seqlen_q),
992  make_tuple(1),
993  number<1>{},
994  number<1>{});
995  }
996  }();
997  return pad_tensor_view(
998  lse_acc_dram_naive, lse_acc_dram_window_lengths, sequence<kPadSeqLenQ>{});
999  }();
1000 
1001  return make_tile_window(lse_acc_dram, lse_acc_dram_window_lengths, {i_m0});
1002  }();
1003 
1004  FmhaMask mask = [&]() {
1005  if constexpr(kHasMask)
1006  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1007  kargs.window_size_left,
1008  kargs.window_size_right,
1009  kargs.sink_size,
1010  kargs.seqlen_q,
1011  kargs.seqlen_k,
1013  else
1014  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1015  }();
1016 
1017  // WA i_batch capture structure binding before c++20
1018  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1019  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1020  {
1021  // data loading, shared by entire wg
1022  // TODO: how to use s_read?
1023  SaccDataType slope =
1024  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1025  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1026 #if CK_TILE_FMHA_FWD_FAST_EXP2
1027  slope *= ck_tile::log2e_v<>;
1028 #endif
1029  if constexpr(kHasMask)
1030  {
1031  return make_alibi_from_lr_mask<SaccDataType, true, 32>(slope,
1032  kargs.window_size_left,
1033  kargs.window_size_right,
1034  kargs.seqlen_q,
1035  kargs.seqlen_k,
1036  kargs.mask_type);
1037  }
1038  else
1039  {
1040  return Alibi<SaccDataType, true, 32>{
1041  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1042  }
1043  }
1044  else
1045  {
1046  return EmptyPositionEncoding<SaccDataType>{};
1047  }
1048  }();
1049 
1050  AttentionVariant variant;
1051  const auto variant_params = [&] {
1052  if constexpr(kHasLogitsSoftCap)
1053  {
1055  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1056  }
1057  else
1058  {
1059  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1060  }
1061  }();
1062 
1063  BlockIndices block_indices{i_batch, i_nhead, i_nhead_k};
1064 
1065  auto o_acc_tile = [&, i_split_ = i_split]() {
1066  if constexpr(kDoFp8StaticQuant)
1067  {
1068  return FmhaPipeline{}(q_dram_window,
1069  identity{}, // q_element_func
1070  k_dram_window_lengths,
1071  k_page_block_navigator,
1072  identity{}, // k_element_func
1073  v_dram_window_lengths,
1074  v_page_block_navigator,
1075  identity{}, // v_element_func
1076  bias_dram_window,
1077  identity{}, // bias_element_func
1078  lse_acc_dram_window,
1079  identity{}, // lse_element_func
1080  identity{}, // s_acc_element_func
1081  scales<remove_cvref_t<decltype(kargs.scale_p)>>{
1082  kargs.scale_p}, // p_compute_element_func
1083  identity{}, // o_acc_element_func
1084  kargs.num_splits,
1085  i_split_,
1086  mask,
1087  position_encoding,
1088  kargs.scale_s,
1089  variant,
1090  variant_params,
1091  block_indices,
1092  kv_l2p_offset,
1093  smem_ptr,
1094  sink_value);
1095  }
1096  else
1097  {
1098  return FmhaPipeline{}(q_dram_window,
1099  k_dram_window_lengths,
1100  k_page_block_navigator,
1101  v_dram_window_lengths,
1102  v_page_block_navigator,
1103  bias_dram_window,
1104  lse_acc_dram_window,
1105  kargs.num_splits,
1106  i_split_,
1107  mask,
1108  position_encoding,
1109  kargs.scale_s,
1110  variant,
1111  variant_params,
1112  block_indices,
1113  kv_l2p_offset,
1114  smem_ptr,
1115  sink_value);
1116  }
1117  }();
1118 
1119  // Oacc DRAM and Oacc DRAM window
1120  auto o_acc_dram = [&] {
1121  const auto o_acc_dram_naive = [&] {
1122  if constexpr(kMergeNumHeadGroupsSeqLenQ)
1123  {
1124  // reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk * seqlen_q,
1125  // hdim_v)
1126  const auto view = make_naive_tensor_view<address_space_enum::global>(
1127  o_acc_ptr,
1128  make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_v),
1129  make_tuple(kargs.nhead_stride_o_acc, kargs.stride_o_acc, 1),
1130  number<FmhaPipeline::kAlignmentOacc>{},
1131  number<1>{});
1132 
1133  return transform_tensor_view(
1134  view,
1135  make_tuple(
1136  make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
1137  make_pass_through_transform(kargs.hdim_v)),
1138  make_tuple(sequence<0, 1>{}, sequence<2>{}),
1139  make_tuple(sequence<0>{}, sequence<1>{}));
1140  }
1141  else
1142  {
1143  return make_naive_tensor_view<address_space_enum::global>(
1144  o_acc_ptr,
1145  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1146  make_tuple(kargs.stride_o_acc, 1),
1147  number<FmhaPipeline::kAlignmentOacc>{},
1148  number<1>{});
1149  }
1150  }();
1151 
1152  return pad_tensor_view(
1153  o_acc_dram_naive,
1154  make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
1155  sequence<kPadSeqLenQ, kPadHeadDimV>{});
1156  }();
1157 
1158  auto o_acc_dram_window =
1159  make_tile_window(o_acc_dram,
1160  make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
1161  {i_m0, i_n1});
1162 
1163  EpiloguePipeline{}(o_acc_dram_window, o_acc_tile, nullptr);
1164  }
1165 };
1166 
1167 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:146
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
_Float16 fp16_t
Definition: half.hpp:110
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition: page_block_navigator.hpp:333
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:526
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1691
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1634
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:158
scales(Scale) -> scales< Scale >
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_attention_bias_enum.hpp:19
Definition: fmha_fwd_splitkv_kernel.hpp:194
const void * alibi_slope_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:196
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_splitkv_kernel.hpp:197
Definition: fmha_fwd_splitkv_kernel.hpp:189
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:190
Definition: fmha_fwd_splitkv_kernel.hpp:240
ck_tile::index_t batch_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:248
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:246
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:243
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:244
ck_tile::index_t batch_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:249
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:241
Definition: fmha_fwd_splitkv_kernel.hpp:277
ck_tile::index_t batch_idx
Definition: fmha_fwd_splitkv_kernel.hpp:278
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_splitkv_kernel.hpp:280
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_splitkv_kernel.hpp:279
Definition: fmha_fwd_splitkv_kernel.hpp:225
const int32_t * cache_batch_idx
Definition: fmha_fwd_splitkv_kernel.hpp:226
Definition: fmha_fwd_splitkv_kernel.hpp:182
const void * bias_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:183
ck_tile::index_t stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:184
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:185
Definition: fmha_fwd_splitkv_kernel.hpp:121
ck_tile::index_t split_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:156
ck_tile::index_t nhead_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:153
const void * k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:123
ck_tile::index_t num_splits
Definition: fmha_fwd_splitkv_kernel.hpp:140
void * lse_acc_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:125
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:149
void * o_acc_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:126
ck_tile::index_t nhead_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:152
const void * sink_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:127
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:151
ck_tile::index_t hdim_q
Definition: fmha_fwd_splitkv_kernel.hpp:133
const void * v_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:124
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:150
ck_tile::index_t split_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:155
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_splitkv_kernel.hpp:139
ck_tile::index_t stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:144
ck_tile::index_t seqlen_k
Definition: fmha_fwd_splitkv_kernel.hpp:132
ck_tile::index_t batch
Definition: fmha_fwd_splitkv_kernel.hpp:129
const void * q_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:122
ck_tile::index_t stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:146
ck_tile::index_t stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:145
ck_tile::index_t num_head_q
Definition: fmha_fwd_splitkv_kernel.hpp:136
ck_tile::index_t seqlen_q
Definition: fmha_fwd_splitkv_kernel.hpp:131
ck_tile::index_t stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:147
float scale_s
Definition: fmha_fwd_splitkv_kernel.hpp:142
ck_tile::index_t hdim_v
Definition: fmha_fwd_splitkv_kernel.hpp:134
Definition: fmha_fwd_splitkv_kernel.hpp:213
ck_tile::index_t page_block_size
Definition: fmha_fwd_splitkv_kernel.hpp:216
ck_tile::index_t batch_stride_block_table
Definition: fmha_fwd_splitkv_kernel.hpp:215
const int32_t * block_table_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:214
Definition: fmha_fwd_splitkv_kernel.hpp:114
Definition: fmha_fwd_splitkv_kernel.hpp:208
float scale_p
Definition: fmha_fwd_splitkv_kernel.hpp:209
Definition: fmha_fwd_splitkv_kernel.hpp:263
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:266
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:268
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:264
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:265
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:270
Definition: fmha_fwd_splitkv_kernel.hpp:220
bool is_gappy
Definition: fmha_fwd_splitkv_kernel.hpp:221
Definition: fmha_fwd_splitkv_kernel.hpp:160
float logits_soft_cap_rcp
Definition: fmha_fwd_splitkv_kernel.hpp:178
float logits_soft_cap
Definition: fmha_fwd_splitkv_kernel.hpp:177
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_splitkv_kernel.hpp:163
Definition: fmha_fwd_splitkv_kernel.hpp:201
ck_tile::index_t sink_size
Definition: fmha_fwd_splitkv_kernel.hpp:203
ck_tile::index_t window_size_right
Definition: fmha_fwd_splitkv_kernel.hpp:203
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_splitkv_kernel.hpp:204
ck_tile::index_t window_size_left
Definition: fmha_fwd_splitkv_kernel.hpp:203
Definition: fmha_fwd_splitkv_kernel.hpp:66
Definition: fmha_fwd_splitkv_kernel.hpp:24
static constexpr auto BiasEnum
Definition: fmha_fwd_splitkv_kernel.hpp:50
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_splitkv_kernel.hpp:27
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_splitkv_kernel.hpp:36
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_splitkv_kernel.hpp:598
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_splitkv_kernel.hpp:26
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_splitkv_kernel.hpp:42
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_splitkv_kernel.hpp:47
std::conditional_t< kIsGroupMode, GroupModeKargs, BatchModeKargs > Kargs
Definition: fmha_fwd_splitkv_kernel.hpp:274
static constexpr bool kHasSink
Definition: fmha_fwd_splitkv_kernel.hpp:54
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead_q, ck_tile::index_t nhead_kv, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits)
Definition: fmha_fwd_splitkv_kernel.hpp:541
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_splitkv_kernel.hpp:46
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_splitkv_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_splitkv_kernel.hpp:34
remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_splitkv_kernel.hpp:40
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: fmha_fwd_splitkv_kernel.hpp:55
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_splitkv_kernel.hpp:58
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_splitkv_kernel.hpp:52
static CK_TILE_HOST std::string GetName()
Definition: fmha_fwd_splitkv_kernel.hpp:74
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_splitkv_kernel.hpp:38
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_fwd_splitkv_kernel.hpp:586
static constexpr bool kHasMask
Definition: fmha_fwd_splitkv_kernel.hpp:59
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_splitkv_kernel.hpp:559
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_splitkv_kernel.hpp:48
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_splitkv_kernel.hpp:35
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_splitkv_kernel.hpp:31
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_acc_ptr, void *o_acc_ptr, ck_tile::index_t batch, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, const void *sink_ptr=nullptr)
Definition: fmha_fwd_splitkv_kernel.hpp:419
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_splitkv_kernel.hpp:37
remove_cvref_t< typename FmhaPipeline::OaccDataType > OaccDataType
Definition: fmha_fwd_splitkv_kernel.hpp:39
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_acc_ptr, void *o_acc_ptr, ck_tile::index_t batch, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, const void *sink_ptr=nullptr)
Definition: fmha_fwd_splitkv_kernel.hpp:285
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_splitkv_kernel.hpp:57
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_splitkv_kernel.hpp:28
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_splitkv_kernel.hpp:33
static constexpr bool kStoreLSE
Definition: fmha_fwd_splitkv_kernel.hpp:51
static constexpr bool kIsPagedKV
Definition: fmha_fwd_splitkv_kernel.hpp:53
static constexpr bool kIsGroupMode
Definition: fmha_fwd_splitkv_kernel.hpp:44
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_splitkv_kernel.hpp:603
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_splitkv_kernel.hpp:49
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_splitkv_kernel.hpp:25
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: numeric.hpp:18
Definition: sequence.hpp:49