/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Source File
gemm_quant_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <string>
7 
8 #include "ck_tile/core.hpp"
14 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
19 namespace detail {
20 // Helper templates for safe type extraction
21 template <typename, typename Default, typename = void>
23 {
24  using type = Default;
25 };
26 
27 template <typename T, typename Default>
28 struct get_aq_layout_or<T, Default, std::void_t<typename T::AQLayout>>
29 {
30  using type = typename T::AQLayout;
31 };
32 
33 template <typename, typename Default, typename = void>
35 {
36  using type = Default;
37 };
38 
39 template <typename T, typename Default>
40 struct get_bq_layout_or<T, Default, std::void_t<typename T::BQLayout>>
41 {
42  using type = typename T::BQLayout;
43 };
44 
45 template <typename, typename Default, typename = void>
47 {
48  using type = Default;
49 };
50 
51 template <typename T, typename Default>
52 struct get_aq_data_type_or<T, Default, std::void_t<typename T::AQDataType>>
53 {
54  using type = typename T::AQDataType;
55 };
56 
57 template <typename, typename Default, typename = void>
59 {
60  using type = Default;
61 };
62 
63 template <typename T, typename Default>
64 struct get_bq_data_type_or<T, Default, std::void_t<typename T::BQDataType>>
65 {
66  using type = typename T::BQDataType;
67 };
68 
69 template <typename, typename = void>
71 {
72  static constexpr bool value = false;
73 };
74 
75 template <typename T>
76 struct is_quantpreshuffle_enabled<T, std::void_t<decltype(T::PreshuffleQuant)>>
77 {
78  static constexpr bool value = T::PreshuffleQuant;
79 };
80 
81 template <typename, typename = void>
83 {
84  static constexpr bool value = false;
85 };
86 
87 template <typename T>
88 struct is_preshuffleB_enabled<T, std::void_t<decltype(T::PreshuffleB)>>
89 {
90  static constexpr bool value = T::PreshuffleB;
91 };
92 } // namespace detail
93 
95 {
98  index_t N_,
99  index_t K_,
100  index_t QK_A_,
101  index_t QK_B_,
102  index_t stride_A_,
103  index_t stride_B_,
104  index_t stride_C_,
105  index_t stride_AQ_,
106  index_t stride_BQ_)
107  : M(M_),
108  N(N_),
109  K(K_),
110  QK_A(QK_A_),
111  QK_B(QK_B_),
112  stride_A(stride_A_),
113  stride_B(stride_B_),
114  stride_C(stride_C_),
115  stride_AQ(stride_AQ_),
116  stride_BQ(stride_BQ_)
117  {
118  }
119 
130 };
131 
133 {
135  CK_TILE_HOST QuantGemmHostArgs(const void* a_ptr_,
136  const void* b_ptr_,
137  void* c_ptr_,
138  const void* aq_ptr_,
139  const void* bq_ptr_,
140  index_t k_batch_,
141  index_t M_,
142  index_t N_,
143  index_t K_,
144  index_t QK_A_,
145  index_t QK_B_,
146  index_t stride_A_,
147  index_t stride_B_,
148  index_t stride_C_,
149  index_t stride_AQ_,
150  index_t stride_BQ_)
152  M_, N_, K_, QK_A_, QK_B_, stride_A_, stride_B_, stride_C_, stride_AQ_, stride_BQ_),
153  a_ptr(a_ptr_),
154  b_ptr(b_ptr_),
155  aq_ptr(aq_ptr_),
156  bq_ptr(bq_ptr_),
157  c_ptr(c_ptr_),
158  k_batch(k_batch_)
159  {
160  }
161 
162  const void* a_ptr = nullptr;
163  const void* b_ptr = nullptr;
164  const void* aq_ptr = nullptr;
165  const void* bq_ptr = nullptr;
166  void* c_ptr = nullptr;
168 };
169 
171 {
172  const void* a_ptr;
173  const void* b_ptr;
174  const void* aq_ptr;
175  const void* bq_ptr;
176  void* c_ptr;
188 };
189 
190 template <typename TilePartitioner_,
191  typename GemmPipeline_,
192  typename EpiloguePipeline_,
193  QuantType QuantType_>
195 {
202 
207 
208  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
209  static constexpr bool PreshuffleQuant =
212 
217 
218  using AQDataType =
220  using BQDataType =
222 
223  static constexpr auto I0 = number<0>(); // A Tensor
224  static constexpr auto I1 = number<1>(); // AQ Tensor
225  static constexpr auto I2 = number<2>(); // B Tensor
226  static constexpr auto I3 = number<3>(); // BQ Tensor
227  static constexpr auto I4 = number<4>(); // C Tensor
228 
229  static constexpr auto kQuantType = QuantType_;
230 
231  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
232  {
233  // clang-format off
234  return concat('_', "gemm_quant", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
235  // clang-format on
236  }
237 
238  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
239  {
240  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
241  }
242 
243  CK_TILE_HOST static auto BlockSize()
244  {
245  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
246  }
247 
248  CK_TILE_HOST static constexpr QuantGemmKernelArgs
250  {
251  return QuantGemmKernelArgs{hostArgs.a_ptr,
252  hostArgs.b_ptr,
253  hostArgs.aq_ptr,
254  hostArgs.bq_ptr,
255  hostArgs.c_ptr,
256  hostArgs.M,
257  hostArgs.N,
258  hostArgs.K,
259  hostArgs.QK_A,
260  hostArgs.QK_B,
261  hostArgs.stride_A,
262  hostArgs.stride_B,
263  hostArgs.stride_C,
264  hostArgs.stride_AQ,
265  hostArgs.stride_BQ,
266  hostArgs.k_batch};
267  }
268 
270  {
271  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
272  }
273 
274  private:
275  CK_TILE_DEVICE static constexpr index_t get_padding_size(index_t length, index_t alignment)
276  {
277  return ck_tile::integer_least_multiple(length, alignment) - length;
278  };
279  // ===================================================================
280  // Helper: Create Pre-shuffled Quantization Tensor Descriptor
281  // ===================================================================
282  template <index_t KPerBlockBQ,
283  index_t NPerBlockBQ,
284  index_t NPerBlock,
285  index_t WarpTileN,
286  index_t GetVectorSizeBQ,
287  typename BQDataType_>
288  CK_TILE_DEVICE static auto
289  MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QN_B, index_t QK_B)
290  {
291  // Step 1: Calculate base BQ tensor dimensions
292  // ----------------------------------------------------------
293  // bq_x: Number of quantization groups in N dimension
294  // = N * KPerBlockBQ, where KPerBlockBQ is the number of
295  // K-dimension groups per block
296  // bq_y: Number of quantization groups in K dimension
297  // = Total K groups (QK_B) / groups per block
298  const auto bq_x = N * KPerBlockBQ;
299  const auto bq_y = QK_B / KPerBlockBQ;
300 
301  const auto bq_desc = make_naive_tensor_descriptor(
302  make_tuple(bq_y, bq_x), make_tuple(bq_x, 1), number<GetVectorSizeBQ>{}, number<1>{});
303 
304  // Step 2: First padding transformation (block-level alignment)
305  // ----------------------------------------------------------
306  // Pad the X dimension to be a multiple of block_tile_size to ensure
307  // each thread block can process complete tiles without edge cases
308  const auto block_tile_size = NPerBlockBQ * KPerBlockBQ;
309 
310  const auto bq_pad0_desc = transform_tensor_descriptor(
311  bq_desc,
313  make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))),
314  make_tuple(sequence<0>{}, sequence<1>{}),
315  make_tuple(sequence<0>{}, sequence<1>{}));
316 
317  // Step 3: Unmerge transformation (wave-level decomposition)
318  // ----------------------------------------------------------
319  // Split the X dimension into [wave_tile_count_x, wave_tile_size]
320  // This separates the work into tiles that can be processed by
321  // individual warps/waves
322  const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1];
323  const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1) * KPerBlockBQ;
324  const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size);
325 
326  const auto bq_unmerge_pad0_desc = transform_tensor_descriptor(
327  bq_pad0_desc,
329  make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
330  make_tuple(sequence<0>{}, sequence<1>{}),
331  make_tuple(sequence<0>{}, sequence<1, 2>{}));
332 
333  // Step 4: Second padding transformation (warp-level alignment)
334  // ----------------------------------------------------------
335  // Pad wave_tile_size to be a multiple of warp_size (typically 32 or 64)
336  // This ensures coalesced memory accesses within each warp
337  const auto bq_pad1_desc = transform_tensor_descriptor(
338  bq_unmerge_pad0_desc,
340  make_pass_through_transform(wave_tile_count_x),
341  make_right_pad_transform(wave_tile_size,
342  get_padding_size(wave_tile_size, get_warp_size()))),
343  make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
344  make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
345 
346  // Step 5: Final merge transformation (prepare for indexing)
347  // ----------------------------------------------------------
348  // Merge [bq_y, wave_tile_count_x] into a single outer dimension
349  // This creates a 2D layout: [merged_outer_dim, pad_wave_size]
350  // where merged_outer_dim = bq_y * wave_tile_count_x
351  // This layout facilitates efficient block-to-data mapping
352  const auto pad_wave_size = ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
353  const auto bq_merge_pad1_desc = transform_tensor_descriptor(
354  bq_pad1_desc,
355  make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)),
356  make_pass_through_transform(pad_wave_size)),
357  make_tuple(sequence<0, 1>{}, sequence<2>{}),
358  make_tuple(sequence<0>{}, sequence<1>{}));
359 
360  return make_tensor_view<address_space_enum::global>(bq_ptr, bq_merge_pad1_desc);
361  }
362 
363  public:
365  {
366  __device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
367  const std::size_t k_id = blockIdx.z)
368  {
369  constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
370  const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
371  const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
372 
373  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
374  {
376  }
377  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
378  {
379  a_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_A);
380  }
381 
382  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
383  {
384  b_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_B);
385  }
386  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
387  {
389  }
390 
391  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
392  {
394  }
395  else
396  {
397  splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
398  }
399  }
400 
404  };
405 
406  CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
407  const QuantGemmKernelArgs& kargs,
408  const index_t k_size,
409  const index_t i_m)
410  {
411  // Step 1: Create tensor view for A
412  const auto& a_tensor_view = [&]() {
413  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
414  {
415  return make_naive_tensor_view<address_space_enum::global>(
416  a_ptr,
417  make_tuple(kargs.M, k_size),
418  make_tuple(kargs.stride_A, 1),
419  number<GemmPipeline::GetVectorSizeA()>{},
420  number<1>{});
421  }
422  else
423  {
424  return make_naive_tensor_view<address_space_enum::global>(
425  a_ptr,
426  make_tuple(k_size, kargs.M),
427  make_tuple(kargs.stride_A, 1),
428  number<GemmPipeline::GetVectorSizeA()>{},
429  number<1>{});
430  }
431  }();
432 
433  // Step 2: Create padded view
434  const auto& a_pad_view = [&]() {
435  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
436  {
437  return pad_tensor_view(a_tensor_view,
441  }
442  else
443  {
444  return pad_tensor_view(a_tensor_view,
448  }
449  }();
450 
451  // Step 3: Create tile window
452  const auto& a_block_window = [&]() {
453  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
454  {
455  return make_tile_window(a_pad_view,
458  {i_m, 0});
459  }
460  else
461  {
462  return make_tile_window(a_pad_view,
465  {0, i_m});
466  }
467  }();
468 
469  return a_block_window;
470  }
471 
472  CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr,
473  const QuantGemmKernelArgs& kargs,
474  const index_t i_m,
475  const index_t i_n)
476  {
477  // Step 1: Create tensor view for AQ
478  const auto& aq_tensor_view = [&]() {
480  {
481  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
482  const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
483  const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ;
484  const auto aq_desc =
486  make_tuple(aq_x, 1),
487  number<GemmPipeline::GetVectorSizeAQ()>{},
488  number<1>{});
489 
490  const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
491  const auto aq_pad0_desc = transform_tensor_descriptor(
492  aq_desc,
493  make_tuple(
495  make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
498 
499  const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
500  const auto wave_tile_size =
501  GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
502  const auto wave_tile_count_x =
503  ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
504 
505  const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
506  aq_pad0_desc,
507  make_tuple(
509  make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
512 
513  const auto aq_pad1_desc = transform_tensor_descriptor(
514  aq_unmerge_pad0_desc,
515  make_tuple(
517  make_pass_through_transform(wave_tile_count_x),
519  wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
522 
523  const auto pad_wave_size =
525  const auto aq_merge_pad1_desc = transform_tensor_descriptor(
526  aq_pad1_desc,
527  make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
528  make_pass_through_transform(pad_wave_size)),
531 
532  return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
533  }
534  else if constexpr((kQuantType == QuantType::AQuantGrouped ||
537  {
538  if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
539  {
540  return make_naive_tensor_view<address_space_enum::global>(
541  aq_ptr,
542  make_tuple(kargs.M, kargs.QK_A),
543  make_tuple(kargs.stride_AQ, 1),
544  number<GemmPipeline::GetVectorSizeAQ()>{},
545  number<1>{});
546  }
547  else // Column major AQ
548  {
549  return make_naive_tensor_view<address_space_enum::global>(
550  aq_ptr,
551  make_tuple(kargs.QK_A, kargs.M),
552  make_tuple(kargs.stride_AQ, 1),
553  number<GemmPipeline::GetVectorSizeAQ()>{},
554  number<1>{});
555  }
556  }
557  else if constexpr(kQuantType == QuantType::RowColQuant)
558  {
559  return make_naive_tensor_view<address_space_enum::global>(
560  aq_ptr,
561  make_tuple(kargs.M, kargs.N),
562  make_tuple(1, 0), // broadcasting over n
563  number<1>{},
564  number<1>{});
565  }
566  else
567  {
568  return nullptr;
569  }
570  }();
571 
572  // Step 2: Create tile window (no padding for AQ)
573  const auto& aq_block_window = [&]() {
575  {
576  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
578  constexpr auto block_m = TilePartitioner::MPerBlock;
579  constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0);
580  constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
581  constexpr auto tile_window_width =
582  ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
583  constexpr auto tile_window_height = block_m / warp_m;
584  auto block_m_idx = i_m / block_m;
585  return make_tile_window(
586  aq_tensor_view,
588  {block_m_idx * tile_window_height, 0});
589  }
590  else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
591  {
593  constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
594  constexpr auto block_m = TilePartitioner::MPerBlock;
595  if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
596  {
597  return make_tile_window(aq_tensor_view,
599  {i_m, 0});
600  }
601  else // Column major AQ
602  {
603  return make_tile_window(aq_tensor_view,
605  {0, i_m});
606  }
607  }
608  else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant)
609  {
610  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
612  constexpr auto block_m = TilePartitioner::MPerBlock;
613  constexpr auto block_k = TilePartitioner::KPerBlock;
614  return make_tile_window(
615  aq_tensor_view,
616  make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
617  {i_m, 0});
618  }
619  else if constexpr(kQuantType == QuantType::RowColQuant)
620  {
621  return make_tile_window(aq_tensor_view,
624  {i_m, i_n});
625  }
626  else
627  {
628  return nullptr;
629  }
630  }();
631 
632  return aq_block_window;
633  }
634 
635  CK_TILE_DEVICE static auto MakeBBlockWindow(const BDataType* b_ptr,
636  const QuantGemmKernelArgs& kargs,
637  const index_t k_size,
638  const index_t i_n)
639  {
640  // Step 1: Create tensor view for B
641  const auto& b_tensor_view = [&]() {
642  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
643  {
644  if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
645  {
646  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
647  const index_t K0 = k_size / K1;
648  constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
649  const auto b_k0_n_k1_desc =
651  make_tuple(kargs.N * K1, K1, I1),
653  number<1>{});
654  const auto b_n_k_desc = transform_tensor_descriptor(
655  b_k0_n_k1_desc,
660  return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
661  }
662  else
663  {
664  return make_naive_tensor_view<address_space_enum::global>(
665  b_ptr,
666  make_tuple(k_size, kargs.N),
667  make_tuple(kargs.stride_B, 1),
668  number<GemmPipeline::GetVectorSizeB()>{},
669  number<1>{});
670  }
671  }
672  else
673  {
674  if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
675  {
676  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
677  const index_t K0 = k_size / K1;
678  constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
679  const auto b_k0_n_k1_desc =
681  make_tuple(kargs.N * K1, K1, I1),
683  number<1>{});
684  const auto b_n_k_desc = transform_tensor_descriptor(
685  b_k0_n_k1_desc,
690  return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
691  }
692  else
693  {
694  if constexpr(PreshuffleB)
695  {
696  index_t kFlatK =
697  GemmPipeline::flatKPerWarp *
698  (k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
699  index_t kFlatN = kargs.N * kargs.K / kFlatK;
700  return make_naive_tensor_view<address_space_enum::global>(
701  b_ptr,
702  make_tuple(kFlatN, kFlatK),
703  make_tuple(kFlatK, 1),
704  number<GemmPipeline::GetVectorSizeB()>{},
705  number<1>{});
706  }
707  else
708  {
709  if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
710  return make_naive_tensor_view<address_space_enum::global>(
711  b_ptr,
712  make_tuple(kargs.N, k_size / 2),
713  make_tuple(kargs.stride_B, 1),
714  number<GemmPipeline::GetVectorSizeB()>{},
715  number<1>{});
716  else
717  return make_naive_tensor_view<address_space_enum::global>(
718  b_ptr,
719  make_tuple(kargs.N, k_size),
720  make_tuple(kargs.stride_B, 1),
721  number<GemmPipeline::GetVectorSizeB()>{},
722  number<1>{});
723  }
724  }
725  }
726  }();
727 
728  // Step 2: Create padded view (or flat view for PreshuffleB)
729  const auto& b_pad_view = [&]() {
730  if constexpr(PreshuffleB)
731  {
732  return b_tensor_view; // no padding for preshuffle
733  }
734  else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
735  {
736  if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
737  return pad_tensor_view(b_tensor_view,
739  number<TilePartitioner::KPerBlock / 2>{}),
741  else
742  return pad_tensor_view(b_tensor_view,
746  }
747  else
748  {
749  return pad_tensor_view(b_tensor_view,
753  }
754  }();
755 
756  // Step 3: Create tile window
757  const auto& b_block_window = [&]() {
758  if constexpr(PreshuffleB)
759  {
760  return make_tile_window(
761  b_pad_view,
764  {static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0});
765  }
766  else
767  {
768  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
769  {
770  if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
771  return make_tile_window(
772  b_pad_view,
774  number<TilePartitioner::KPerBlock / 2>{}),
775  {i_n, 0});
776  else
777  return make_tile_window(b_pad_view,
780  {i_n, 0});
781  }
782  else
783  {
784  return make_tile_window(b_pad_view,
787  {0, i_n});
788  }
789  }
790  }();
791 
792  return b_block_window;
793  }
794 
795  CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr,
796  const QuantGemmKernelArgs& kargs,
797  const index_t i_m,
798  const index_t i_n)
799  {
800  // Step 1: Create tensor view for BQ
801  const auto& bq_tensor_view = [&]() {
802  if constexpr(kQuantType == QuantType::RowColQuant)
803  {
804  return make_naive_tensor_view<address_space_enum::global>(
805  bq_ptr,
806  make_tuple(kargs.M, kargs.N),
807  make_tuple(0, 1), // broadcasting over m
808  number<1>{},
809  number<1>{});
810  }
811  else if constexpr(kQuantType == QuantType::BQuantGrouped)
812  {
813  if constexpr(PreshuffleQuant)
814  {
815  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
816  "PreshuffleQuant with BQuantGrouped currently only supports "
817  "ColumnMajor BQ layout");
819 
820  return MakePreshuffledQuantTensorView<
821  GemmPipeline::KPerBlockBQ,
822  GemmPipeline::NPerBlockBQ,
823  GemmPipeline::NPerBlock,
824  TilePartitioner::BlockGemmShape::WarpTile::at(I1),
825  GemmPipeline::GetVectorSizeBQ()>(
826  bq_ptr,
827  ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN),
828  QuantGroupSize::kN,
829  kargs.QK_B);
830  }
831  else
832  {
834 
835  if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
836  {
837  return make_naive_tensor_view<address_space_enum::global>(
838  bq_ptr,
839  make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK),
840  integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
841  make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1),
842  number<GemmPipeline::GetVectorSizeBQ()>{},
843  number<1>{});
844  }
845  else
846  {
847  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
848  return make_naive_tensor_view<address_space_enum::global>(
849  bq_ptr,
850  make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN),
851  integer_divide_ceil(kargs.K, QuantGroupSize::kK)),
852  make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1),
853  number<GemmPipeline::GetVectorSizeBQ()>{},
854  number<1>{});
855  }
856  }
857  }
858  else if constexpr(kQuantType == QuantType::ABQuantGrouped)
859  {
860  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
862  return make_naive_tensor_view<address_space_enum::global>(
863  bq_ptr,
864  make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
865  make_tuple(kargs.stride_BQ, 1),
866  number<GemmPipeline::GetVectorSizeBQ()>{},
867  number<1>{});
868  }
869  else
870  {
871  return nullptr;
872  }
873  }();
874 
875  // Step 2: Create tile window (no padding for BQ)
876  const auto& bq_block_window = [&]() {
877  if constexpr(kQuantType == QuantType::RowColQuant)
878  {
879  return make_tile_window(bq_tensor_view,
882  {i_m, i_n});
883  }
884  else if constexpr(kQuantType == QuantType::BQuantGrouped)
885  {
887  if constexpr(PreshuffleQuant)
888  {
889  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
890 
891  // Number of N-dimension quantization groups per block
892  constexpr auto block_n = (QuantGroupSize::kN <= TilePartitioner::NPerBlock)
893  ? TilePartitioner::NPerBlock / QuantGroupSize::kN
894  : QuantGroupSize::kN / TilePartitioner::NPerBlock;
895 
896  // Number of N-dimension elements per warp
897  constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
898 
899  // Determine how many warps share the same scale in N-dimension
900  constexpr auto warp_per_group = (QuantGroupSize::kN < warp_n)
901  ? (warp_n / QuantGroupSize::kN)
902  : (QuantGroupSize::kN / warp_n);
903 
904  // Number of K-dimension quantization groups per block
905  constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
906 
907  // The pre-shuffled layout flattens warp_n ×
908  // bqk_per_block scales per row, Padded up to warp_size
909  // to ensure coalesced memory access.
910  constexpr auto tile_window_width =
911  ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
912 
913  // Adapts based on fine vs coarse quantization granularity:
914  // - Fine-grained (QuantGroupSize::kN < warp_n):
915  // Multiple quant groups per warp → fewer rows needed per block.
916  // height = block_n / warp_per_group
917  //
918  // - Coarse-grained (QuantGroupSize::kN >= warp_n):
919  // Each row represents one quant group.
920  // height = block_n
921  constexpr auto tile_window_height =
922  (QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
923 
924  auto block_n_idx = i_n / TilePartitioner::NPerBlock;
925 
926  // For decode shapes GN: 128, Blocks needs to repeat 0,0,1,1,2,2 ...
927  if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
928  {
929  block_n_idx = block_n_idx >> 1;
930  }
931 
932  if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
933  {
934  return make_tile_window(
935  bq_tensor_view,
937  {block_n_idx, 0});
938  }
939  else
940  {
941  return make_tile_window(
942  bq_tensor_view,
944  {block_n_idx * tile_window_height, 0});
945  }
946  }
947  else
948  {
949  constexpr auto tensor_dim =
950  (QuantGroupSize::kN <= TilePartitioner::NPerBlock)
951  ? TilePartitioner::NPerBlock / QuantGroupSize::kN
952  : 1;
953  if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
954  {
955  return make_tile_window(
956  bq_tensor_view,
959  {0, i_n / QuantGroupSize::kN});
960  }
961  else
962  {
963  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
964  return make_tile_window(
965  bq_tensor_view,
967  number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
968  {i_n / QuantGroupSize::kN, 0});
969  }
970  }
971  }
972  else if constexpr(kQuantType == QuantType::ABQuantGrouped)
973  {
974  static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
976  return make_tile_window(
977  bq_tensor_view,
979  number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
980  {i_n / QuantGroupSize::kN, 0});
981  }
982  else
983  {
984  return nullptr;
985  }
986  }();
987 
988  return bq_block_window;
989  }
990 
991  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
993  const QuantGemmKernelArgs& kargs,
994  const index_t i_m,
995  const index_t i_n)
996  {
997  // Step 1: Create tensor view for C
998  const auto& c_tensor_view = [&]() {
999  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
1000  {
1001  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
1002  c_ptr,
1003  make_tuple(kargs.M, kargs.N),
1004  make_tuple(kargs.stride_C, 1),
1005  number<EpiloguePipeline::GetVectorSizeC()>{},
1006  number<1>{});
1007  }
1008  else
1009  {
1010  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
1011  c_ptr,
1012  make_tuple(kargs.M, kargs.N),
1013  make_tuple(1, kargs.stride_C),
1014  number<1>{},
1015  number<1>{});
1016  }
1017  }();
1018 
1019  // Step 2: Create padded view
1020  const auto& c_pad_view = [&]() {
1021  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
1022  {
1023  return pad_tensor_view(c_tensor_view,
1027  }
1028  else
1029  {
1030  return pad_tensor_view(c_tensor_view,
1034  }
1035  }();
1036 
1037  // Step 3: Create tile window
1038  auto c_block_window = make_tile_window(
1039  c_pad_view,
1041  {i_m, i_n});
1042 
1043  return c_block_window;
1044  }
1045 
1047  {
1048  if(kargs.k_batch != 1)
1049  {
1050  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1051  {
1052  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
1053  }
1054  return false;
1055  }
1056 
1057  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
1058  {
1059  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
1060  GemmPipeline::kPadK == false)
1061  {
1062  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1063  {
1064  CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
1065  "without padding!");
1066  }
1067  return false;
1068  }
1069  if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
1070  {
1071  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1072  {
1073  CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
1074  }
1075  return false;
1076  }
1077  }
1078  else
1079  {
1080  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
1081  {
1082  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1083  {
1084  CK_TILE_ERROR(
1085  "Can't support M that is not a multiple of MPerBlock without padding!");
1086  }
1087  return false;
1088  }
1089  if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
1090  {
1091  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1092  {
1093  CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
1094  }
1095  return false;
1096  }
1097  }
1098 
1099  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
1100  {
1101  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
1102  {
1103  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1104  {
1105  CK_TILE_ERROR(
1106  "Can't support N that is not a multiple of NPerBlock without padding!");
1107  }
1108  return false;
1109  }
1110  if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
1111  {
1112  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1113  {
1114  CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
1115  }
1116  return false;
1117  }
1118  }
1119  else
1120  {
1121  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
1122  GemmPipeline::kPadK == false)
1123  {
1124  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1125  {
1126  CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
1127  "without padding!");
1128  }
1129  return false;
1130  }
1131  if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
1132  {
1133  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1134  {
1135  CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
1136  }
1137  return false;
1138  }
1139  }
1140 
1141  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
1142  {
1143  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
1144  {
1145  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1146  {
1147  CK_TILE_ERROR(
1148  "Can't support N that is not a multiple of NPerBlock without padding!");
1149  }
1150  return false;
1151  }
1152  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
1153  {
1154  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1155  {
1156  CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
1157  }
1158  return false;
1159  }
1160  }
1161  else
1162  {
1163  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
1164  {
1165  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1166  {
1167  CK_TILE_ERROR(
1168  "Can't support M that is not a multiple of MPerBlock without padding!");
1169  }
1170  return false;
1171  }
1172  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
1173  {
1174  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
1175  {
1176  CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
1177  }
1178  return false;
1179  }
1180  }
1181  return true;
1182  }
1183 
1199  CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
1200  const BDataType* b_ptr,
1201  const AQDataType* aq_ptr,
1202  const BQDataType* bq_ptr,
1203  CDataType* c_ptr,
1204  void* smem_ptr,
1205  const QuantGemmKernelArgs& kargs,
1206  const SplitKBatchOffset& splitk_batch_offset,
1207  const index_t block_idx_m,
1208  const index_t block_idx_n)
1209  {
1210  // Create block windows using specialized methods
1211  const auto& a_block_window =
1212  MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
1213  const auto& b_block_window =
1214  MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
1215  const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
1216  const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
1217 
1218  const index_t num_loop =
1219  amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
1220 
1221  // Run GEMM cooperatively by whole workgroup.
1222  const auto& c_block_tile = [&]() {
1223  if constexpr(kQuantType == QuantType::AQuantGrouped)
1224  {
1225  index_t m = 0;
1226  if constexpr(PreshuffleQuant)
1227  {
1228  m = kargs.M;
1229  }
1230  return GemmPipeline{}.template operator()(
1231  a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr, m);
1232  }
1233  else if constexpr(kQuantType == QuantType::BQuantGrouped)
1234  {
1235  index_t n = 0;
1236  if constexpr(PreshuffleQuant)
1237  {
1238  n = kargs.N;
1239  }
1240  return GemmPipeline{}.template operator()(
1241  a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr, n);
1242  }
1243  else if constexpr(kQuantType == QuantType::ABQuantGrouped)
1244  {
1245  index_t m = 0;
1246  index_t n = 0;
1247  if constexpr(PreshuffleQuant)
1248  {
1249  m = kargs.M;
1250  n = kargs.N;
1251  }
1252  return GemmPipeline{}.template operator()(a_block_window,
1253  b_block_window,
1254  aq_block_window,
1255  bq_block_window,
1256  num_loop,
1257  smem_ptr,
1258  m,
1259  n);
1260  }
1261  else if constexpr(kQuantType == QuantType::RowColQuant ||
1263  {
1264  return GemmPipeline{}.template operator()(
1265  a_block_window, b_block_window, num_loop, smem_ptr);
1266  }
1267  }();
1268 
1269  const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
1270 
1271  // Run Epilogue Pipeline with k_batch dispatch
1272  if(k_batch == 1)
1273  {
1274  auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
1275  c_ptr, kargs, block_idx_m, block_idx_n);
1276 
1277  if constexpr(kQuantType == QuantType::ABQuantGrouped ||
1280  {
1281  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
1282  }
1283  else if constexpr(kQuantType == QuantType::RowColQuant)
1284  {
1285  EpiloguePipeline{}(c_block_window,
1286  c_block_tile,
1287  c_block_window,
1288  smem_ptr,
1289  aq_block_window,
1290  bq_block_window);
1291  }
1292  else if constexpr(kQuantType == QuantType::TensorQuant)
1293  {
1294  const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
1295  const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
1296  EpiloguePipeline{}(
1297  c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
1298  }
1299  }
1300  else
1301  {
1302  auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
1303  c_ptr, kargs, block_idx_m, block_idx_n);
1304 
1305  if constexpr(kQuantType == QuantType::ABQuantGrouped ||
1308  {
1309  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
1310  }
1311  else if constexpr(kQuantType == QuantType::RowColQuant)
1312  {
1313  EpiloguePipeline{}(c_block_window,
1314  c_block_tile,
1315  c_block_window,
1316  smem_ptr,
1317  aq_block_window,
1318  bq_block_window);
1319  }
1320  else if constexpr(kQuantType == QuantType::TensorQuant)
1321  {
1322  const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
1323  const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
1324  EpiloguePipeline{}(
1325  c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
1326  }
1327  }
1328  }
1329 
1331  {
1332  const auto blockId = amd_wave_read_first_lane(blockIdx.x);
1333  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
1334  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1335  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1336  const SplitKBatchOffset splitk_batch_offset(kargs);
1337 
1338  // Apply splitk offset to input pointers
1339  const ADataType* a_ptr =
1340  static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
1341  const BDataType* b_ptr =
1342  static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
1343  const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
1344  const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
1345  CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
1346 
1347  // allocate LDS
1348  __shared__ char smem_ptr[GetSmemSize()];
1349 
1350  RunGemm(
1351  a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
1352  }
1353 };
1354 
1355 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1660
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_least_multiple(X x, Y y)
Definition: math.hpp:152
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:146
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1691
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1634
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1698
QuantType
Definition: tile_gemm_quant_traits.hpp:12
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:158
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
unsigned int uint32_t
Definition: stdint.h:126
Definition: gemm_quant_kernel.hpp:133
void * c_ptr
Definition: gemm_quant_kernel.hpp:166
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:164
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:165
const void * b_ptr
Definition: gemm_quant_kernel.hpp:163
CK_TILE_HOST QuantGemmHostArgs()=default
index_t k_batch
Definition: gemm_quant_kernel.hpp:167
const void * a_ptr
Definition: gemm_quant_kernel.hpp:162
CK_TILE_HOST QuantGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, const void *aq_ptr_, const void *bq_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition: gemm_quant_kernel.hpp:135
Definition: gemm_quant_kernel.hpp:365
__device__ SplitKBatchOffset(const QuantGemmKernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: gemm_quant_kernel.hpp:366
index_t a_k_split_offset
Definition: gemm_quant_kernel.hpp:401
index_t b_k_split_offset
Definition: gemm_quant_kernel.hpp:402
index_t splitted_k
Definition: gemm_quant_kernel.hpp:403
Definition: gemm_quant_kernel.hpp:171
index_t k_batch
Definition: gemm_quant_kernel.hpp:187
index_t stride_BQ
Definition: gemm_quant_kernel.hpp:186
const void * b_ptr
Definition: gemm_quant_kernel.hpp:173
void * c_ptr
Definition: gemm_quant_kernel.hpp:176
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:174
index_t stride_A
Definition: gemm_quant_kernel.hpp:182
index_t M
Definition: gemm_quant_kernel.hpp:177
const void * a_ptr
Definition: gemm_quant_kernel.hpp:172
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:175
index_t QK_B
Definition: gemm_quant_kernel.hpp:181
index_t K
Definition: gemm_quant_kernel.hpp:179
index_t QK_A
Definition: gemm_quant_kernel.hpp:180
index_t stride_AQ
Definition: gemm_quant_kernel.hpp:185
index_t N
Definition: gemm_quant_kernel.hpp:178
index_t stride_C
Definition: gemm_quant_kernel.hpp:184
index_t stride_B
Definition: gemm_quant_kernel.hpp:183
Definition: gemm_quant_kernel.hpp:195
static constexpr auto I4
Definition: gemm_quant_kernel.hpp:227
static constexpr auto I3
Definition: gemm_quant_kernel.hpp:226
static constexpr bool PreshuffleB
Definition: gemm_quant_kernel.hpp:211
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: gemm_quant_kernel.hpp:238
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_quant_kernel.hpp:1199
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_quant_kernel.hpp:197
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_quant_kernel.hpp:198
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_quant_kernel.hpp:196
remove_cvref_t< typename EpiloguePipeline::AccDataType > AccDataType
Definition: gemm_quant_kernel.hpp:216
static CK_TILE_DEVICE auto MakeCBlockWindow(CDataType *c_ptr, const QuantGemmKernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:992
static constexpr auto I0
Definition: gemm_quant_kernel.hpp:223
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
Definition: gemm_quant_kernel.hpp:1330
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: gemm_quant_kernel.hpp:215
static constexpr index_t kBlockSize
Definition: gemm_quant_kernel.hpp:208
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_quant_kernel.hpp:200
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: gemm_quant_kernel.hpp:201
static CK_TILE_DEVICE auto MakeABlockWindow(const ADataType *a_ptr, const QuantGemmKernelArgs &kargs, const index_t k_size, const index_t i_m)
Definition: gemm_quant_kernel.hpp:406
static CK_TILE_DEVICE auto MakeBQBlockWindow(const BQDataType *bq_ptr, const QuantGemmKernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:795
static constexpr auto I1
Definition: gemm_quant_kernel.hpp:224
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_quant_kernel.hpp:199
static constexpr bool PreshuffleQuant
Definition: gemm_quant_kernel.hpp:209
static CK_TILE_DEVICE auto MakeBBlockWindow(const BDataType *b_ptr, const QuantGemmKernelArgs &kargs, const index_t k_size, const index_t i_n)
Definition: gemm_quant_kernel.hpp:635
static CK_TILE_HOST bool IsSupportedArgument(const QuantGemmKernelArgs &kargs)
Definition: gemm_quant_kernel.hpp:1046
remove_cvref_t< typename detail::get_aq_data_type_or< GemmPipeline, AccDataType >::type > AQDataType
Definition: gemm_quant_kernel.hpp:219
remove_cvref_t< typename detail::get_bq_data_type_or< GemmPipeline, AccDataType >::type > BQDataType
Definition: gemm_quant_kernel.hpp:221
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_quant_kernel.hpp:214
static constexpr auto I2
Definition: gemm_quant_kernel.hpp:225
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: gemm_quant_kernel.hpp:269
static constexpr CK_TILE_HOST QuantGemmKernelArgs MakeKernelArgs(const QuantGemmHostArgs &hostArgs)
Definition: gemm_quant_kernel.hpp:249
static CK_TILE_DEVICE auto MakeAQBlockWindow(const AQDataType *aq_ptr, const QuantGemmKernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:472
static CK_TILE_HOST const std::string GetName()
Definition: gemm_quant_kernel.hpp:231
remove_cvref_t< typename detail::get_bq_layout_or< GemmPipeline, typename GemmPipeline::BLayout >::type > BQLayout
Definition: gemm_quant_kernel.hpp:206
static CK_TILE_HOST auto BlockSize()
Definition: gemm_quant_kernel.hpp:243
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_quant_kernel.hpp:213
remove_cvref_t< typename detail::get_aq_layout_or< GemmPipeline, typename GemmPipeline::ALayout >::type > AQLayout
Definition: gemm_quant_kernel.hpp:204
static constexpr auto kQuantType
Definition: gemm_quant_kernel.hpp:229
Definition: gemm_quant_kernel.hpp:95
index_t stride_AQ
Definition: gemm_quant_kernel.hpp:128
index_t N
Definition: gemm_quant_kernel.hpp:121
index_t K
Definition: gemm_quant_kernel.hpp:122
index_t stride_BQ
Definition: gemm_quant_kernel.hpp:129
index_t stride_C
Definition: gemm_quant_kernel.hpp:127
index_t stride_B
Definition: gemm_quant_kernel.hpp:126
index_t stride_A
Definition: gemm_quant_kernel.hpp:125
CK_TILE_HOST QuantGemmProblem(index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition: gemm_quant_kernel.hpp:97
index_t QK_A
Definition: gemm_quant_kernel.hpp:123
index_t QK_B
Definition: gemm_quant_kernel.hpp:124
CK_TILE_HOST QuantGemmProblem()=default
index_t M
Definition: gemm_quant_kernel.hpp:120
Definition: integral_constant.hpp:13
Definition: gemm_quant_kernel.hpp:47
Default type
Definition: gemm_quant_kernel.hpp:48
typename T::AQLayout type
Definition: gemm_quant_kernel.hpp:30
Definition: gemm_quant_kernel.hpp:23
Default type
Definition: gemm_quant_kernel.hpp:24
Definition: gemm_quant_kernel.hpp:59
Default type
Definition: gemm_quant_kernel.hpp:60
typename T::BQLayout type
Definition: gemm_quant_kernel.hpp:42
Definition: gemm_quant_kernel.hpp:35
Default type
Definition: gemm_quant_kernel.hpp:36
Definition: gemm_quant_kernel.hpp:83
static constexpr bool value
Definition: gemm_quant_kernel.hpp:84
Definition: gemm_quant_kernel.hpp:71
static constexpr bool value
Definition: gemm_quant_kernel.hpp:72
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145