/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File
fmha_batch_prefill_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
12 
13 #include <string>
14 #include <type_traits>
15 #include <utility>
16 #include <variant>
17 
18 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
19 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
20 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
21 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
22 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
23 
24 namespace ck_tile {
25 
26 template <typename FmhaPipeline_, typename EpiloguePipeline_>
28 {
31  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
32 
33  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
34  static_assert(kBlockPerCu > 0);
35  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
36 
47 
49 
50  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
51  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
52  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
53  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
54  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
55  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
56  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
57  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
58  static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
59  static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
60  static constexpr auto kKVMemoryLayout = FmhaPipeline::Problem::kKVMemoryLayout;
61  static constexpr auto kKVLookupTable = FmhaPipeline::Problem::kKVLookupTable;
62  static constexpr index_t kPageBlockSize = FmhaPipeline::kPageBlockSize;
63  static constexpr index_t kVectorSize = FmhaPipeline::kVectorSize;
66  static constexpr bool kHasMask = FmhaMask::IsMasking;
67 
68  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
69  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
70  // arg
72  {
73  };
74 
75  // kargs use aggregate initializer, so no constructor will provided
76  // use inheritance to minimize karg size
77  // user need to use MakeKargs() function to create kargs.
79  {
83  };
84 
86  {
90  };
91 
97 
99  {
100  const void* q_ptr;
101  const void* k_ptr;
102  const void* v_ptr;
103  void* o_ptr;
104  const void* sink_ptr;
105 
110 
112  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
113  // if this param is larger than 1, indicate MQA/GQA case
115 
119 
120  float scale_s;
121 
126 
131  };
132 
134  {
136 
137  void init_logits_soft_cap(float logits_soft_cap_)
138  {
139  if(0 < logits_soft_cap_)
140  {
141  logits_soft_cap = logits_soft_cap_;
143  }
144  else
145  {
146  logits_soft_cap = 0.f;
147  logits_soft_cap_rcp = 0.f;
148  }
149  }
150 
153  };
154 
156  {
157  const void* bias_ptr = nullptr;
160  };
161 
163  {
165  };
166 
168  {
169  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
170  const void* alibi_slope_ptr;
171  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
172  };
173 
175  {
176  // ck_tile::index_t window_size_left, window_size_right;
179  };
180 
182  {
183  void* lse_ptr = nullptr;
186  };
187 
189  {
190  const void* q_descale_ptr = nullptr;
191  const void* k_descale_ptr = nullptr;
192  const void* v_descale_ptr = nullptr;
193  };
194 
196  {
197  template <typename T>
199  {
200  T val;
201  const T* ptr;
202  };
203 
207  };
208 
210  {
211  void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
212  {
213  float p_undrop = 1.0 - p_drop;
216  rp_undrop = 1.0 / p_undrop;
217 
218  this->drop_seed.val = seed;
219  this->drop_offset.val = offset;
220  this->is_drop_seed_offset_from_host = true;
221  }
222 
223  void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
224  {
225  float p_undrop = 1.0 - p_drop;
228  rp_undrop = 1.0 / p_undrop;
229 
230  this->drop_seed.ptr = seed_ptr;
231  this->drop_offset.ptr = offset_ptr;
232  this->is_drop_seed_offset_from_host = false;
233  }
234 
235  float rp_undrop = 1;
237  bool is_store_randval = false;
238  void* rand_val_ptr = nullptr;
239 
242  };
243 
245  {
247  };
248 
251  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
252  FmhaFwdBatchModeBiasKargs,
253  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
254  FmhaFwdAlibiKargs,
255  FmhaFwdEmptyKargs<0>>>,
256  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
257  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
258  std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
259  FmhaFwdCommonQScaleKargs,
260  FmhaFwdEmptyKargs<3>>,
261  std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
262  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
263  {
268  };
269 
272  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
273  FmhaFwdCommonBiasKargs,
274  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
275  FmhaFwdAlibiKargs,
276  FmhaFwdEmptyKargs<0>>>,
277  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
278  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
279  std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
280  FmhaFwdCommonQScaleKargs,
281  FmhaFwdEmptyKargs<3>>,
282  std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
283  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
284  {
288  };
289 
290  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
291 
293  {
297  };
298 
299  template <bool Cond = !kIsGroupMode>
300  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
301  MakeKargs(const void* q_ptr,
302  const void* k_ptr,
303  const void* v_ptr,
304  const void* bias_ptr,
305  const void* q_descale_ptr,
306  const void* k_descale_ptr,
307  const void* v_descale_ptr,
308  void* rand_val_ptr,
309  void* lse_ptr,
310  void* o_ptr,
311  ck_tile::index_t seqlen_q,
312  ck_tile::index_t hdim_q,
313  ck_tile::index_t hdim_v,
314  ck_tile::index_t num_head_q,
315  ck_tile::index_t nhead_ratio_qk,
316  int32_t num_total_pages,
317  ck_tile::index_t page_block_size,
318  const PageBlockTableKargs& page_table,
319  float scale_s,
320  [[maybe_unused]] float scale_p,
321  [[maybe_unused]] float scale_o,
322  float logits_soft_cap,
323  ck_tile::index_t stride_q,
324  ck_tile::index_t stride_k,
325  ck_tile::index_t stride_v,
326  ck_tile::index_t stride_bias,
327  ck_tile::index_t stride_randval,
328  ck_tile::index_t stride_o,
329  ck_tile::index_t nhead_stride_q,
330  ck_tile::index_t nhead_stride_k,
331  ck_tile::index_t nhead_stride_v,
332  ck_tile::index_t nhead_stride_bias,
333  ck_tile::index_t nhead_stride_randval,
334  ck_tile::index_t nhead_stride_lse,
335  ck_tile::index_t nhead_stride_o,
336  ck_tile::index_t batch_stride_q,
337  ck_tile::index_t batch_stride_k,
338  ck_tile::index_t batch_stride_v,
339  ck_tile::index_t batch_stride_bias,
340  ck_tile::index_t batch_stride_randval,
341  ck_tile::index_t batch_stride_lse,
342  ck_tile::index_t batch_stride_o,
343  ck_tile::index_t window_size_left,
344  ck_tile::index_t window_size_right,
345  ck_tile::index_t sink_size,
346  ck_tile::index_t mask_type,
347  float p_drop,
348  bool s_randval,
349  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
350  drop_seed_offset,
351  const void* sink_ptr = nullptr)
352  {
353  Kargs kargs{{q_ptr,
354  k_ptr,
355  v_ptr,
356  o_ptr,
357  sink_ptr,
358  seqlen_q,
359  -1,
360  hdim_q,
361  hdim_v,
362  num_head_q,
363  nhead_ratio_qk,
364  num_total_pages,
365  page_block_size,
366  page_table,
367 #if CK_TILE_FMHA_FWD_FAST_EXP2
368  static_cast<float>(scale_s * ck_tile::log2e_v<>),
369 #else
370  scale_s,
371 #endif
372  stride_q,
373  stride_k,
374  stride_v,
375  stride_o,
376  nhead_stride_q,
377  nhead_stride_k,
378  nhead_stride_v,
379  nhead_stride_o}, // args for common karg
380  {}, // placeholder for bias
381  {}, // placeholder for mask
382  {}, // placeholder for lse
383  {}, // placeholder for qscale
384  {}, // placeholder for dropout
385  {}, // placeholder for logits_soft_cap
386  batch_stride_q,
387  batch_stride_k,
388  batch_stride_v,
389  batch_stride_o};
390 
392  {
393  kargs.bias_ptr = bias_ptr;
394  kargs.stride_bias = stride_bias;
395  kargs.nhead_stride_bias = nhead_stride_bias;
396  kargs.batch_stride_bias = batch_stride_bias;
397  }
398  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
399  {
400  kargs.alibi_slope_ptr = bias_ptr;
401  kargs.alibi_slope_stride = stride_bias;
402  }
403  if constexpr(kHasMask)
404  {
405  kargs.window_size_left = window_size_left;
406  kargs.window_size_right = window_size_right;
407  kargs.sink_size = sink_size;
408  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
409  }
410  if constexpr(kStoreLSE)
411  {
412  kargs.lse_ptr = lse_ptr;
413  kargs.nhead_stride_lse = nhead_stride_lse;
414  kargs.batch_stride_lse = batch_stride_lse;
415  }
417  {
418  kargs.q_descale_ptr = q_descale_ptr;
419  kargs.k_descale_ptr = k_descale_ptr;
420  kargs.v_descale_ptr = v_descale_ptr;
421  }
422  if constexpr(kHasDropout)
423  {
424  if(drop_seed_offset.index() == 0) // seed & offset come from host
425  {
426  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
427  kargs.init_dropout(p_drop, seed, offset);
428  }
429  else // seed & offset come from device
430  {
431  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
432  kargs.init_dropout(p_drop,
433  reinterpret_cast<const uint64_t*>(seed_ptr),
434  reinterpret_cast<const uint64_t*>(offset_ptr));
435  }
436 
437  kargs.rand_val_ptr = rand_val_ptr;
438  kargs.stride_randval = stride_randval;
439  kargs.nhead_stride_randval = nhead_stride_randval;
440  kargs.batch_stride_randval = batch_stride_randval;
441  kargs.is_store_randval = s_randval;
442  }
443  if constexpr(kHasLogitsSoftCap)
444  {
445  kargs.init_logits_soft_cap(logits_soft_cap);
446  }
447 
448  return kargs;
449  }
450 
451  template <bool Cond = kIsGroupMode>
452  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
453  MakeKargs(const void* q_ptr,
454  const void* k_ptr,
455  const void* v_ptr,
456  const void* bias_ptr,
457  const void* q_descale_ptr,
458  const void* k_descale_ptr,
459  const void* v_descale_ptr,
460  void* rand_val_ptr,
461  void* lse_ptr,
462  void* o_ptr,
463  const void* seqstart_q_ptr,
464  ck_tile::index_t hdim_q,
465  ck_tile::index_t hdim_v,
466  ck_tile::index_t num_head_q,
467  ck_tile::index_t nhead_ratio_qk,
468  int32_t num_total_pages,
469  ck_tile::index_t page_block_size,
470  const PageBlockTableKargs& page_table,
471  float scale_s,
472  [[maybe_unused]] float scale_p,
473  [[maybe_unused]] float scale_o,
474  float logits_soft_cap,
475  ck_tile::index_t stride_q,
476  ck_tile::index_t stride_k,
477  ck_tile::index_t stride_v,
478  ck_tile::index_t stride_bias,
479  ck_tile::index_t stride_randval,
480  ck_tile::index_t stride_o,
481  ck_tile::index_t nhead_stride_q,
482  ck_tile::index_t nhead_stride_k,
483  ck_tile::index_t nhead_stride_v,
484  ck_tile::index_t nhead_stride_bias,
485  ck_tile::index_t nhead_stride_randval,
486  ck_tile::index_t nhead_stride_lse,
487  ck_tile::index_t nhead_stride_o,
488  ck_tile::index_t batch_stride_k,
489  ck_tile::index_t batch_stride_v,
490  ck_tile::index_t window_size_left,
491  ck_tile::index_t window_size_right,
492  ck_tile::index_t sink_size,
493  ck_tile::index_t mask_type,
494  float p_drop,
495  bool s_randval,
496  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
497  drop_seed_offset,
498  const void* sink_ptr = nullptr)
499  {
500  Kargs kargs{{q_ptr,
501  k_ptr,
502  v_ptr,
503  o_ptr,
504  sink_ptr,
505  -1, // seqlen will be updated by another pointer
506  -1, //
507  hdim_q,
508  hdim_v,
509  num_head_q,
510  nhead_ratio_qk,
511  num_total_pages,
512  page_block_size,
513  page_table,
514 #if CK_TILE_FMHA_FWD_FAST_EXP2
515  static_cast<float>(scale_s * ck_tile::log2e_v<>),
516 #else
517  scale_s,
518 #endif
519  stride_q,
520  stride_k,
521  stride_v,
522  stride_o,
523  nhead_stride_q,
524  nhead_stride_k,
525  nhead_stride_v,
526  nhead_stride_o}, // args for common karg
527  {}, // placeholder for bias
528  {}, // placeholder for mask
529  {}, // placeholder for lse
530  {}, // placeholder for qscale
531  {}, // placeholder for dropout
532  {}, // placeholder for logits_soft_cap
533  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
534  batch_stride_k,
535  batch_stride_v};
536 
538  {
539  kargs.bias_ptr = bias_ptr;
540  kargs.stride_bias = stride_bias;
541  kargs.nhead_stride_bias = nhead_stride_bias;
542  }
543  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
544  {
545  kargs.alibi_slope_ptr = bias_ptr;
546  kargs.alibi_slope_stride = stride_bias;
547  }
548  if constexpr(kHasMask)
549  {
550  kargs.window_size_left = window_size_left;
551  kargs.window_size_right = window_size_right;
552  kargs.sink_size = sink_size;
553  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
554  }
555  if constexpr(kStoreLSE)
556  {
557  kargs.lse_ptr = lse_ptr;
558  kargs.nhead_stride_lse = nhead_stride_lse;
559  }
561  {
562  kargs.q_descale_ptr = q_descale_ptr;
563  kargs.k_descale_ptr = k_descale_ptr;
564  kargs.v_descale_ptr = v_descale_ptr;
565  }
566  if constexpr(kHasDropout)
567  {
568  if(drop_seed_offset.index() == 0) // seed & offset come from host
569  {
570  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
571  kargs.init_dropout(p_drop, seed, offset);
572  }
573  else // seed & offset come from device
574  {
575  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
576  kargs.init_dropout(p_drop,
577  reinterpret_cast<const uint64_t*>(seed_ptr),
578  reinterpret_cast<const uint64_t*>(offset_ptr));
579  }
580 
581  kargs.rand_val_ptr = rand_val_ptr;
582  kargs.stride_randval = stride_randval;
583  kargs.nhead_stride_randval = nhead_stride_randval;
584  kargs.is_store_randval = s_randval;
585  }
586  if constexpr(kHasLogitsSoftCap)
587  {
588  kargs.init_logits_soft_cap(logits_soft_cap);
589  }
590 
591  return kargs;
592  }
593 
594  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
595  ck_tile::index_t nhead_,
596  ck_tile::index_t seqlen_q_,
597  ck_tile::index_t hdim_v_)
598  {
599  if constexpr(kIsGroupMode)
600  {
601  // TODO: this may need tuning
602  return dim3(nhead_,
603  batch_size_,
604  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
605  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
606  }
607  else
608  {
609  // TODO: this may need tuning
610  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
611  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
612  nhead_,
613  batch_size_);
614  }
615  }
616 
617  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
618  {
619  if constexpr(kIsGroupMode)
620  {
621  // const index_t num_tile_m0 = seqlen_q / kM0;
622  const index_t num_tile_n1 =
623  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
624 
625  const index_t i_block = blockIdx.z;
626  const index_t i_nhead = blockIdx.x;
627  const index_t i_batch = blockIdx.y;
628 
629  const auto f = [](index_t dividend, index_t divisor) {
630  index_t quotient = dividend / divisor;
631  index_t modulus = dividend - quotient * divisor;
632  return ck_tile::make_tuple(quotient, modulus);
633  };
634 
635  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
636  if constexpr(kHasMask)
637  {
638  // assume that num_tile_n1 is always 1
639  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
640  }
641  else
642  {
643  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
644  }
645  }
646  else
647  {
648  // const index_t num_tile_m0 = seqlen_q / kM0;
649  const index_t num_tile_n1 =
650  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
651 
652  const index_t i_block = blockIdx.x;
653  const index_t i_nhead = blockIdx.y;
654  const index_t i_batch = blockIdx.z;
655 
656  const auto f = [](index_t dividend, index_t divisor) {
657  index_t quotient = dividend / divisor;
658  index_t modulus = dividend - quotient * divisor;
659  return ck_tile::make_tuple(quotient, modulus);
660  };
661 
662  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
663 
664  if constexpr(kHasMask)
665  {
666  // assume that num_tile_n1 is always 1
667  return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
668  }
669  else
670  {
671  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
672  }
673  }
674  }
675 
676  CK_TILE_HOST static dim3 BlockSize()
677  {
678  if(is_wave32())
679  {
680  return dim3(kBlockSize / 2);
681  }
682  else
683  {
684  return dim3(kBlockSize);
685  }
686  }
687 
689  {
690  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
691  }
692 
693  CK_TILE_DEVICE void operator()(Kargs kargs) const
694  {
695  // allocate LDS
696  __shared__ char smem_ptr[GetSmemSize()];
697 
698  // divide problem
699  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
700 
701  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
702  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
703 
704  long_index_t batch_offset_q = 0;
705  long_index_t batch_offset_bias = 0;
706  long_index_t batch_offset_randval = 0;
707  long_index_t batch_offset_lse = 0;
708  long_index_t batch_offset_o = 0;
709  const float sink_value =
710  kargs.sink_ptr != nullptr
711  ? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
713  const index_t seqlen_k = [&]() {
714  if constexpr(kKVLookupTable ==
716  {
717  const int32_t page_start = kargs.page_table.kv_indptr[i_batch];
718  const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1];
719  const int32_t num_page_blocks = page_end - page_start;
720  const int32_t last_page_len = [&]() {
721  if constexpr(kPageBlockSize == 1)
722  return static_cast<int32_t>(kPageBlockSize);
723  else
724  return kargs.page_table.kv_last_page_lens[i_batch];
725  }();
726  return num_page_blocks > 0
727  ? static_cast<index_t>((num_page_blocks - 1) * kargs.page_block_size +
728  last_page_len)
729  : 0;
730  }
731  else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
732  {
733  if(kargs.page_table.seqlen_k_ptr != nullptr)
734  return static_cast<index_t>(kargs.page_table.seqlen_k_ptr[i_batch]);
735  else
736  return kargs.seqlen_k;
737  }
738  }();
739  const int32_t* page_idx = [&]() {
740  if constexpr(kKVLookupTable ==
742  {
743  return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch];
744  }
745  else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
746  {
747  return kargs.page_table.block_table_ptr +
748  static_cast<long_index_t>(i_batch) *
749  kargs.page_table.batch_stride_block_table;
750  }
751  }();
752 
753  if constexpr(kIsGroupMode)
754  {
755  // get starting offset for each batch
756  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
757 
758  batch_offset_q = query_start * kargs.stride_q;
759 
761  {
762  batch_offset_bias = query_start * kargs.stride_bias;
763  }
764  if constexpr(kStoreLSE)
765  {
766  batch_offset_lse = query_start;
767  }
768  if constexpr(kHasDropout)
769  {
770  batch_offset_randval = query_start * kargs.stride_randval;
771  }
772  batch_offset_o = query_start * kargs.stride_o;
773 
774  // get real # queries & # keys under group mode
775  kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - query_start;
776 
777  // # of required blocks is different in each groups, terminate unnecessary blocks
778  // earlier
779  if(kargs.seqlen_q <= i_m0)
780  {
781  return;
782  }
783 
784  kargs.seqlen_k = seqlen_k;
785  }
786  else
787  {
788  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
789 
791  {
792  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
793  }
794  if constexpr(kStoreLSE)
795  {
796  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
797  }
798  if constexpr(kHasDropout)
799  {
800  batch_offset_randval =
801  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
802  }
803  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
804 
805  kargs.seqlen_k = seqlen_k;
806  }
807 
808  // for simplicity, batch stride we just modify the pointer
809  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
810  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
811  batch_offset_q;
812  const KDataType* k_ptr =
813  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
814  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k;
815  const VDataType* v_ptr =
816  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
817  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v;
818  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
819  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
820  batch_offset_o;
821 
822  // Q/K/V DRAM and DRAM window
823  const auto q_dram = [&]() {
824  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
825  q_ptr,
826  make_tuple(kargs.seqlen_q, kargs.hdim_q),
827  make_tuple(kargs.stride_q, 1),
829  number<1>{});
830  if constexpr(FmhaPipeline::kQLoadOnce)
831  {
832  return pad_tensor_view(
833  q_dram_naive,
836  }
837  else
838  {
839  return pad_tensor_view(
840  q_dram_naive,
843  }
844  }();
845  const auto k_dram = [&]() {
846  if constexpr(kKVMemoryLayout ==
848  {
849  // Vectorized K Layout: [NumPages, D/kVectorSize, S, kVectorSize]
850  // Logical View for Pipeline: (TotalSeqK, D)
851 
852  // Define the naive physical view with 4D shape: (NumPages, HeadDim/kVectorSize,
853  // PageBlockSize, kVectorSize)
854  // Strides: (BatchStride, PageBlockSize*kVectorSize, kVectorSize, 1)
855  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
856  k_ptr,
857  make_tuple(kargs.num_total_pages,
858  kargs.hdim_q / kVectorSize,
859  kargs.page_block_size,
860  kVectorSize),
861  make_tuple(
862  kargs.batch_stride_k, kargs.page_block_size * kVectorSize, kVectorSize, 1),
864  number<1>{});
865 
866  // Merge to (TotalSeqK, D) in a single transform:
867  // physical (Page, D/vec, S, vec) -> logical (TotalSeqK, D)
868  auto k_dram_2d = transform_tensor_view(
869  k_dram_naive,
870  make_tuple(make_merge_transform(make_tuple(kargs.num_total_pages,
871  kargs.page_block_size)), // TotalSeqK
873  make_tuple(static_cast<int32_t>(kargs.hdim_q / kVectorSize),
874  static_cast<int32_t>(kVectorSize)))), // D
877 
878  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
879  return pad_tensor_view(
880  k_dram_2d,
883  }
884  else
885  {
886  // Linear K Layout: [NumPages, PageSize, NumHeads, HeadDim]
887  // Logical View for Pipeline: (TotalSeqK, D)
888  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
889  k_ptr,
890  make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_q),
891  make_tuple(kargs.batch_stride_k, kargs.stride_k, 1),
893  number<1>{});
894 
895  // Merge to (TotalSeqK, D) in a single transform:
896  // physical (Page, S, D) -> logical (TotalSeqK, D)
897  auto k_dram_2d = transform_tensor_view(
898  k_dram_naive,
900  make_tuple(kargs.num_total_pages, kargs.page_block_size)),
901  make_pass_through_transform(kargs.hdim_q)),
904 
905  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
906  return pad_tensor_view(
907  k_dram_2d,
910  }
911  }();
912  const auto v_dram = [&]() {
913  if constexpr(kKVMemoryLayout ==
915  {
916  // Vectorized V Layout: [NumPages, S/kVectorSize, D, kVectorSize]
917  // Logical View for Pipeline: (D, TotalSeqK) - Transposed for GEMM
918 
919  // Define the naive physical view with 4D shape: (NumPages,
920  // PageBlockSize/kVectorSize, HeadDim, kVectorSize)
921  // Strides: (BatchStride, HeadDim*kVectorSize, kVectorSize, 1)
922  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
923  v_ptr,
924  make_tuple(kargs.num_total_pages,
925  kargs.page_block_size / kVectorSize,
926  kargs.hdim_v,
927  kVectorSize),
928  make_tuple(kargs.batch_stride_v, kargs.hdim_v * kVectorSize, kVectorSize, 1),
930  number<1>{});
931 
932  // Merge to (D, TotalSeqK) in a single transform:
933  // physical (Page, S/vec, D, vec) -> logical (D, TotalSeqK)
934  auto v_dram_final = transform_tensor_view(
935  v_dram_naive,
936  make_tuple(make_pass_through_transform(kargs.hdim_v), // D
937  make_merge_transform(make_tuple(kargs.num_total_pages,
938  kargs.page_block_size / kVectorSize,
939  kVectorSize))), // TotalSeqK
942 
943  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
944  return pad_tensor_view(
945  v_dram_final,
948  }
949  else
950  {
951  // Linear V Layout: [NumPages, PageSize, NumHeads, HeadDim]
952  // Logical View for Pipeline: (D, TotalSeqK)
953  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
954  v_ptr,
955  make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_v),
956  make_tuple(kargs.batch_stride_v, kargs.stride_v, 1),
958  number<1>{});
959 
960  // Merge to (D, TotalSeqK) in a single transform:
961  // physical (Page, S, D) -> logical (D, TotalSeqK)
962  auto v_dram_final = transform_tensor_view(
963  v_dram_naive,
966  make_tuple(kargs.num_total_pages, kargs.page_block_size))),
969 
970  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
971  return pad_tensor_view(
972  v_dram_final,
975  }
976  }();
977  auto q_dram_window = make_tile_window(
978  q_dram,
979  [&]() {
980  if constexpr(FmhaPipeline::kQLoadOnce)
983  else
985  }(),
986  {i_m0, 0});
987 
988  auto k_dram_window = make_tile_window(
990 
991  auto v_dram_window =
992  make_tile_window(v_dram,
994  {i_n1, 0});
997  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
998  constexpr auto bias_dram_window_lengths =
1001  {
1002  const BiasDataType* bias_ptr =
1003  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1004  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1005  batch_offset_bias;
1006 
1007  const auto bias_dram = [&]() {
1008  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1009  bias_ptr,
1010  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1011  make_tuple(kargs.stride_bias, 1),
1013  number<1>{});
1014 
1015  return pad_tensor_view(bias_dram_naive,
1016  bias_dram_window_lengths,
1018  }();
1019 
1020  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1021  }
1022  else
1023  {
1024  return make_null_tile_window(bias_dram_window_lengths);
1025  }
1026  }();
1027 
1028  // lse
1029  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1030  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1031  if constexpr(kStoreLSE)
1032  {
1033  LSEDataType* lse_ptr =
1034  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1035  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
1036 
1037  const auto lse_dram = [&]() {
1038  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1039  lse_ptr,
1040  make_tuple(kargs.seqlen_q),
1041  make_tuple(1),
1042  number<1>{},
1043  number<1>{});
1044 
1045  return pad_tensor_view(
1046  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1047  }();
1048 
1049  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1050  }
1051  else
1052  {
1053  return make_null_tile_window(lse_dram_window_lengths);
1054  }
1055  }();
1056 
1057  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1058  if constexpr(kHasDropout)
1059  {
1060  return BlockDropout{i_batch_,
1061  i_nhead_,
1062  kargs.num_head_q,
1063  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1064  : *kargs.drop_seed.ptr,
1065  kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
1066  : *kargs.drop_offset.ptr,
1067  kargs.rp_undrop,
1068  kargs.p_undrop_in_uint8_t,
1069  kargs.is_store_randval};
1070  }
1071  else
1072  {
1073  return NullBlockDropout{};
1074  };
1075  }();
1076 
1077  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1078  constexpr auto randval_dram_window_lengths =
1080  if constexpr(kHasDropout)
1081  {
1082  RandValOutputDataType* rand_val_ptr =
1083  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1084  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1085  batch_offset_randval;
1086 
1087  const auto randval_dram = [&]() {
1088  const auto randval_dram_naive =
1089  make_naive_tensor_view<address_space_enum::global>(
1090  rand_val_ptr,
1091  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1092  make_tuple(kargs.stride_randval, 1),
1094  number<1>{});
1095 
1096  return pad_tensor_view(randval_dram_naive,
1097  randval_dram_window_lengths,
1099  }();
1100 
1101  return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
1102  }
1103  else
1104  {
1105  return make_null_tile_window(randval_dram_window_lengths);
1106  }
1107  }();
1108 
1109  FmhaMask mask = [&]() {
1110  if constexpr(kHasMask)
1111  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1112  kargs.window_size_left,
1113  kargs.window_size_right,
1114  kargs.sink_size,
1115  kargs.seqlen_q,
1116  kargs.seqlen_k,
1118  else
1119  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1120  }();
1121 
1122  // WA i_batch capture structure binding before c++20
1123  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1124  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1125  {
1126  // data loading, shared by entire wg
1127  // TODO: how to use s_read?
1128  SaccDataType slope =
1129  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1130  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1131 #if CK_TILE_FMHA_FWD_FAST_EXP2
1132  slope *= ck_tile::log2e_v<>;
1133 #endif
1134  if constexpr(kHasMask)
1135  {
1136  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1137  kargs.window_size_left,
1138  kargs.window_size_right,
1139  kargs.seqlen_q,
1140  kargs.seqlen_k,
1141  kargs.mask_type);
1142  }
1143  else
1144  {
1146  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1147  }
1148  }
1149  else
1150  {
1152  }
1153  }();
1154 
1155  AttentionVariant variant;
1156  const auto variant_params = [&] {
1157  const float scale_s = [&] {
1159  {
1160  float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
1161  float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
1162 
1163  return kargs.scale_s * q_descale * k_descale;
1164  }
1165  else
1166  {
1167  return kargs.scale_s;
1168  }
1169  }();
1170 
1171  if constexpr(kHasLogitsSoftCap)
1172  {
1174  mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1175  }
1176  else
1177  {
1178  return ck_tile::StandardAttentionParams<FmhaMask>{mask, scale_s};
1179  }
1180  }();
1181 
1182  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1183 
1184  const index_t stride_k_for_pipeline =
1186  ? kVectorSize
1187  : kargs.stride_k;
1188  const index_t stride_v_for_pipeline =
1190  ? kargs.hdim_v
1191  : kargs.stride_v;
1192 
1193  auto o_acc_tile = [&] {
1195  {
1196  // TODO - move global load of descale to pipeline
1197  float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
1198 
1199  float scale_p = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
1200  float scale_o = v_descale / scale_p;
1201 
1202  auto o_acc_element_func = [&]() {
1203  if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
1205  scales<remove_cvref_t<decltype(scale_o)>>{scale_o});
1206  else
1207  return scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
1208  }();
1209 
1210  return FmhaPipeline{}(
1211  q_dram_window,
1212  identity{}, // q_element_func
1213  k_dram_window,
1214  identity{}, // k_element_func
1215  v_dram_window,
1216  identity{}, // v_element_func
1217  bias_dram_window,
1218  identity{}, // bias_element_func
1219  randval_dram_window,
1220  lse_dram_window,
1221  identity{}, // lse_element_func
1222  identity{}, // s_acc_element_func
1223  scales<remove_cvref_t<decltype(scale_p)>>{scale_p}, // p_compute_element_func
1224  o_acc_element_func, // o_acc_element_func
1225  mask,
1226  position_encoding,
1227  variant_params.sm_scale,
1228  variant,
1229  variant_params,
1230  block_indices,
1231  smem_ptr,
1232  page_idx,
1233  stride_k_for_pipeline,
1234  stride_v_for_pipeline,
1235  kargs.batch_stride_k,
1236  kargs.batch_stride_v,
1237  dropout,
1238  sink_value);
1239  }
1240  else
1241  {
1242  return FmhaPipeline{}(q_dram_window,
1243  k_dram_window,
1244  v_dram_window,
1245  bias_dram_window,
1246  randval_dram_window,
1247  lse_dram_window,
1248  mask,
1249  position_encoding,
1250  variant_params.sm_scale,
1251  variant,
1252  variant_params,
1253  block_indices,
1254  smem_ptr,
1255  page_idx,
1256  stride_k_for_pipeline,
1257  stride_v_for_pipeline,
1258  kargs.batch_stride_k,
1259  kargs.batch_stride_v,
1260  dropout,
1261  sink_value);
1262  }
1263  }();
1264 
1265  // O DRAM and O DRAM window
1266  auto o_dram = [&]() {
1267  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1268  o_ptr,
1269  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1270  make_tuple(kargs.stride_o, 1),
1272  number<1>{});
1273 
1274  return pad_tensor_view(
1275  o_dram_naive,
1278  }();
1279 
1280  auto o_dram_window =
1281  make_tile_window(o_dram,
1283  {i_m0, i_n1});
1284 
1285  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1286  }
1287 };
1288 
1289 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:146
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_composes(Ts &&... ts)
Definition: unary_element_function.hpp:51
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:526
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1691
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1634
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:158
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
unsigned char uint8_t
Definition: stdint.h:124
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: block_position_encoding.hpp:48
Definition: block_dropout.hpp:53
const float rp_undrop
Definition: block_dropout.hpp:371
Definition: block_position_encoding.hpp:137
Definition: fmha_batch_prefill_kernel.hpp:293
ck_tile::index_t kv_head_idx
Definition: fmha_batch_prefill_kernel.hpp:296
ck_tile::index_t qo_head_idx
Definition: fmha_batch_prefill_kernel.hpp:295
ck_tile::index_t batch_idx
Definition: fmha_batch_prefill_kernel.hpp:294
Definition: fmha_batch_prefill_kernel.hpp:168
ck_tile::index_t alibi_slope_stride
Definition: fmha_batch_prefill_kernel.hpp:171
const void * alibi_slope_ptr
Definition: fmha_batch_prefill_kernel.hpp:170
ck_tile::index_t batch_stride_bias
Definition: fmha_batch_prefill_kernel.hpp:164
ck_tile::index_t batch_stride_randval
Definition: fmha_batch_prefill_kernel.hpp:246
ck_tile::index_t batch_stride_o
Definition: fmha_batch_prefill_kernel.hpp:267
ck_tile::index_t batch_stride_v
Definition: fmha_batch_prefill_kernel.hpp:266
ck_tile::index_t batch_stride_q
Definition: fmha_batch_prefill_kernel.hpp:264
ck_tile::index_t batch_stride_k
Definition: fmha_batch_prefill_kernel.hpp:265
ck_tile::index_t nhead_stride_bias
Definition: fmha_batch_prefill_kernel.hpp:159
ck_tile::index_t stride_bias
Definition: fmha_batch_prefill_kernel.hpp:158
const void * bias_ptr
Definition: fmha_batch_prefill_kernel.hpp:157
ck_tile::index_t stride_randval
Definition: fmha_batch_prefill_kernel.hpp:240
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_batch_prefill_kernel.hpp:223
ck_tile::index_t nhead_stride_randval
Definition: fmha_batch_prefill_kernel.hpp:241
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_batch_prefill_kernel.hpp:211
void * rand_val_ptr
Definition: fmha_batch_prefill_kernel.hpp:238
float rp_undrop
Definition: fmha_batch_prefill_kernel.hpp:235
bool is_store_randval
Definition: fmha_batch_prefill_kernel.hpp:237
uint8_t p_undrop_in_uint8_t
Definition: fmha_batch_prefill_kernel.hpp:236
ck_tile::index_t page_block_size
Definition: fmha_batch_prefill_kernel.hpp:117
ck_tile::index_t stride_q
Definition: fmha_batch_prefill_kernel.hpp:122
ck_tile::index_t stride_v
Definition: fmha_batch_prefill_kernel.hpp:124
int32_t num_total_pages
Definition: fmha_batch_prefill_kernel.hpp:116
float scale_s
Definition: fmha_batch_prefill_kernel.hpp:120
PageBlockTableKargs page_table
Definition: fmha_batch_prefill_kernel.hpp:118
ck_tile::index_t seqlen_q
Definition: fmha_batch_prefill_kernel.hpp:106
ck_tile::index_t stride_k
Definition: fmha_batch_prefill_kernel.hpp:123
ck_tile::index_t nhead_stride_o
Definition: fmha_batch_prefill_kernel.hpp:130
ck_tile::index_t nhead_stride_k
Definition: fmha_batch_prefill_kernel.hpp:128
ck_tile::index_t nhead_ratio_qk
Definition: fmha_batch_prefill_kernel.hpp:114
ck_tile::index_t nhead_stride_v
Definition: fmha_batch_prefill_kernel.hpp:129
ck_tile::index_t nhead_stride_q
Definition: fmha_batch_prefill_kernel.hpp:127
const void * sink_ptr
Definition: fmha_batch_prefill_kernel.hpp:104
const void * v_ptr
Definition: fmha_batch_prefill_kernel.hpp:102
void * o_ptr
Definition: fmha_batch_prefill_kernel.hpp:103
ck_tile::index_t seqlen_k
Definition: fmha_batch_prefill_kernel.hpp:107
ck_tile::index_t stride_o
Definition: fmha_batch_prefill_kernel.hpp:125
ck_tile::index_t hdim_v
Definition: fmha_batch_prefill_kernel.hpp:109
ck_tile::index_t num_head_q
Definition: fmha_batch_prefill_kernel.hpp:111
const void * k_ptr
Definition: fmha_batch_prefill_kernel.hpp:101
ck_tile::index_t hdim_q
Definition: fmha_batch_prefill_kernel.hpp:108
const void * q_ptr
Definition: fmha_batch_prefill_kernel.hpp:100
ck_tile::index_t batch_stride_lse
Definition: fmha_batch_prefill_kernel.hpp:185
ck_tile::index_t nhead_stride_lse
Definition: fmha_batch_prefill_kernel.hpp:184
void * lse_ptr
Definition: fmha_batch_prefill_kernel.hpp:183
const void * v_descale_ptr
Definition: fmha_batch_prefill_kernel.hpp:192
const void * k_descale_ptr
Definition: fmha_batch_prefill_kernel.hpp:191
const void * q_descale_ptr
Definition: fmha_batch_prefill_kernel.hpp:190
bool is_drop_seed_offset_from_host
Definition: fmha_batch_prefill_kernel.hpp:206
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_batch_prefill_kernel.hpp:204
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_batch_prefill_kernel.hpp:205
Definition: fmha_batch_prefill_kernel.hpp:72
ck_tile::index_t batch_stride_v
Definition: fmha_batch_prefill_kernel.hpp:287
ck_tile::index_t batch_stride_k
Definition: fmha_batch_prefill_kernel.hpp:286
const int32_t * seqstart_q_ptr
Definition: fmha_batch_prefill_kernel.hpp:285
float logits_soft_cap_rcp
Definition: fmha_batch_prefill_kernel.hpp:152
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_batch_prefill_kernel.hpp:137
float logits_soft_cap
Definition: fmha_batch_prefill_kernel.hpp:151
Definition: fmha_batch_prefill_kernel.hpp:175
ck_tile::index_t sink_size
Definition: fmha_batch_prefill_kernel.hpp:177
ck_tile::index_t window_size_right
Definition: fmha_batch_prefill_kernel.hpp:177
ck_tile::index_t window_size_left
Definition: fmha_batch_prefill_kernel.hpp:177
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_batch_prefill_kernel.hpp:178
const int32_t * kv_page_indices
Definition: fmha_batch_prefill_kernel.hpp:81
const int32_t * kv_indptr
Definition: fmha_batch_prefill_kernel.hpp:80
const int32_t * kv_last_page_lens
Definition: fmha_batch_prefill_kernel.hpp:82
const int32_t * block_table_ptr
Definition: fmha_batch_prefill_kernel.hpp:87
const int32_t * seqlen_k_ptr
Definition: fmha_batch_prefill_kernel.hpp:89
ck_tile::index_t batch_stride_block_table
Definition: fmha_batch_prefill_kernel.hpp:88
Definition: fmha_batch_prefill_kernel.hpp:28
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_batch_prefill_kernel.hpp:617
static constexpr bool kIsGroupMode
Definition: fmha_batch_prefill_kernel.hpp:50
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_batch_prefill_kernel.hpp:33
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_batch_prefill_kernel.hpp:39
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_batch_prefill_kernel.hpp:29
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_batch_prefill_kernel.hpp:38
static constexpr bool kPadSeqLenQ
Definition: fmha_batch_prefill_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_batch_prefill_kernel.hpp:43
static constexpr bool kPadHeadDimV
Definition: fmha_batch_prefill_kernel.hpp:54
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_batch_prefill_kernel.hpp:44
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_batch_prefill_kernel.hpp:37
static constexpr auto kKVMemoryLayout
Definition: fmha_batch_prefill_kernel.hpp:60
static constexpr bool kHasMask
Definition: fmha_batch_prefill_kernel.hpp:66
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_batch_prefill_kernel.hpp:41
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_batch_prefill_kernel.hpp:65
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_batch_prefill_kernel.hpp:676
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, ck_tile::index_t page_block_size, const PageBlockTableKargs &page_table, float scale_s, [[maybe_unused]] float scale_p, [[maybe_unused]] float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset, const void *sink_ptr=nullptr)
Definition: fmha_batch_prefill_kernel.hpp:453
static constexpr bool kPadSeqLenK
Definition: fmha_batch_prefill_kernel.hpp:52
static constexpr auto QScaleEnum
Definition: fmha_batch_prefill_kernel.hpp:59
static constexpr bool kHasLogitsSoftCap
Definition: fmha_batch_prefill_kernel.hpp:55
static constexpr bool kHasDropout
Definition: fmha_batch_prefill_kernel.hpp:58
static constexpr bool kStoreLSE
Definition: fmha_batch_prefill_kernel.hpp:57
static constexpr auto kKVLookupTable
Definition: fmha_batch_prefill_kernel.hpp:61
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_batch_prefill_kernel.hpp:46
static constexpr index_t kPageBlockSize
Definition: fmha_batch_prefill_kernel.hpp:62
ck_tile::remove_cvref_t< typename FmhaPipeline::PDataType > PDataType
Definition: fmha_batch_prefill_kernel.hpp:40
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_batch_prefill_kernel.hpp:688
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, ck_tile::index_t page_block_size, const PageBlockTableKargs &page_table, float scale_s, [[maybe_unused]] float scale_p, [[maybe_unused]] float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset, const void *sink_ptr=nullptr)
Definition: fmha_batch_prefill_kernel.hpp:301
std::conditional_t< kKVLookupTable==BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D, SglangPageTableKargs, VllmPageTableKargs > PageBlockTableKargs
Definition: fmha_batch_prefill_kernel.hpp:96
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_batch_prefill_kernel.hpp:31
static constexpr auto BiasEnum
Definition: fmha_batch_prefill_kernel.hpp:56
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_batch_prefill_kernel.hpp:48
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_batch_prefill_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_batch_prefill_kernel.hpp:64
static constexpr index_t kVectorSize
Definition: fmha_batch_prefill_kernel.hpp:63
static constexpr bool kUseAsyncCopy
Definition: fmha_batch_prefill_kernel.hpp:68
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_batch_prefill_kernel.hpp:35
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_)
Definition: fmha_batch_prefill_kernel.hpp:594
static constexpr bool kPadHeadDimQ
Definition: fmha_batch_prefill_kernel.hpp:53
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_batch_prefill_kernel.hpp:693
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_batch_prefill_kernel.hpp:290
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_batch_prefill_kernel.hpp:30
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: block_dropout.hpp:39
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:114
Definition: numeric.hpp:18
Definition: coordinate_transform.hpp:1393
Definition: unary_element_function.hpp:58
Definition: math.hpp:28
Definition: sequence.hpp:49