/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File
fmha_fwd_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
11 
12 #include <string>
13 #include <type_traits>
14 #include <utility>
15 #include <variant>
16 
17 #define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
18 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
19 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
20 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
21 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
22 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
23 
24 namespace ck_tile {
25 
26 template <typename FmhaPipeline_, typename EpiloguePipeline_>
28 {
31  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
32 
33  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
34  static_assert(kBlockPerCu > 0);
35  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
36 
47 
49 
50  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
51  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
52  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
53  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
54  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
55  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
56  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
57  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
58  static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
59  static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
60  static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
61  static constexpr bool kHasSink = FmhaPipeline::kHasSink;
62 
65  static constexpr bool kHasMask = FmhaMask::IsMasking;
66 
67  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
68 
69  static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad;
70 #if defined(__gfx950__)
71  static constexpr bool kIsAvailable = true;
72 #else
73  static constexpr bool kIsAvailable = !kUseTrLoad;
74 #endif
75  static constexpr std::string_view kPipelineName = FmhaPipeline::name;
76 
77  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
78  // arg
80  {
81  };
82 
83  // kargs use aggregate initializer, so no constructor will provided
84  // use inheritance to minimize karg size
85  // user need to use MakeKargs() function to create kargs.
87  {
88  const void* q_ptr;
89  const void* k_ptr;
90  const void* v_ptr;
91  void* o_ptr;
92  const void* sink_ptr;
93 
98 
100  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
101  // if this param is larger than 1, indicate MQA/GQA case
103  float scale_s;
104 
109 
114  };
115 
117  {
119 
120  void init_logits_soft_cap(float logits_soft_cap_)
121  {
122  if(0 < logits_soft_cap_)
123  {
124  logits_soft_cap = logits_soft_cap_;
126  }
127  else
128  {
129  logits_soft_cap = 0.f;
130  logits_soft_cap_rcp = 0.f;
131  }
132  }
133 
136  };
137 
139  {
140  const void* bias_ptr = nullptr;
143  };
144 
146  {
148  };
149 
151  {
152  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
153  const void* alibi_slope_ptr;
154  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
155  };
156 
158  {
159  // ck_tile::index_t window_size_left, window_size_right;
162  };
163 
165  {
166  const void* q_descale_ptr = nullptr;
167  const void* k_descale_ptr = nullptr;
168  const void* v_descale_ptr = nullptr;
169  };
170 
172  {
176 
179  };
180 
182  {
186  };
187 
189  {
192  };
193 
195  {
196  void* lse_ptr = nullptr;
199  };
200 
202  {
203  template <typename T>
205  {
206  T val;
207  const T* ptr;
208  };
209 
213  };
214 
216  {
217  void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
218  {
219  float p_undrop = 1.0 - p_drop;
222  rp_undrop = 1.0 / p_undrop;
223 
224  this->drop_seed.val = seed;
225  this->drop_offset.val = offset;
226  this->is_drop_seed_offset_from_host = true;
227  }
228 
229  void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
230  {
231  float p_undrop = 1.0 - p_drop;
234  rp_undrop = 1.0 / p_undrop;
235 
236  this->drop_seed.ptr = seed_ptr;
237  this->drop_offset.ptr = offset_ptr;
238  this->is_drop_seed_offset_from_host = false;
239  }
240 
241  float rp_undrop = 1;
243  bool is_store_randval = false;
244  void* rand_val_ptr = nullptr;
245 
248  };
249 
251  {
253  };
254 
256  {
258  };
259 
262  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
263  FmhaFwdBatchModeBiasKargs,
264  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
265  FmhaFwdAlibiKargs,
266  FmhaFwdEmptyKargs<0>>>,
267  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
268  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
270  QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
271  FmhaFwdCommonQScaleKargs,
272  std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
273  FmhaFwdBatchBlockScaleKargs,
274  FmhaFwdEmptyKargs<3>>>,
275  std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
276  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
277  {
282 
283  // Optional cumulative sequence length pointers for batch mode
284  // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
285  const int32_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
286  const int32_t* cu_seqlen_k_ptr = nullptr; // cumulative, length without PAD
287  };
288 
291  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
292  FmhaFwdCommonBiasKargs,
293  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
294  FmhaFwdAlibiKargs,
295  FmhaFwdEmptyKargs<0>>>,
296  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
297  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
299  QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
300  FmhaFwdCommonQScaleKargs,
301  std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
302  FmhaFwdGroupBlockScaleKargs,
303  FmhaFwdEmptyKargs<3>>>,
304  std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
305  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
306  std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
307  {
312 
313  // Optional per-sequence and cumulative logical (excluding padding) sequence length arrays
314  const int32_t* cu_seqlen_q_ptr = nullptr;
315  const int32_t* cu_seqlen_k_ptr = nullptr;
316  };
317 
318  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
319 
321  {
325  };
326 
327  template <bool Cond = !kIsGroupMode>
328  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
329  MakeKargsImpl(const void* q_ptr,
330  const void* k_ptr,
331  const void* v_ptr,
332  const void* bias_ptr,
333  const void* q_descale_ptr,
334  const void* k_descale_ptr,
335  const void* v_descale_ptr,
336  void* rand_val_ptr,
337  void* lse_ptr,
338  void* o_ptr,
339  ck_tile::index_t seqlen_q,
340  ck_tile::index_t seqlen_k,
341  ck_tile::index_t hdim_q,
342  ck_tile::index_t hdim_v,
343  ck_tile::index_t num_head_q,
344  ck_tile::index_t nhead_ratio_qk,
345  float scale_s,
346  float logits_soft_cap,
347  ck_tile::index_t stride_q,
348  ck_tile::index_t stride_k,
349  ck_tile::index_t stride_v,
350  ck_tile::index_t stride_bias,
351  ck_tile::index_t stride_randval,
352  ck_tile::index_t stride_o,
353  ck_tile::index_t nhead_stride_q,
354  ck_tile::index_t nhead_stride_k,
355  ck_tile::index_t nhead_stride_v,
356  ck_tile::index_t nhead_stride_bias,
357  ck_tile::index_t nhead_stride_randval,
358  ck_tile::index_t nhead_stride_lse,
359  ck_tile::index_t nhead_stride_o,
360  ck_tile::index_t nhead_stride_q_descale,
361  ck_tile::index_t nhead_stride_k_descale,
362  ck_tile::index_t nhead_stride_v_descale,
363  ck_tile::index_t batch_stride_q,
364  ck_tile::index_t batch_stride_k,
365  ck_tile::index_t batch_stride_v,
366  ck_tile::index_t batch_stride_bias,
367  ck_tile::index_t batch_stride_randval,
368  ck_tile::index_t batch_stride_lse,
369  ck_tile::index_t batch_stride_o,
370  ck_tile::index_t batch_stride_q_descale,
371  ck_tile::index_t batch_stride_k_descale,
372  ck_tile::index_t batch_stride_v_descale,
373  ck_tile::index_t window_size_left,
374  ck_tile::index_t window_size_right,
375  ck_tile::index_t sink_size,
376  ck_tile::index_t mask_type,
377  float p_drop,
378  bool s_randval,
379  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
380  drop_seed_offset,
381  ck_tile::index_t block_scale_size_q,
382  ck_tile::index_t block_scale_size_kv,
383  const void* cu_seqlen_q_ptr = nullptr,
384  const void* cu_seqlen_k_ptr = nullptr,
385  const void* sink_ptr = nullptr)
386  {
387  Kargs kargs{{q_ptr,
388  k_ptr,
389  v_ptr,
390  o_ptr,
391  sink_ptr,
392  seqlen_q,
393  seqlen_k,
394  hdim_q,
395  hdim_v,
396  num_head_q,
397  nhead_ratio_qk,
398 #if CK_TILE_FMHA_FWD_FAST_EXP2
399  static_cast<float>(scale_s * ck_tile::log2e_v<>),
400 #else
401  scale_s,
402 #endif
403  stride_q,
404  stride_k,
405  stride_v,
406  stride_o,
407  nhead_stride_q,
408  nhead_stride_k,
409  nhead_stride_v,
410  nhead_stride_o}, // args for common karg
411  {}, // placeholder for bias
412  {}, // placeholder for mask
413  {}, // placeholder for lse
414  {}, // placeholder for qscale
415  {}, // placeholder for dropout
416  {}, // placeholder for logits_soft_cap
417  batch_stride_q,
418  batch_stride_k,
419  batch_stride_v,
420  batch_stride_o};
421 
423  {
424  kargs.bias_ptr = bias_ptr;
425  kargs.stride_bias = stride_bias;
426  kargs.nhead_stride_bias = nhead_stride_bias;
427  kargs.batch_stride_bias = batch_stride_bias;
428  }
429  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
430  {
431  kargs.alibi_slope_ptr = bias_ptr;
432  kargs.alibi_slope_stride = stride_bias;
433  }
434  if constexpr(kHasMask)
435  {
436  kargs.window_size_left = window_size_left;
437  kargs.window_size_right = window_size_right;
438  kargs.sink_size = sink_size;
439  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
440  }
441  if constexpr(kStoreLSE)
442  {
443  kargs.lse_ptr = lse_ptr;
444  kargs.nhead_stride_lse = nhead_stride_lse;
445  kargs.batch_stride_lse = batch_stride_lse;
446  }
448  {
449  kargs.q_descale_ptr = q_descale_ptr;
450  kargs.k_descale_ptr = k_descale_ptr;
451  kargs.v_descale_ptr = v_descale_ptr;
452  }
454  {
455  kargs.q_descale_ptr = q_descale_ptr;
456  kargs.k_descale_ptr = k_descale_ptr;
457  kargs.v_descale_ptr = v_descale_ptr;
458 
459  kargs.nhead_stride_q_descale = nhead_stride_q_descale;
460  kargs.nhead_stride_k_descale = nhead_stride_k_descale;
461  kargs.nhead_stride_v_descale = nhead_stride_v_descale;
462 
463  kargs.batch_stride_q_descale = batch_stride_q_descale;
464  kargs.batch_stride_k_descale = batch_stride_k_descale;
465  kargs.batch_stride_v_descale = batch_stride_v_descale;
466 
467  kargs.block_scale_size_q = block_scale_size_q;
468  kargs.block_scale_size_kv = block_scale_size_kv;
469  }
470  if constexpr(kHasDropout)
471  {
472  if(drop_seed_offset.index() == 0) // seed & offset come from host
473  {
474  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
475  kargs.init_dropout(p_drop, seed, offset);
476  }
477  else // seed & offset come from device
478  {
479  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
480  kargs.init_dropout(p_drop,
481  reinterpret_cast<const uint64_t*>(seed_ptr),
482  reinterpret_cast<const uint64_t*>(offset_ptr));
483  }
484 
485  kargs.rand_val_ptr = rand_val_ptr;
486  kargs.stride_randval = stride_randval;
487  kargs.nhead_stride_randval = nhead_stride_randval;
488  kargs.batch_stride_randval = batch_stride_randval;
489  kargs.is_store_randval = s_randval;
490  }
491  if constexpr(kHasLogitsSoftCap)
492  {
493  kargs.init_logits_soft_cap(logits_soft_cap);
494  }
495 
496  kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
497  kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
498  return kargs;
499  }
500 
501  // std::variant<> can't take in a list initializer, overload for backward compatibility
502  template <bool Cond = !kIsGroupMode>
503  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
504  MakeKargs(const void* q_ptr,
505  const void* k_ptr,
506  const void* v_ptr,
507  const void* bias_ptr,
508  const void* q_descale_ptr,
509  const void* k_descale_ptr,
510  const void* v_descale_ptr,
511  void* rand_val_ptr,
512  void* lse_ptr,
513  void* o_ptr,
514  ck_tile::index_t seqlen_q,
515  ck_tile::index_t seqlen_k,
516  ck_tile::index_t hdim_q,
517  ck_tile::index_t hdim_v,
518  ck_tile::index_t num_head_q,
519  ck_tile::index_t nhead_ratio_qk,
520  float scale_s,
521  float logits_soft_cap,
522  ck_tile::index_t stride_q,
523  ck_tile::index_t stride_k,
524  ck_tile::index_t stride_v,
525  ck_tile::index_t stride_bias,
526  ck_tile::index_t stride_randval,
527  ck_tile::index_t stride_o,
528  ck_tile::index_t nhead_stride_q,
529  ck_tile::index_t nhead_stride_k,
530  ck_tile::index_t nhead_stride_v,
531  ck_tile::index_t nhead_stride_bias,
532  ck_tile::index_t nhead_stride_randval,
533  ck_tile::index_t nhead_stride_lse,
534  ck_tile::index_t nhead_stride_o,
535  ck_tile::index_t nhead_stride_q_descale,
536  ck_tile::index_t nhead_stride_k_descale,
537  ck_tile::index_t nhead_stride_v_descale,
538  ck_tile::index_t batch_stride_q,
539  ck_tile::index_t batch_stride_k,
540  ck_tile::index_t batch_stride_v,
541  ck_tile::index_t batch_stride_bias,
542  ck_tile::index_t batch_stride_randval,
543  ck_tile::index_t batch_stride_lse,
544  ck_tile::index_t batch_stride_o,
545  ck_tile::index_t batch_stride_q_descale,
546  ck_tile::index_t batch_stride_k_descale,
547  ck_tile::index_t batch_stride_v_descale,
548  ck_tile::index_t window_size_left,
549  ck_tile::index_t window_size_right,
550  ck_tile::index_t sink_size,
551  ck_tile::index_t mask_type,
552  float p_drop,
553  bool s_randval,
554  const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
555  ck_tile::index_t block_scale_size_q,
556  ck_tile::index_t block_scale_size_kv,
557  const void* cu_seqlen_q_ptr = nullptr,
558  const void* cu_seqlen_k_ptr = nullptr,
559  const void* sink_ptr = nullptr)
560  {
561  return MakeKargsImpl(
562  q_ptr,
563  k_ptr,
564  v_ptr,
565  bias_ptr,
566  q_descale_ptr,
567  k_descale_ptr,
568  v_descale_ptr,
569  rand_val_ptr,
570  lse_ptr,
571  o_ptr,
572  seqlen_q,
573  seqlen_k,
574  hdim_q,
575  hdim_v,
576  num_head_q,
577  nhead_ratio_qk,
578  scale_s,
579  logits_soft_cap,
580  stride_q,
581  stride_k,
582  stride_v,
583  stride_bias,
584  stride_randval,
585  stride_o,
586  nhead_stride_q,
587  nhead_stride_k,
588  nhead_stride_v,
589  nhead_stride_bias,
590  nhead_stride_randval,
591  nhead_stride_lse,
592  nhead_stride_o,
593  nhead_stride_q_descale,
594  nhead_stride_k_descale,
595  nhead_stride_v_descale,
596  batch_stride_q,
597  batch_stride_k,
598  batch_stride_v,
599  batch_stride_bias,
600  batch_stride_randval,
601  batch_stride_lse,
602  batch_stride_o,
603  batch_stride_q_descale,
604  batch_stride_k_descale,
605  batch_stride_v_descale,
606  window_size_left,
607  window_size_right,
608  sink_size,
609  mask_type,
610  p_drop,
611  s_randval,
612  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
613  block_scale_size_q,
614  block_scale_size_kv,
615  cu_seqlen_q_ptr,
616  cu_seqlen_k_ptr,
617  sink_ptr);
618  }
619 
620  // std::variant<> can't take in a list initializer, overload for backward compatibility
621  template <bool Cond = !kIsGroupMode>
622  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
623  MakeKargs(const void* q_ptr,
624  const void* k_ptr,
625  const void* v_ptr,
626  const void* bias_ptr,
627  const void* q_descale_ptr,
628  const void* k_descale_ptr,
629  const void* v_descale_ptr,
630  void* rand_val_ptr,
631  void* lse_ptr,
632  void* o_ptr,
633  ck_tile::index_t seqlen_q,
634  ck_tile::index_t seqlen_k,
635  ck_tile::index_t hdim_q,
636  ck_tile::index_t hdim_v,
637  ck_tile::index_t num_head_q,
638  ck_tile::index_t nhead_ratio_qk,
639  float scale_s,
640  float logits_soft_cap,
641  ck_tile::index_t stride_q,
642  ck_tile::index_t stride_k,
643  ck_tile::index_t stride_v,
644  ck_tile::index_t stride_bias,
645  ck_tile::index_t stride_randval,
646  ck_tile::index_t stride_o,
647  ck_tile::index_t nhead_stride_q,
648  ck_tile::index_t nhead_stride_k,
649  ck_tile::index_t nhead_stride_v,
650  ck_tile::index_t nhead_stride_bias,
651  ck_tile::index_t nhead_stride_randval,
652  ck_tile::index_t nhead_stride_lse,
653  ck_tile::index_t nhead_stride_o,
654  ck_tile::index_t nhead_stride_q_descale,
655  ck_tile::index_t nhead_stride_k_descale,
656  ck_tile::index_t nhead_stride_v_descale,
657  ck_tile::index_t batch_stride_q,
658  ck_tile::index_t batch_stride_k,
659  ck_tile::index_t batch_stride_v,
660  ck_tile::index_t batch_stride_bias,
661  ck_tile::index_t batch_stride_randval,
662  ck_tile::index_t batch_stride_lse,
663  ck_tile::index_t batch_stride_o,
664  ck_tile::index_t batch_stride_q_descale,
665  ck_tile::index_t batch_stride_k_descale,
666  ck_tile::index_t batch_stride_v_descale,
667  ck_tile::index_t window_size_left,
668  ck_tile::index_t window_size_right,
669  ck_tile::index_t sink_size,
670  ck_tile::index_t mask_type,
671  float p_drop,
672  bool s_randval,
673  const std::tuple<const void*, const void*>& drop_seed_offset,
674  ck_tile::index_t block_scale_size_q,
675  ck_tile::index_t block_scale_size_kv,
676  const void* cu_seqlen_q_ptr = nullptr,
677  const void* cu_seqlen_k_ptr = nullptr,
678  const void* sink_ptr = nullptr)
679  {
680  return MakeKargsImpl(
681  q_ptr,
682  k_ptr,
683  v_ptr,
684  bias_ptr,
685  q_descale_ptr,
686  k_descale_ptr,
687  v_descale_ptr,
688  rand_val_ptr,
689  lse_ptr,
690  o_ptr,
691  seqlen_q,
692  seqlen_k,
693  hdim_q,
694  hdim_v,
695  num_head_q,
696  nhead_ratio_qk,
697  scale_s,
698  logits_soft_cap,
699  stride_q,
700  stride_k,
701  stride_v,
702  stride_bias,
703  stride_randval,
704  stride_o,
705  nhead_stride_q,
706  nhead_stride_k,
707  nhead_stride_v,
708  nhead_stride_bias,
709  nhead_stride_randval,
710  nhead_stride_lse,
711  nhead_stride_o,
712  nhead_stride_q_descale,
713  nhead_stride_k_descale,
714  nhead_stride_v_descale,
715  batch_stride_q,
716  batch_stride_k,
717  batch_stride_v,
718  batch_stride_bias,
719  batch_stride_randval,
720  batch_stride_lse,
721  batch_stride_o,
722  batch_stride_q_descale,
723  batch_stride_k_descale,
724  batch_stride_v_descale,
725  window_size_left,
726  window_size_right,
727  sink_size,
728  mask_type,
729  p_drop,
730  s_randval,
731  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
732  block_scale_size_q,
733  block_scale_size_kv,
734  cu_seqlen_q_ptr,
735  cu_seqlen_k_ptr,
736  sink_ptr);
737  }
738 
739  template <bool Cond = kIsGroupMode>
740  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
741  MakeKargsImpl(const void* q_ptr,
742  const void* k_ptr,
743  const void* v_ptr,
744  const void* bias_ptr,
745  const void* q_descale_ptr,
746  const void* k_descale_ptr,
747  const void* v_descale_ptr,
748  void* rand_val_ptr,
749  void* lse_ptr,
750  void* o_ptr,
751  const void* seqstart_q_ptr,
752  const void* seqstart_k_ptr,
753  const void* seqlen_q_ptr,
754  const void* seqlen_k_ptr,
755  const void* block_scale_seqstart_q_ptr,
756  const void* block_scale_seqstart_k_ptr,
757  ck_tile::index_t hdim_q,
758  ck_tile::index_t hdim_v,
759  ck_tile::index_t num_head_q,
760  ck_tile::index_t nhead_ratio_qk,
761  float scale_s,
762  float logits_soft_cap,
763  ck_tile::index_t stride_q,
764  ck_tile::index_t stride_k,
765  ck_tile::index_t stride_v,
766  ck_tile::index_t stride_bias,
767  ck_tile::index_t stride_randval,
768  ck_tile::index_t stride_o,
769  ck_tile::index_t nhead_stride_q,
770  ck_tile::index_t nhead_stride_k,
771  ck_tile::index_t nhead_stride_v,
772  ck_tile::index_t nhead_stride_bias,
773  ck_tile::index_t nhead_stride_randval,
774  ck_tile::index_t nhead_stride_lse,
775  ck_tile::index_t nhead_stride_o,
776  ck_tile::index_t nhead_stride_q_descale,
777  ck_tile::index_t nhead_stride_k_descale,
778  ck_tile::index_t nhead_stride_v_descale,
779  ck_tile::index_t window_size_left,
780  ck_tile::index_t window_size_right,
781  ck_tile::index_t sink_size,
782  ck_tile::index_t mask_type,
783  ck_tile::index_t min_seqlen_q,
784  float p_drop,
785  bool s_randval,
786  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
787  drop_seed_offset,
788  ck_tile::index_t block_scale_size_q,
789  ck_tile::index_t block_scale_size_kv,
790  const void* cu_seqlen_q_ptr = nullptr,
791  const void* cu_seqlen_k_ptr = nullptr,
792  const void* sink_ptr = nullptr)
793  {
794  Kargs kargs{{q_ptr,
795  k_ptr,
796  v_ptr,
797  o_ptr,
798  sink_ptr,
799  -1, // seqlen will be updated by another pointer
800  -1, //
801  hdim_q,
802  hdim_v,
803  num_head_q,
804  nhead_ratio_qk,
805 #if CK_TILE_FMHA_FWD_FAST_EXP2
806  static_cast<float>(scale_s * ck_tile::log2e_v<>),
807 #else
808  scale_s,
809 #endif
810  stride_q,
811  stride_k,
812  stride_v,
813  stride_o,
814  nhead_stride_q,
815  nhead_stride_k,
816  nhead_stride_v,
817  nhead_stride_o}, // args for common karg
818  {}, // placeholder for bias
819  {}, // placeholder for mask
820  {}, // placeholder for lse
821  {}, // placeholder for qscale
822  {}, // placeholder for dropout
823  {}, // placeholder for logits_soft_cap
824  {}, // placeholder for min_seqlen_q
825  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
826  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
827  reinterpret_cast<const int32_t*>(seqlen_q_ptr),
828  reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
829 
831  {
832  kargs.bias_ptr = bias_ptr;
833  kargs.stride_bias = stride_bias;
834  kargs.nhead_stride_bias = nhead_stride_bias;
835  }
836  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
837  {
838  kargs.alibi_slope_ptr = bias_ptr;
839  kargs.alibi_slope_stride = stride_bias;
840  }
841  if constexpr(kHasMask)
842  {
843  kargs.window_size_left = window_size_left;
844  kargs.window_size_right = window_size_right;
845  kargs.sink_size = sink_size;
846  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
847  }
848  if constexpr(kStoreLSE)
849  {
850  kargs.lse_ptr = lse_ptr;
851  kargs.nhead_stride_lse = nhead_stride_lse;
852  }
854  {
855  kargs.q_descale_ptr = q_descale_ptr;
856  kargs.k_descale_ptr = k_descale_ptr;
857  kargs.v_descale_ptr = v_descale_ptr;
858  }
860  {
861  kargs.q_descale_ptr = q_descale_ptr;
862  kargs.k_descale_ptr = k_descale_ptr;
863  kargs.v_descale_ptr = v_descale_ptr;
864 
865  kargs.nhead_stride_q_descale = nhead_stride_q_descale;
866  kargs.nhead_stride_k_descale = nhead_stride_k_descale;
867  kargs.nhead_stride_v_descale = nhead_stride_v_descale;
868 
869  kargs.block_scale_size_q = block_scale_size_q;
870  kargs.block_scale_size_kv = block_scale_size_kv;
871 
872  kargs.block_scale_seqstart_q_ptr =
873  reinterpret_cast<const int32_t*>(block_scale_seqstart_q_ptr);
874  kargs.block_scale_seqstart_k_ptr =
875  reinterpret_cast<const int32_t*>(block_scale_seqstart_k_ptr);
876  }
877  if constexpr(kHasDropout)
878  {
879  if(drop_seed_offset.index() == 0) // seed & offset come from host
880  {
881  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
882  kargs.init_dropout(p_drop, seed, offset);
883  }
884  else // seed & offset come from device
885  {
886  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
887  kargs.init_dropout(p_drop,
888  reinterpret_cast<const uint64_t*>(seed_ptr),
889  reinterpret_cast<const uint64_t*>(offset_ptr));
890  }
891 
892  kargs.rand_val_ptr = rand_val_ptr;
893  kargs.stride_randval = stride_randval;
894  kargs.nhead_stride_randval = nhead_stride_randval;
895  kargs.is_store_randval = s_randval;
896  }
897  if constexpr(kHasLogitsSoftCap)
898  {
899  kargs.init_logits_soft_cap(logits_soft_cap);
900  }
901  if constexpr(kSkipMinSeqlenQ)
902  {
903  kargs.min_seqlen_q = min_seqlen_q;
904  }
905 
906  kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
907  kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
908  return kargs;
909  }
910 
911  // std::variant<> can't take in a list initializer, overload for backward compatibility
912  template <bool Cond = kIsGroupMode>
913  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
914  MakeKargs(const void* q_ptr,
915  const void* k_ptr,
916  const void* v_ptr,
917  const void* bias_ptr,
918  const void* q_descale_ptr,
919  const void* k_descale_ptr,
920  const void* v_descale_ptr,
921  void* rand_val_ptr,
922  void* lse_ptr,
923  void* o_ptr,
924  const void* seqstart_q_ptr,
925  const void* seqstart_k_ptr,
926  const void* seqlen_q_ptr,
927  const void* seqlen_k_ptr,
928  const void* block_scale_seqstart_q_ptr,
929  const void* block_scale_seqstart_k_ptr,
930  ck_tile::index_t hdim_q,
931  ck_tile::index_t hdim_v,
932  ck_tile::index_t num_head_q,
933  ck_tile::index_t nhead_ratio_qk,
934  float scale_s,
935  float logits_soft_cap,
936  ck_tile::index_t stride_q,
937  ck_tile::index_t stride_k,
938  ck_tile::index_t stride_v,
939  ck_tile::index_t stride_bias,
940  ck_tile::index_t stride_randval,
941  ck_tile::index_t stride_o,
942  ck_tile::index_t nhead_stride_q,
943  ck_tile::index_t nhead_stride_k,
944  ck_tile::index_t nhead_stride_v,
945  ck_tile::index_t nhead_stride_bias,
946  ck_tile::index_t nhead_stride_randval,
947  ck_tile::index_t nhead_stride_lse,
948  ck_tile::index_t nhead_stride_o,
949  ck_tile::index_t nhead_stride_q_descale,
950  ck_tile::index_t nhead_stride_k_descale,
951  ck_tile::index_t nhead_stride_v_descale,
952  ck_tile::index_t window_size_left,
953  ck_tile::index_t window_size_right,
954  ck_tile::index_t sink_size,
955  ck_tile::index_t mask_type,
956  ck_tile::index_t min_seqlen_q,
957  float p_drop,
958  bool s_randval,
959  const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
960  ck_tile::index_t block_scale_size_q,
961  ck_tile::index_t block_scale_size_kv,
962  const void* cu_seqlen_q_ptr = nullptr,
963  const void* cu_seqlen_k_ptr = nullptr,
964  const void* sink_ptr = nullptr)
965  {
966  return MakeKargsImpl(
967  q_ptr,
968  k_ptr,
969  v_ptr,
970  bias_ptr,
971  q_descale_ptr,
972  k_descale_ptr,
973  v_descale_ptr,
974  rand_val_ptr,
975  lse_ptr,
976  o_ptr,
977  seqstart_q_ptr,
978  seqstart_k_ptr,
979  seqlen_q_ptr,
980  seqlen_k_ptr,
981  block_scale_seqstart_q_ptr,
982  block_scale_seqstart_k_ptr,
983  hdim_q,
984  hdim_v,
985  num_head_q,
986  nhead_ratio_qk,
987  scale_s,
988  logits_soft_cap,
989  stride_q,
990  stride_k,
991  stride_v,
992  stride_bias,
993  stride_randval,
994  stride_o,
995  nhead_stride_q,
996  nhead_stride_k,
997  nhead_stride_v,
998  nhead_stride_bias,
999  nhead_stride_randval,
1000  nhead_stride_lse,
1001  nhead_stride_o,
1002  nhead_stride_q_descale,
1003  nhead_stride_k_descale,
1004  nhead_stride_v_descale,
1005  window_size_left,
1006  window_size_right,
1007  sink_size,
1008  mask_type,
1009  min_seqlen_q,
1010  p_drop,
1011  s_randval,
1012  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
1013  block_scale_size_q,
1014  block_scale_size_kv,
1015  cu_seqlen_q_ptr,
1016  cu_seqlen_k_ptr,
1017  sink_ptr);
1018  }
1019 
1020  // std::variant<> can't take in a list initializer, overload for backward compatibility
1021  template <bool Cond = kIsGroupMode>
1022  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1023  MakeKargs(const void* q_ptr,
1024  const void* k_ptr,
1025  const void* v_ptr,
1026  const void* bias_ptr,
1027  const void* q_descale_ptr,
1028  const void* k_descale_ptr,
1029  const void* v_descale_ptr,
1030  void* rand_val_ptr,
1031  void* lse_ptr,
1032  void* o_ptr,
1033  const void* seqstart_q_ptr,
1034  const void* seqstart_k_ptr,
1035  const void* seqlen_q_ptr,
1036  const void* seqlen_k_ptr,
1037  const void* block_scale_seqstart_q_ptr,
1038  const void* block_scale_seqstart_k_ptr,
1039  ck_tile::index_t hdim_q,
1040  ck_tile::index_t hdim_v,
1041  ck_tile::index_t num_head_q,
1042  ck_tile::index_t nhead_ratio_qk,
1043  float scale_s,
1044  float logits_soft_cap,
1045  ck_tile::index_t stride_q,
1046  ck_tile::index_t stride_k,
1047  ck_tile::index_t stride_v,
1048  ck_tile::index_t stride_bias,
1049  ck_tile::index_t stride_randval,
1050  ck_tile::index_t stride_o,
1051  ck_tile::index_t nhead_stride_q,
1052  ck_tile::index_t nhead_stride_k,
1053  ck_tile::index_t nhead_stride_v,
1054  ck_tile::index_t nhead_stride_bias,
1055  ck_tile::index_t nhead_stride_randval,
1056  ck_tile::index_t nhead_stride_lse,
1057  ck_tile::index_t nhead_stride_o,
1058  ck_tile::index_t nhead_stride_q_descale,
1059  ck_tile::index_t nhead_stride_k_descale,
1060  ck_tile::index_t nhead_stride_v_descale,
1061  ck_tile::index_t window_size_left,
1062  ck_tile::index_t window_size_right,
1063  ck_tile::index_t sink_size,
1064  ck_tile::index_t mask_type,
1065  ck_tile::index_t min_seqlen_q,
1066  float p_drop,
1067  bool s_randval,
1068  const std::tuple<const void*, const void*>& drop_seed_offset,
1069  ck_tile::index_t block_scale_size_q,
1070  ck_tile::index_t block_scale_size_kv,
1071  const void* cu_seqlen_q_ptr = nullptr,
1072  const void* cu_seqlen_k_ptr = nullptr,
1073  const void* sink_ptr = nullptr)
1074  {
1075  return MakeKargsImpl(
1076  q_ptr,
1077  k_ptr,
1078  v_ptr,
1079  bias_ptr,
1080  q_descale_ptr,
1081  k_descale_ptr,
1082  v_descale_ptr,
1083  rand_val_ptr,
1084  lse_ptr,
1085  o_ptr,
1086  seqstart_q_ptr,
1087  seqstart_k_ptr,
1088  seqlen_q_ptr,
1089  seqlen_k_ptr,
1090  block_scale_seqstart_q_ptr,
1091  block_scale_seqstart_k_ptr,
1092  hdim_q,
1093  hdim_v,
1094  num_head_q,
1095  nhead_ratio_qk,
1096  scale_s,
1097  logits_soft_cap,
1098  stride_q,
1099  stride_k,
1100  stride_v,
1101  stride_bias,
1102  stride_randval,
1103  stride_o,
1104  nhead_stride_q,
1105  nhead_stride_k,
1106  nhead_stride_v,
1107  nhead_stride_bias,
1108  nhead_stride_randval,
1109  nhead_stride_lse,
1110  nhead_stride_o,
1111  nhead_stride_q_descale,
1112  nhead_stride_k_descale,
1113  nhead_stride_v_descale,
1114  window_size_left,
1115  window_size_right,
1116  sink_size,
1117  mask_type,
1118  min_seqlen_q,
1119  p_drop,
1120  s_randval,
1121  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
1122  block_scale_size_q,
1123  block_scale_size_kv,
1124  cu_seqlen_q_ptr,
1125  cu_seqlen_k_ptr,
1126  sink_ptr);
1127  }
1128 
1129  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
1130  ck_tile::index_t nhead_,
1131  ck_tile::index_t seqlen_q_,
1132  ck_tile::index_t hdim_v_,
1133  bool has_padded_seqlen_k = false)
1134  {
1135  // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
1136  if(has_padded_seqlen_k)
1137  {
1138  // TODO: this may need tuning
1139  return dim3(nhead_,
1140  batch_size_,
1141  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
1142  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
1143  }
1144  else
1145  {
1146  // TODO: this may need tuning
1147  return dim3(nhead_,
1148  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
1149  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
1150  batch_size_);
1151  }
1152  }
1153 
1154  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
1155  {
1156  bool has_padded_seqlen_k = false;
1157 
1158  if constexpr(kIsGroupMode)
1159  has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
1160 
1161  if(has_padded_seqlen_k)
1162  {
1163  // const index_t num_tile_m0 = seqlen_q / kM0;
1164  const index_t num_tile_n1 =
1165  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
1166 
1167  const index_t i_block = blockIdx.z;
1168  const index_t i_nhead = blockIdx.x;
1169  const index_t i_batch = blockIdx.y;
1170 
1171  const auto f = [](index_t dividend, index_t divisor) {
1172  index_t quotient = dividend / divisor;
1173  index_t modulus = dividend - quotient * divisor;
1174  return ck_tile::make_tuple(quotient, modulus);
1175  };
1176 
1177  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
1178 
1179  if constexpr(kHasMask)
1180  {
1181  // assume that num_tile_n1 is always 1
1182  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1183  }
1184  else
1185  {
1186  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1187  }
1188  }
1189  else
1190  {
1191  // const index_t num_tile_m0 = seqlen_q / kM0;
1192  const index_t num_tile_n1 =
1193  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
1194 
1195  const index_t i_block = blockIdx.y; // blockIdx.x
1196  const index_t i_nhead = blockIdx.x; // blockIdx.y
1197  const index_t i_batch = blockIdx.z;
1198 
1199  const auto f = [](index_t dividend, index_t divisor) {
1200  index_t quotient = dividend / divisor;
1201  index_t modulus = dividend - quotient * divisor;
1202  return ck_tile::make_tuple(quotient, modulus);
1203  };
1204 
1205  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
1206 
1207  if constexpr(kHasMask)
1208  {
1209  // assume that num_tile_n1 is always 1
1210  return ck_tile::make_tuple(gridDim.y - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1211  }
1212  else
1213  {
1214  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1215  }
1216  }
1217  }
1218 
1219  CK_TILE_HOST static dim3 BlockSize()
1220  {
1221  if(is_wave32())
1222  {
1223  return dim3(kBlockSize / 2);
1224  }
1225  else
1226  {
1227  return dim3(kBlockSize);
1228  }
1229  }
1230 
1232  {
1233  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
1234  }
1235 
1236  CK_TILE_DEVICE void operator()(Kargs kargs) const
1237  {
1238  if constexpr(kIsAvailable)
1239  run_(std::move(kargs));
1240  }
1241 
1242  CK_TILE_DEVICE void run_(Kargs kargs) const
1243  {
1244  if constexpr(kPipelineName != "qr_async_trload")
1245  {
1246  // allocate LDS
1247  __shared__ char smem_ptr[GetSmemSize()];
1248  // divide problem
1249  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1250  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
1251  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
1252 
1253  long_index_t batch_offset_q = 0;
1254  long_index_t batch_offset_k = 0;
1255  long_index_t batch_offset_v = 0;
1256  long_index_t batch_offset_bias = 0;
1257  long_index_t batch_offset_randval = 0;
1258  long_index_t batch_offset_lse = 0;
1259  long_index_t batch_offset_o = 0;
1260  long_index_t batch_offset_q_descale = 0;
1261  long_index_t batch_offset_k_descale = 0;
1262  long_index_t batch_offset_v_descale = 0;
1263  const float sink_value =
1264  kargs.sink_ptr != nullptr
1265  ? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
1267 
1268  if constexpr(kIsGroupMode)
1269  {
1270  // Use seqstart_q_ptr and seqstart_k_ptr for physical starts
1271  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1272  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
1273 
1274  // DRAM base offsets use physical starts
1275  batch_offset_q = query_start * kargs.stride_q;
1276  batch_offset_k = key_start * kargs.stride_k;
1277  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1278  {
1279  batch_offset_v = key_start * kargs.stride_v;
1280  }
1281  else
1282  {
1283  batch_offset_v = key_start;
1284  }
1286  {
1287  batch_offset_bias = query_start * kargs.stride_bias;
1288  }
1289  if constexpr(kStoreLSE)
1290  {
1291  // LSE follows the physical layout to stay consistent with other tensors
1292  batch_offset_lse = query_start;
1293  }
1294  if constexpr(kHasDropout)
1295  {
1296  batch_offset_randval = query_start * kargs.stride_randval;
1297  }
1299  {
1300  const long_index_t bquery_start = kargs.block_scale_seqstart_q_ptr[i_batch];
1301  const long_index_t bkey_start = kargs.block_scale_seqstart_k_ptr[i_batch];
1302  batch_offset_q_descale = bquery_start;
1303  batch_offset_k_descale = bkey_start;
1304  batch_offset_v_descale = bkey_start;
1305  }
1306  batch_offset_o = query_start * kargs.stride_o;
1307 
1308  // real logical lengths (exclude PAD)
1309  // Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr
1310  if(kargs.seqlen_q_ptr != nullptr)
1311  {
1312  kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
1313  }
1314  else if(kargs.cu_seqlen_q_ptr != nullptr)
1315  {
1316  kargs.seqlen_q =
1317  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1318  }
1319  else
1320  {
1321  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1322  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1323  }
1324 
1325  if constexpr(kSkipMinSeqlenQ)
1326  {
1327  if(kargs.seqlen_q <= kargs.min_seqlen_q)
1328  {
1329  return;
1330  }
1331  }
1332 
1333  // terminate unnecessary blocks earlier
1334  if(kargs.seqlen_q <= i_m0)
1335  {
1336  return;
1337  }
1338 
1339  if(kargs.seqlen_k_ptr != nullptr)
1340  {
1341  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1342  }
1343  else if(kargs.cu_seqlen_k_ptr != nullptr)
1344  {
1345  kargs.seqlen_k =
1346  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1347  }
1348  else
1349  {
1350  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
1351  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
1352  }
1353  }
1354  else
1355  {
1356  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1357  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1358  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1360  {
1361  batch_offset_bias =
1362  static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1363  }
1364  if constexpr(kStoreLSE)
1365  {
1366  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1367  }
1368  if constexpr(kHasDropout)
1369  {
1370  batch_offset_randval =
1371  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
1372  }
1374  {
1375  batch_offset_q_descale =
1376  static_cast<long_index_t>(i_batch) * kargs.batch_stride_q_descale;
1377  batch_offset_k_descale =
1378  static_cast<long_index_t>(i_batch) * kargs.batch_stride_k_descale;
1379  batch_offset_v_descale =
1380  static_cast<long_index_t>(i_batch) * kargs.batch_stride_v_descale;
1381  }
1382  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1383 
1384  // If cumulative seqlen pointers are provided, override per-batch effective lengths
1385  if(kargs.cu_seqlen_q_ptr != nullptr)
1386  {
1387  kargs.seqlen_q =
1388  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1389  }
1390  if(kargs.cu_seqlen_k_ptr != nullptr)
1391  {
1392  kargs.seqlen_k =
1393  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1394  }
1395  }
1396 
1397  // for simplicity, batch stride we just modify the pointer
1398  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1399  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1400  batch_offset_q;
1401  const KDataType* k_ptr =
1402  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1403  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
1404  batch_offset_k;
1405  const VDataType* v_ptr =
1406  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1407  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
1408  batch_offset_v;
1409  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1410  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1411  batch_offset_o;
1412 
1413  // Q/K/V DRAM and DRAM window
1414  const auto q_dram = [&]() {
1415  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1416  q_ptr,
1417  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1418  make_tuple(kargs.stride_q, 1),
1420  number<1>{});
1421  if constexpr(FmhaPipeline::kQLoadOnce)
1422  {
1423  return pad_tensor_view(q_dram_naive,
1427  }
1428  else
1429  {
1430  return pad_tensor_view(
1431  q_dram_naive,
1434  }
1435  }();
1436  const auto k_dram = [&]() {
1437  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1438  k_ptr,
1439  make_tuple(kargs.seqlen_k, kargs.hdim_q),
1440  make_tuple(kargs.stride_k, 1),
1442  number<1>{});
1443 
1444  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1445  return pad_tensor_view(
1446  k_dram_naive,
1449  }();
1450  const auto v_dram = [&]() {
1451  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1452  {
1453  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1454  v_ptr,
1455  make_tuple(kargs.seqlen_k, kargs.hdim_v),
1456  make_tuple(kargs.stride_v, 1),
1458  number<1>{});
1459 
1460  const auto v_dram_transposed = transform_tensor_view(
1461  v_dram_naive,
1463  make_pass_through_transform(kargs.seqlen_k)),
1466 
1467  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1468  return pad_tensor_view(
1469  v_dram_transposed,
1472  }
1473  else
1474  {
1475  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1476  v_ptr,
1477  make_tuple(kargs.hdim_v, kargs.seqlen_k),
1478  make_tuple(kargs.stride_v, 1),
1480  number<1>{});
1481 
1482  constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
1483  return pad_tensor_view(
1484  v_dram_naive,
1487  }
1488  }();
1489 
1490  auto q_dram_window = make_tile_window(
1491  q_dram,
1492  [&]() {
1493  if constexpr(FmhaPipeline::kQLoadOnce)
1496  else
1498  }(),
1499  {i_m0, 0});
1500 
1501  auto k_dram_window = make_tile_window(
1502  k_dram,
1504  {0, 0});
1505 
1506  auto v_dram_window = make_tile_window(
1507  v_dram,
1509  {i_n1, 0});
1512  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
1513  constexpr auto bias_dram_window_lengths =
1516  {
1517  const BiasDataType* bias_ptr =
1518  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1519  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1520  batch_offset_bias;
1521 
1522  const auto bias_dram = [&]() {
1523  const auto bias_dram_naive =
1524  make_naive_tensor_view<address_space_enum::global>(
1525  bias_ptr,
1526  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1527  make_tuple(kargs.stride_bias, 1),
1529  number<1>{});
1530 
1531  return pad_tensor_view(bias_dram_naive,
1532  bias_dram_window_lengths,
1534  }();
1535 
1536  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1537  }
1538  else
1539  {
1540  return make_null_tile_window(bias_dram_window_lengths);
1541  }
1542  }();
1543 
1544  // lse
1545  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1546  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1547  if constexpr(kStoreLSE)
1548  {
1549  LSEDataType* lse_ptr =
1550  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1551  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
1552  batch_offset_lse;
1553 
1554  const auto lse_dram = [&]() {
1555  const auto lse_dram_naive =
1556  make_naive_tensor_view<address_space_enum::global>(
1557  lse_ptr,
1558  make_tuple(kargs.seqlen_q),
1559  make_tuple(1),
1560  number<1>{},
1561  number<1>{});
1562 
1563  return pad_tensor_view(
1564  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1565  }();
1566 
1567  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1568  }
1569  else
1570  {
1571  return make_null_tile_window(lse_dram_window_lengths);
1572  }
1573  }();
1574 
1575  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1576  if constexpr(kHasDropout)
1577  {
1578  return BlockDropout{i_batch_,
1579  i_nhead_,
1580  kargs.num_head_q,
1581  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1582  : *kargs.drop_seed.ptr,
1583  kargs.is_drop_seed_offset_from_host
1584  ? kargs.drop_offset.val
1585  : *kargs.drop_offset.ptr,
1586  kargs.rp_undrop,
1587  kargs.p_undrop_in_uint8_t,
1588  kargs.is_store_randval};
1589  }
1590  else
1591  {
1592  return NullBlockDropout{};
1593  };
1594  }();
1595 
1596  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1597  constexpr auto randval_dram_window_lengths =
1599  if constexpr(kHasDropout)
1600  {
1601  RandValOutputDataType* rand_val_ptr =
1602  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1603  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1604  batch_offset_randval;
1605 
1606  const auto randval_dram = [&]() {
1607  const auto randval_dram_naive =
1608  make_naive_tensor_view<address_space_enum::global>(
1609  rand_val_ptr,
1610  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1611  make_tuple(kargs.stride_randval, 1),
1613  number<1>{});
1614 
1615  return pad_tensor_view(randval_dram_naive,
1616  randval_dram_window_lengths,
1618  }();
1619 
1620  return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
1621  }
1622  else
1623  {
1624  return make_null_tile_window(randval_dram_window_lengths);
1625  }
1626  }();
1627 
1628  FmhaMask mask = [&]() {
1629  if constexpr(kHasMask)
1630  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1631  kargs.window_size_left,
1632  kargs.window_size_right,
1633  kargs.sink_size,
1634  kargs.seqlen_q,
1635  kargs.seqlen_k,
1637  else
1638  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1639  }();
1640 
1641  // WA i_batch capture structure binding before c++20
1642  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1643  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1644  {
1645  // data loading, shared by entire wg
1646  // TODO: how to use s_read?
1647  SaccDataType slope =
1648  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1649  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1650 #if CK_TILE_FMHA_FWD_FAST_EXP2
1651  slope *= ck_tile::log2e_v<>;
1652 #endif
1653  if constexpr(kHasMask)
1654  {
1655  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1656  kargs.window_size_left,
1657  kargs.window_size_right,
1658  kargs.seqlen_q,
1659  kargs.seqlen_k,
1660  kargs.mask_type);
1661  }
1662  else
1663  {
1665  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1666  }
1667  }
1668  else
1669  {
1671  }
1672  }();
1673 
1674  AttentionVariant variant;
1675  const auto variant_params = [&] {
1676  const float scale_s = [&] {
1678  {
1679  float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
1680  float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
1681 
1682  return kargs.scale_s * q_descale * k_descale;
1683  }
1684  else
1685  {
1686  return kargs.scale_s;
1687  }
1688  }();
1689 
1690  if constexpr(kHasLogitsSoftCap)
1691  {
1693  mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1694  }
1695  else
1696  {
1697  return ck_tile::StandardAttentionParams<FmhaMask>{mask, scale_s};
1698  }
1699  }();
1700 
1701  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1702 
1703  auto o_acc_tile = [&, i_nhead_ = i_nhead]() {
1705  {
1706  // TODO - move global load of descale to pipeline
1707  float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
1708 
1709  float scale_p =
1710  ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
1711  float scale_o = v_descale / scale_p;
1712 
1713  auto o_acc_element_func = [&]() {
1714  if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
1715  return make_composes(
1717  ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o});
1718  else
1719  return ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
1720  }();
1721  return FmhaPipeline{}(q_dram_window,
1722  identity{}, // q_element_func
1723  k_dram_window,
1724  identity{}, // k_element_func
1725  v_dram_window,
1726  identity{}, // v_element_func
1727  bias_dram_window,
1728  identity{}, // bias_element_func
1729  randval_dram_window,
1730  lse_dram_window,
1731  identity{}, // lse_element_func
1732  identity{}, // s_acc_element_func
1733  scales<remove_cvref_t<decltype(scale_p)>>{
1734  scale_p}, // p_compute_element_func
1735  o_acc_element_func, // o_acc_element_func
1736  mask,
1737  position_encoding,
1738  variant_params.sm_scale,
1739  variant,
1740  variant_params,
1741  block_indices,
1742  smem_ptr,
1743  dropout,
1744  nullptr,
1745  nullptr,
1746  1,
1747  sink_value);
1748  }
1750  {
1751  const float* q_descale_ptr =
1752  reinterpret_cast<const float*>(kargs.q_descale_ptr) +
1753  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_q_descale +
1754  batch_offset_q_descale;
1755  const float* k_descale_ptr =
1756  reinterpret_cast<const float*>(kargs.k_descale_ptr) +
1757  static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
1758  kargs.nhead_stride_k_descale +
1759  batch_offset_k_descale;
1760  const float* v_descale_ptr =
1761  reinterpret_cast<const float*>(kargs.v_descale_ptr) +
1762  static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
1763  kargs.nhead_stride_v_descale +
1764  batch_offset_v_descale;
1765 
1766  size_t idx = i_m0 / kargs.block_scale_size_q;
1767  float q_descale = q_descale_ptr[idx];
1768  // BLOCKSCALE: P is scaled in exp2(x+shift) where shift=7 or 8
1769  // Both P and rowsum are scaled by 2^shift, canceling in normalization
1770  // No additional scaling needed in p_compute_element_func or o_acc_element_func
1771 
1772  return FmhaPipeline{}(
1773  q_dram_window,
1774  identity{}, // q_element_func
1775  k_dram_window,
1776  identity{}, // k_element_func
1777  v_dram_window,
1778  identity{}, // v_element_func
1779  bias_dram_window,
1780  identity{}, // bias_element_func
1781  randval_dram_window,
1782  lse_dram_window,
1783  identity{}, // lse_element_func
1784  scales<float>(q_descale), // s_acc_element_func
1785  identity{}, // p_compute_element_func - No scaling (done in exp2)
1786  identity{}, // o_acc_element_func - No dequant needed (canceled by rowsum)
1787  mask,
1788  position_encoding,
1789  kargs.scale_s,
1790  variant,
1791  variant_params,
1792  block_indices,
1793  smem_ptr,
1794  dropout,
1795  k_descale_ptr,
1796  v_descale_ptr,
1797  kargs.block_scale_size_kv,
1798  sink_value);
1799  }
1800  else
1801  {
1802  return FmhaPipeline{}(q_dram_window,
1803  k_dram_window,
1804  v_dram_window,
1805  bias_dram_window,
1806  randval_dram_window,
1807  lse_dram_window,
1808  mask,
1809  position_encoding,
1810  variant_params.sm_scale,
1811  variant,
1812  variant_params,
1813  block_indices,
1814  smem_ptr,
1815  dropout,
1816  sink_value);
1817  }
1818  }();
1819 
1820  // O DRAM and O DRAM window
1821  auto o_dram = [&]() {
1822  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1823  o_ptr,
1824  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1825  make_tuple(kargs.stride_o, 1),
1827  number<1>{});
1828 
1829  return pad_tensor_view(
1830  o_dram_naive,
1833  }();
1834 
1835  auto o_dram_window = make_tile_window(
1836  o_dram,
1838  {i_m0, i_n1});
1839 
1840  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1841  }
1842  else
1843  {
1844  // TODO: Refine the logical here.
1845  // In Decode case
1846  // 1. we don't expect KV data reused by different ThreadGroups, bypass the cache
1847  // 2. limit the LDS usage, as we want higher occupancy
1848  // In Prefill case
1849  // 1. we expect KV data reused by different ThreadGroups, use cache
1850  // 2. use more LDS, as we want better memory latency hiding
1851  // If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the
1852  // cache
1853  constexpr bool PrefillCase = FmhaPipeline::kM0 > 64;
1854  // divide problem
1855  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1856  const float sink_value =
1857  kargs.sink_ptr != nullptr
1858  ? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
1860 
1861  const index_t i_m0 = i_tile_m * FmhaPipeline::kM0;
1862  const index_t i_n1 = i_tile_n * FmhaPipeline::kN1;
1863 
1864  long_index_t batch_offset_q = 0;
1865  long_index_t batch_offset_k = 0; // unused for paged-kvcache
1866  long_index_t batch_offset_v = 0; // unused for paged-kvcache
1867  long_index_t batch_offset_bias = 0;
1868  long_index_t batch_offset_lse = 0;
1869  long_index_t batch_offset_o = 0;
1870  // index_t kv_l2p_offset =
1871  // 0; // logical-to-physical offset of seqlen_k coordinate. only used for
1872  // paged-kvcache
1873 
1874  if constexpr(kIsGroupMode)
1875  {
1876  // get starting offset for each batch - use seqstart_q_ptr/seqstart_k_ptr for
1877  // physical starts
1878  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1879  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
1880 
1881  batch_offset_q = query_start * kargs.stride_q;
1882  batch_offset_k = key_start * kargs.stride_k;
1883  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1884  {
1885  batch_offset_v = key_start * kargs.stride_v;
1886  }
1887  else
1888  {
1889  // col-major V: offset along seqlen dimension is scalar index
1890  batch_offset_v = key_start;
1891  }
1893  {
1894  batch_offset_bias = query_start * kargs.stride_bias;
1895  }
1896 
1897  // LSE layout is [nhead, total_seqlen] following the physical layout for Q/O
1898  batch_offset_lse = query_start;
1899  batch_offset_o = query_start * kargs.stride_o;
1900 
1901  // get real # queries & # keys under group mode
1902  if(kargs.seqlen_q_ptr != nullptr)
1903  {
1904  kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
1905  }
1906  else if(kargs.cu_seqlen_q_ptr != nullptr)
1907  {
1908  kargs.seqlen_q =
1909  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1910  }
1911  else
1912  {
1913  kargs.seqlen_q =
1914  kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
1915  }
1916 
1917  // # of required blocks is different in each groups, terminate unnecessary blocks
1918  // earlier
1919  if(kargs.seqlen_q <= i_m0)
1920  {
1921  return;
1922  }
1923 
1924  if(kargs.seqlen_k_ptr != nullptr)
1925  {
1926  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1927  }
1928  else if(kargs.cu_seqlen_k_ptr != nullptr)
1929  {
1930  kargs.seqlen_k =
1931  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1932  }
1933  else
1934  {
1935  kargs.seqlen_k =
1936  kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
1937  }
1938  }
1939  else
1940  {
1941  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1942  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1943  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1944  if constexpr(kStoreLSE)
1945  {
1946  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1947  }
1948  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1949 
1951  {
1952  batch_offset_bias =
1953  static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1954  }
1955 
1956  // If cumulative seqlen pointers are provided, override per-batch effective lengths
1957  if(kargs.cu_seqlen_q_ptr != nullptr)
1958  {
1959  kargs.seqlen_q =
1960  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1961  }
1962  if(kargs.cu_seqlen_k_ptr != nullptr)
1963  {
1964  kargs.seqlen_k =
1965  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1966  }
1967  }
1968 
1969  // for simplicity, batch stride we just modify the pointer
1970  const index_t i_nhead_k = i_nhead / kargs.nhead_ratio_qk;
1971 
1972  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1973  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1974  batch_offset_q;
1975  const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1976  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
1977  batch_offset_k;
1978  const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1979  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
1980  batch_offset_v;
1981 
1982  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1983  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1984  batch_offset_o;
1985 
1986  // Q/K/V DRAM and DRAM window
1987  const auto q_dram = [&] {
1988  const auto q_dram_naive = [&] {
1989  {
1990  return make_naive_tensor_view<address_space_enum::global,
1991  memory_operation_enum::set,
1993  q_ptr,
1994  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1995  make_tuple(kargs.stride_q, 1),
1997  number<1>{});
1998  }
1999  }();
2000 
2001  if constexpr(FmhaPipeline::kQLoadOnce)
2002  {
2003  const auto seqlen_q = kargs.seqlen_q;
2004  const auto q_dram_pad = pad_tensor_view(
2005  q_dram_naive,
2008 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
2009  constexpr index_t LDSLayerSize = 256 / sizeof(QDataType);
2010  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
2011 
2012  if constexpr(XorLengthFold > 1)
2013  {
2014  const auto q_dram_unmerged = transform_tensor_view(
2015  q_dram_pad,
2016  make_tuple(
2018  make_tuple(seqlen_q / XorLengthFold, XorLengthFold)),
2022 
2023  const auto q_dram_merged = transform_tensor_view(
2024  q_dram_unmerged,
2025  make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
2027  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
2030 
2031  const auto q_dram_unmerged_xor = transform_tensor_view(
2032  q_dram_merged,
2033  make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
2039 
2040  const auto q_dram_permuted = transform_tensor_view(
2041  q_dram_unmerged_xor,
2042  make_tuple(
2044  make_tuple(seqlen_q / XorLengthFold,
2049 
2050  const auto q_dram_tmp = transform_tensor_view(
2051  q_dram_permuted,
2052  make_tuple(
2053  make_pass_through_transform(seqlen_q / XorLengthFold),
2056  number<FmhaPipeline::kQKHeaddim /
2057  FmhaPipeline::kAlignmentQ>{})),
2061 
2062  return transform_tensor_view(
2063  q_dram_tmp,
2064  make_tuple(
2066  make_tuple(seqlen_q / XorLengthFold, number<XorLengthFold>{})),
2072  }
2073  else
2074 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
2075  {
2076  const auto q_dram_unmerged = transform_tensor_view(
2077  q_dram_pad,
2078  make_tuple(
2079  make_pass_through_transform(seqlen_q),
2085 
2086  const auto q_dram_permuted = transform_tensor_view(
2087  q_dram_unmerged,
2088  make_tuple(
2089  make_xor_transform(make_tuple(seqlen_q,
2090  number<FmhaPipeline::kQKHeaddim /
2091  FmhaPipeline::kAlignmentQ>{})),
2095 
2096  return transform_tensor_view(
2097  q_dram_permuted,
2098  make_tuple(
2099  make_pass_through_transform(seqlen_q),
2105  }
2106  }
2107  else
2108  {
2109  return pad_tensor_view(
2110  q_dram_naive,
2113  }
2114  }();
2115 
2116  const auto make_k_dram = [&](const KDataType* data, index_t height) {
2117  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
2118  data, // will update this pointer if using paged-kvcache
2119  make_tuple(height, kargs.hdim_q),
2120  make_tuple(kargs.stride_k, 1),
2122  number<1>{});
2123 
2124  const auto k_dram_pad = pad_tensor_view(
2125  k_dram_naive,
2128 
2129  constexpr auto kDramTileK =
2130  FmhaPipeline::kKLoadOnce ? FmhaPipeline::kQKHeaddim : FmhaPipeline::kK0;
2131 
2132 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
2133  constexpr index_t LDSLayerSize = 256 / sizeof(KDataType);
2134  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
2135 
2136  if constexpr(XorLengthFold > 1)
2137  {
2138  const auto k_dram_unmerged = transform_tensor_view(
2139  k_dram_pad,
2141  make_tuple(height / XorLengthFold, XorLengthFold)),
2145 
2146  const auto k_dram_merged = transform_tensor_view(
2147  k_dram_unmerged,
2148  make_tuple(make_pass_through_transform(height / XorLengthFold),
2150  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
2153 
2154  const auto k_dram_unmerged_xor = transform_tensor_view(
2155  k_dram_merged,
2156  make_tuple(make_pass_through_transform(height / XorLengthFold),
2162 
2163  const auto k_dram_permuted = transform_tensor_view(
2164  k_dram_unmerged_xor,
2165  make_tuple(
2167  make_tuple(height / XorLengthFold,
2172 
2173  const auto k_dram_tmp = transform_tensor_view(
2174  k_dram_permuted,
2175  make_tuple(
2176  make_pass_through_transform(height / XorLengthFold),
2179  number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
2183 
2184  return transform_tensor_view(
2185  k_dram_tmp,
2186  make_tuple(
2188  make_tuple(height / XorLengthFold, number<XorLengthFold>{})),
2194  }
2195  else
2196 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
2197  {
2198  const auto k_dram_unmerged = transform_tensor_view(
2199  k_dram_pad,
2202  make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
2203  FmhaPipeline::kAlignmentK>{},
2204  number<kDramTileK / FmhaPipeline::kAlignmentK>{},
2208 
2209  const auto k_dram_permuted = transform_tensor_view(
2210  k_dram_unmerged,
2211  make_tuple(
2215  number<FmhaPipeline::kQKHeaddim / kDramTileK /
2216  FmhaPipeline::kAlignmentK>{}),
2220 
2221  return transform_tensor_view(
2222  k_dram_permuted,
2225  make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
2226  FmhaPipeline::kAlignmentK>{},
2227  number<kDramTileK / FmhaPipeline::kAlignmentK>{},
2231  }
2232  };
2233  const auto k_dram = [&]() {
2234  {
2235  return make_k_dram(k_ptr, kargs.seqlen_k);
2236  }
2237  }();
2238 
2239  const auto make_v_dram = [&](const VDataType* data, index_t length) {
2240  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
2241  data, // will update this pointer if using paged-kvcache
2242  make_tuple(length, kargs.hdim_v),
2243  make_tuple(kargs.stride_v, 1),
2245  number<1>{});
2246 
2247  // TODO: Add kVHeadDim
2248  constexpr index_t XorGroupSize =
2249  FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
2250 
2251  const auto v_dram_pad = pad_tensor_view(
2252  v_dram_naive,
2255 
2256 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
2257  constexpr index_t LDSLayerSize = 256 / sizeof(VDataType);
2258  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
2259 
2260  if constexpr(XorLengthFold > 1)
2261  {
2262  const auto v_dram_unmerged = transform_tensor_view(
2263  v_dram_pad,
2265  make_tuple(length / XorLengthFold, XorLengthFold)),
2269 
2270  const auto v_dram_merged = transform_tensor_view(
2271  v_dram_unmerged,
2272  make_tuple(make_pass_through_transform(length / XorLengthFold),
2274  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
2277 
2278  const auto v_dram_unmerged_xor = transform_tensor_view(
2279  v_dram_merged,
2280  make_tuple(
2281  make_pass_through_transform(length / XorLengthFold),
2283  number<XorGroupSize>{}))),
2286 
2287  const auto v_dram_permuted = transform_tensor_view(
2288  v_dram_unmerged_xor,
2289  make_tuple(
2290  make_xor_transform(make_tuple(length / XorLengthFold,
2295 
2296  const auto v_dram_tmp = transform_tensor_view(
2297  v_dram_permuted,
2298  make_tuple(make_pass_through_transform(length / XorLengthFold),
2301  number<FmhaPipeline::kQKHeaddim / XorGroupSize>{})),
2305 
2306  return transform_tensor_view(
2307  v_dram_tmp,
2309  make_tuple(length / XorLengthFold, number<XorLengthFold>{})),
2312  number<XorGroupSize>{}))),
2315  }
2316  else
2317 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
2318  {
2319  const auto v_dram_unmerged = transform_tensor_view(
2320  v_dram_pad,
2324  number<XorGroupSize>{}))),
2327 
2328  const auto v_dram_permuted = transform_tensor_view(
2329  v_dram_unmerged,
2335 
2336  return transform_tensor_view(
2337  v_dram_permuted,
2341  number<XorGroupSize>{}))),
2344  }
2345  };
2346 
2347  const auto v_dram = [&]() {
2348  {
2349  return make_v_dram(v_ptr, kargs.seqlen_k);
2350  }
2351  }();
2352 
2353  auto q_dram_window = make_tile_window(
2354  q_dram,
2355  [&]() {
2356  if constexpr(FmhaPipeline::kQLoadOnce)
2359  else
2361  }(),
2362  {i_m0, 0});
2363 
2364  auto k_dram_window = make_tile_window(
2365  k_dram,
2367  {0, 0});
2368 
2369  auto v_dram_window = make_tile_window(
2370  v_dram,
2372  {0, 0});
2373 
2376  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
2377  constexpr auto bias_dram_window_lengths =
2380  {
2381  const BiasDataType* bias_ptr =
2382  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
2383  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
2384  batch_offset_bias;
2385 
2386  const auto bias_dram = [&]() {
2387  const auto bias_dram_naive =
2388  make_naive_tensor_view<address_space_enum::global>(
2389  bias_ptr,
2390  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
2391  make_tuple(kargs.stride_bias, 1),
2393  number<1>{});
2394 
2395  return pad_tensor_view(bias_dram_naive,
2396  bias_dram_window_lengths,
2398  }();
2399 
2400  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
2401  }
2402  else
2403  {
2404  return make_null_tile_window(bias_dram_window_lengths);
2405  }
2406  }();
2407 
2408  // lse acc
2409  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
2410  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
2411  if constexpr(kStoreLSE)
2412  {
2413  LSEDataType* lse_ptr =
2414  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
2415  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
2416  batch_offset_lse;
2417 
2418  const auto lse_dram = [&] {
2419  const auto lse_dram_naive = [&] {
2420  {
2421  return make_naive_tensor_view<address_space_enum::global>(
2422  lse_ptr,
2423  make_tuple(kargs.seqlen_q),
2424  make_tuple(1),
2425  number<1>{},
2426  number<1>{});
2427  }
2428  }();
2429  return pad_tensor_view(
2430  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
2431  }();
2432 
2433  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
2434  }
2435  else
2436  {
2437  return make_null_tile_window(lse_dram_window_lengths);
2438  }
2439  }();
2440 
2441  FmhaMask mask = [&]() {
2442  if constexpr(kHasMask)
2443  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
2444  kargs.window_size_left,
2445  kargs.window_size_right,
2446  kargs.sink_size,
2447  kargs.seqlen_q,
2448  kargs.seqlen_k,
2450  else
2451  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
2452  }();
2453 
2454  // WA i_batch capture structure binding before c++20
2455  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
2456  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
2457  {
2458  // data loading, shared by entire wg
2459  // TODO: how to use s_read?
2460  SaccDataType slope =
2461  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
2462  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
2463 #if CK_TILE_FMHA_FWD_FAST_EXP2
2464  slope *= ck_tile::log2e_v<>;
2465 #endif
2466  if constexpr(kHasMask)
2467  {
2468  return make_alibi_from_lr_mask<SaccDataType, true, 32>(
2469  slope,
2470  kargs.window_size_left,
2471  kargs.window_size_right,
2472  kargs.seqlen_q,
2473  kargs.seqlen_k,
2474  kargs.mask_type);
2475  }
2476  else
2477  {
2479  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
2480  }
2481  }
2482  else
2483  {
2485  }
2486  }();
2487 
2488  auto o_acc_tile = [&]() {
2489  if constexpr(PrefillCase)
2490  {
2491  // allocate double lds
2492  // add __restrict__ here to avoid aliasing
2493  __shared__ char smem_ptrk0
2494  [FmhaPipeline::Policy::template GetSmemSizeK<typename FmhaPipeline::Problem,
2495  true>()];
2496  __shared__ char smem_ptrk1
2497  [FmhaPipeline::Policy::template GetSmemSizeK<typename FmhaPipeline::Problem,
2498  true>()];
2499  __shared__ char smem_ptrv0[FmhaPipeline::Policy::template GetSmemSizeV<
2500  typename FmhaPipeline::Problem>()];
2501  __shared__ char smem_ptrv1[FmhaPipeline::Policy::template GetSmemSizeV<
2502  typename FmhaPipeline::Problem>()];
2503 
2504  return FmhaPipeline{}(q_dram_window,
2505  k_dram_window,
2506  v_dram_window,
2507  bias_dram_window,
2508  lse_dram_window,
2509  mask,
2510  position_encoding,
2511  kargs.scale_s,
2512  sink_value,
2513  smem_ptrk0,
2514  smem_ptrk1,
2515  smem_ptrv0,
2516  smem_ptrv1);
2517  }
2518  else
2519  {
2520  __shared__ char smem_ptr[GetSmemSize()];
2521  return FmhaPipeline{}(q_dram_window,
2522  k_dram_window,
2523  v_dram_window,
2524  bias_dram_window,
2525  lse_dram_window,
2526  mask,
2527  position_encoding,
2528  kargs.scale_s,
2529  smem_ptr,
2530  sink_value);
2531  }
2532  }();
2533 
2534  // Oacc DRAM and Oacc DRAM window
2535  auto o_dram = [&] {
2536  const auto o_dram_naive = [&] {
2537  {
2538  return make_naive_tensor_view<address_space_enum::global>(
2539  o_ptr,
2540  make_tuple(kargs.seqlen_q, kargs.hdim_v),
2541  make_tuple(kargs.stride_o, 1),
2543  number<1>{});
2544  }
2545  }();
2546 
2547  return pad_tensor_view(
2548  o_dram_naive,
2551  }();
2552 
2553  auto o_dram_window = make_tile_window(
2554  o_dram,
2556  {i_m0, i_n1});
2557 
2558  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
2559  }
2560  }
2561 };
2562 
2563 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:146
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_composes(Ts &&... ts)
Definition: unary_element_function.hpp:51
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:526
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1634
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1698
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1685
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_xor_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1738
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_view.hpp:486
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:158
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
unsigned char uint8_t
Definition: stdint.h:124
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: block_position_encoding.hpp:48
Definition: block_dropout.hpp:53
const float rp_undrop
Definition: block_dropout.hpp:371
Definition: block_position_encoding.hpp:137
Definition: fmha_fwd_kernel.hpp:321
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_kernel.hpp:324
ck_tile::index_t batch_idx
Definition: fmha_fwd_kernel.hpp:322
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_kernel.hpp:323
Definition: fmha_fwd_kernel.hpp:151
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_kernel.hpp:154
const void * alibi_slope_ptr
Definition: fmha_fwd_kernel.hpp:153
Definition: fmha_fwd_kernel.hpp:182
ck_tile::index_t batch_stride_q_descale
Definition: fmha_fwd_kernel.hpp:183
ck_tile::index_t batch_stride_v_descale
Definition: fmha_fwd_kernel.hpp:185
ck_tile::index_t batch_stride_k_descale
Definition: fmha_fwd_kernel.hpp:184
Definition: fmha_fwd_kernel.hpp:146
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_kernel.hpp:147
Definition: fmha_fwd_kernel.hpp:251
ck_tile::index_t batch_stride_randval
Definition: fmha_fwd_kernel.hpp:252
Definition: fmha_fwd_kernel.hpp:277
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_kernel.hpp:281
const int32_t * cu_seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:286
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_kernel.hpp:278
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_kernel.hpp:279
const int32_t * cu_seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:285
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_kernel.hpp:280
Definition: fmha_fwd_kernel.hpp:139
const void * bias_ptr
Definition: fmha_fwd_kernel.hpp:140
ck_tile::index_t stride_bias
Definition: fmha_fwd_kernel.hpp:141
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_kernel.hpp:142
Definition: fmha_fwd_kernel.hpp:172
ck_tile::index_t nhead_stride_v_descale
Definition: fmha_fwd_kernel.hpp:175
ck_tile::index_t block_scale_size_q
Definition: fmha_fwd_kernel.hpp:177
ck_tile::index_t nhead_stride_k_descale
Definition: fmha_fwd_kernel.hpp:174
ck_tile::index_t block_scale_size_kv
Definition: fmha_fwd_kernel.hpp:178
ck_tile::index_t nhead_stride_q_descale
Definition: fmha_fwd_kernel.hpp:173
Definition: fmha_fwd_kernel.hpp:216
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_fwd_kernel.hpp:229
float rp_undrop
Definition: fmha_fwd_kernel.hpp:241
ck_tile::index_t stride_randval
Definition: fmha_fwd_kernel.hpp:246
ck_tile::index_t nhead_stride_randval
Definition: fmha_fwd_kernel.hpp:247
void * rand_val_ptr
Definition: fmha_fwd_kernel.hpp:244
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_fwd_kernel.hpp:217
bool is_store_randval
Definition: fmha_fwd_kernel.hpp:243
uint8_t p_undrop_in_uint8_t
Definition: fmha_fwd_kernel.hpp:242
Definition: fmha_fwd_kernel.hpp:87
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_kernel.hpp:111
float scale_s
Definition: fmha_fwd_kernel.hpp:103
ck_tile::index_t seqlen_k
Definition: fmha_fwd_kernel.hpp:95
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_kernel.hpp:113
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_kernel.hpp:102
ck_tile::index_t num_head_q
Definition: fmha_fwd_kernel.hpp:99
ck_tile::index_t hdim_q
Definition: fmha_fwd_kernel.hpp:96
const void * v_ptr
Definition: fmha_fwd_kernel.hpp:90
void * o_ptr
Definition: fmha_fwd_kernel.hpp:91
const void * k_ptr
Definition: fmha_fwd_kernel.hpp:89
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_kernel.hpp:110
ck_tile::index_t stride_k
Definition: fmha_fwd_kernel.hpp:106
ck_tile::index_t stride_o
Definition: fmha_fwd_kernel.hpp:108
const void * sink_ptr
Definition: fmha_fwd_kernel.hpp:92
ck_tile::index_t stride_v
Definition: fmha_fwd_kernel.hpp:107
ck_tile::index_t hdim_v
Definition: fmha_fwd_kernel.hpp:97
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_kernel.hpp:112
const void * q_ptr
Definition: fmha_fwd_kernel.hpp:88
ck_tile::index_t seqlen_q
Definition: fmha_fwd_kernel.hpp:94
ck_tile::index_t stride_q
Definition: fmha_fwd_kernel.hpp:105
Definition: fmha_fwd_kernel.hpp:195
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_kernel.hpp:198
void * lse_ptr
Definition: fmha_fwd_kernel.hpp:196
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_kernel.hpp:197
Definition: fmha_fwd_kernel.hpp:165
const void * v_descale_ptr
Definition: fmha_fwd_kernel.hpp:168
const void * k_descale_ptr
Definition: fmha_fwd_kernel.hpp:167
const void * q_descale_ptr
Definition: fmha_fwd_kernel.hpp:166
Definition: fmha_fwd_kernel.hpp:202
bool is_drop_seed_offset_from_host
Definition: fmha_fwd_kernel.hpp:212
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_fwd_kernel.hpp:210
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_fwd_kernel.hpp:211
Definition: fmha_fwd_kernel.hpp:80
Definition: fmha_fwd_kernel.hpp:189
const int32_t * block_scale_seqstart_q_ptr
Definition: fmha_fwd_kernel.hpp:190
const int32_t * block_scale_seqstart_k_ptr
Definition: fmha_fwd_kernel.hpp:191
Definition: fmha_fwd_kernel.hpp:307
const int32_t * seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:310
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_kernel.hpp:308
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:311
const int32_t * cu_seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:315
const int32_t * cu_seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:314
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_kernel.hpp:309
Definition: fmha_fwd_kernel.hpp:117
float logits_soft_cap
Definition: fmha_fwd_kernel.hpp:134
float logits_soft_cap_rcp
Definition: fmha_fwd_kernel.hpp:135
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_kernel.hpp:120
Definition: fmha_fwd_kernel.hpp:158
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_kernel.hpp:161
ck_tile::index_t window_size_right
Definition: fmha_fwd_kernel.hpp:160
ck_tile::index_t sink_size
Definition: fmha_fwd_kernel.hpp:160
ck_tile::index_t window_size_left
Definition: fmha_fwd_kernel.hpp:160
Definition: fmha_fwd_kernel.hpp:256
ck_tile::index_t min_seqlen_q
Definition: fmha_fwd_kernel.hpp:257
Definition: fmha_fwd_kernel.hpp:28
static constexpr bool kHasDropout
Definition: fmha_fwd_kernel.hpp:58
static constexpr bool kIsAvailable
Definition: fmha_fwd_kernel.hpp:73
static constexpr bool kStoreLSE
Definition: fmha_fwd_kernel.hpp:57
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_kernel.hpp:38
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, const void *block_scale_seqstart_q_ptr, const void *block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_q_descale, ck_tile::index_t nhead_stride_k_descale, ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr, const void *sink_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:1023
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_kernel.hpp:318
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_kernel.hpp:33
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_kernel.hpp:48
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_kernel.hpp:31
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_kernel.hpp:41
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k=false)
Definition: fmha_fwd_kernel.hpp:1129
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_kernel.hpp:39
static constexpr auto QScaleEnum
Definition: fmha_fwd_kernel.hpp:59
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_kernel.hpp:35
static constexpr auto BiasEnum
Definition: fmha_fwd_kernel.hpp:56
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, const void *block_scale_seqstart_q_ptr, const void *block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_q_descale, ck_tile::index_t nhead_stride_k_descale, ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr, const void *sink_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:914
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_q_descale, ck_tile::index_t nhead_stride_k_descale, ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t batch_stride_q_descale, ck_tile::index_t batch_stride_k_descale, ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr, const void *sink_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:329
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_kernel.hpp:54
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_q_descale, ck_tile::index_t nhead_stride_k_descale, ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t batch_stride_q_descale, ck_tile::index_t batch_stride_k_descale, ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr, const void *sink_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:623
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_kernel.hpp:1154
static constexpr bool kSkipMinSeqlenQ
Definition: fmha_fwd_kernel.hpp:60
static constexpr std::string_view kPipelineName
Definition: fmha_fwd_kernel.hpp:75
ck_tile::remove_cvref_t< typename FmhaPipeline::PDataType > PDataType
Definition: fmha_fwd_kernel.hpp:40
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_kernel.hpp:44
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_kernel.hpp:37
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_q_descale, ck_tile::index_t nhead_stride_k_descale, ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t batch_stride_q_descale, ck_tile::index_t batch_stride_k_descale, ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr, const void *sink_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:504
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_fwd_kernel.hpp:1219
CK_TILE_DEVICE void run_(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1242
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_kernel.hpp:63
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_kernel.hpp:1231
static constexpr bool kUseTrLoad
Definition: fmha_fwd_kernel.hpp:69
static constexpr bool kHasMask
Definition: fmha_fwd_kernel.hpp:65
static constexpr bool kUseAsyncCopy
Definition: fmha_fwd_kernel.hpp:67
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_kernel.hpp:29
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, const void *block_scale_seqstart_q_ptr, const void *block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_q_descale, ck_tile::index_t nhead_stride_k_descale, ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr, const void *sink_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:741
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_kernel.hpp:53
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_kernel.hpp:46
static constexpr bool kHasSink
Definition: fmha_fwd_kernel.hpp:61
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_kernel.hpp:64
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_kernel.hpp:55
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_fwd_kernel.hpp:43
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_kernel.hpp:30
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_kernel.hpp:52
static constexpr bool kIsGroupMode
Definition: fmha_fwd_kernel.hpp:50
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1236
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: block_dropout.hpp:39
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:114
Definition: numeric.hpp:18
Definition: coordinate_transform.hpp:1393
Definition: unary_element_function.hpp:58
Definition: math.hpp:28
Definition: sequence.hpp:49
const T * ptr
Definition: fmha_fwd_kernel.hpp:207