/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp Source File
gemm_group_quant_utils.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 
8 namespace ck_tile {
9 
10 template <typename Problem, typename DataType, index_t YPerTile, index_t XPerTile>
11 CK_TILE_HOST_DEVICE static constexpr auto GetABQGlobalVectorLoadSize()
12 {
13  using I1 = number<1>;
14  constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{});
15 
16  constexpr index_t BlockSize = Problem::kBlockSize;
17 
18  // Data is replicated across warps along NWarps, so we divide BlockSize by NWarps
19  constexpr index_t elements_per_thread = (YPerTile * XPerTile) / (BlockSize / NWarps);
20  constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
21 
22  // Define vector load candidates in descending order of priority
23  constexpr std::array<index_t, 5> candidates{
24  PackedSize * 32 / sizeof(DataType),
25  PackedSize * 16 / sizeof(DataType),
26  PackedSize * 8 / sizeof(DataType),
27  PackedSize * 4 / sizeof(DataType),
28  PackedSize * 2 / sizeof(DataType),
29  };
30 
31  for(const auto vec_size : candidates)
32  {
33  if(vec_size <= 0 || XPerTile % vec_size != 0 || elements_per_thread % vec_size != 0)
34  continue;
35  bool is_valid = (vec_size > 0) && (XPerTile % vec_size == 0) &&
36  (elements_per_thread % vec_size == 0) && vec_size != candidates[4];
37  if(is_valid)
38  {
39  return vec_size;
40  }
41  }
42  return PackedSize; // Absolute fallback
43 }
44 
45 // AQ holds groupquant scale data for A. Data is loaded from DRAM and partitioned across
46 // threads. Post mfma scales are shuffled across threads in the warp and applied to
47 // accum registers.
48 template <typename BlockGemmShape,
49  typename WarpGemm,
50  index_t BlockSize,
51  index_t YPerTile,
52  index_t XPerTile,
53  index_t KPerBlockAQ,
54  index_t VecSize,
55  bool PreshuffleQuant>
57 {
58  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
59  static constexpr index_t warp_size = get_warp_size();
60  static constexpr index_t num_warps = BlockSize / get_warp_size();
61 
62  static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
63  static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
64  static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
65 
66  static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM);
67 
68  static_assert(num_warps == MWarps * NWarps * KWarps);
69 
70  // KWarps > 1 isn't supported
71  static_assert(KWarps == 1);
72 
74  {
75  if constexpr(PreshuffleQuant)
76  {
77  // # of elements per thread
78  static_assert(XPerTile >= warp_size && XPerTile % warp_size == 0);
79  constexpr index_t X1 = warp_size;
80  constexpr index_t X0 = XPerTile / warp_size;
81 
82  constexpr index_t Y1 = MWarps;
83  constexpr index_t Y0 = YPerTile / Y1;
90  sequence<0, 0>>{});
91  }
92  else
93  {
94  // # of elements per thread
95  constexpr index_t X = XPerTile;
96 
97  constexpr index_t YR = 1;
98  constexpr index_t Y0 = MIterPerWarp ? MIterPerWarp : 1;
99  constexpr index_t Y1 = MWarps;
100  constexpr index_t Y2 = WarpGemm::kM;
101  static_assert(Y2 >= WarpGemm::kM,
102  "Scales for all rows must be available within the warp.");
103  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
110  sequence<0, 0>>{});
111  }
112  }
114  {
115 
116  constexpr index_t Y0 = YPerTile;
117  constexpr index_t X0 = 1;
118  constexpr index_t X1 = MIterPerWarp ? MIterPerWarp : 1;
119  constexpr index_t X2 = MWarps;
120  constexpr index_t X3 = WarpGemm::kM;
121 
122  static_assert(X3 >= WarpGemm::kM, "Scales for all rows must be available within the warp.");
123  static_assert(X0 * X1 * X2 * X3 == XPerTile,
124  "X0, X1, X2, X3 must cover the blocktile along X.");
125 
132  sequence<1, 0>>{});
133  }
134 };
135 
136 template <typename BlockGemmShape,
137  typename WarpGemm,
138  index_t BlockSize,
139  index_t YPerTile,
140  index_t XPerTile,
141  index_t VecSize>
144 {
145  // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
146  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
147  static constexpr index_t warp_size = get_warp_size();
148  static constexpr index_t num_warps = BlockSize / get_warp_size();
149 
150  static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
151  static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
152  static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
153 
154  static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM);
155 
156  static_assert(num_warps == MWarps * NWarps * KWarps);
157 
158  // KWarps > 1 isn't supported
159  static_assert(KWarps == 1);
160 
161  // # of elements per thread
162  static constexpr index_t X = XPerTile;
163  static constexpr index_t XR = 2;
164 
165  // Number of iters per warp
166  // MIters are indexed using (Y0, Y1)
167  static constexpr index_t Y0 = MIterPerWarp;
168 
169  // # of warps in Y dim
170  static constexpr index_t Y1 = MWarps;
171 
172  static constexpr index_t Y2 = WarpGemm::kM;
173 
174  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
175 
177  {
184  sequence<0, 0>>{});
185  }
186 };
187 
188 // TODO:: might need to update
189 template <typename BlockGemmShape,
190  typename WarpGemm,
191  index_t BlockSize,
192  index_t KPerTile,
193  index_t NPerTile,
194  index_t NPerQ,
195  index_t KPerQ,
196  typename BQLayout = tensor_layout::gemm::ColumnMajor,
197  bool PreshuffleQuant = false>
199 {
200  static constexpr index_t warp_size = get_warp_size();
201  static constexpr index_t num_warps = BlockSize / get_warp_size();
202 
203  static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
204  static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
205  static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
206 
207  static constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarps * WarpGemm::kN);
208 
209  static_assert(num_warps == MWarps * NWarps * KWarps);
210  static_assert(KWarps == 1);
211 
213  {
214  // Preshuffle only supported for ColumnMajor currently
215  static_assert(!(PreshuffleQuant && std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>),
216  "PreshuffleQuant only supported for ColumnMajor BQLayout");
217 
218  if constexpr(PreshuffleQuant)
219  {
220  // =============================================================================
221  // PRE-SHUFFLED BQ SCALE TILE DISTRIBUTION
222  // =============================================================================
223  // For pre-shuffled quantization, the BQ scale tensor has been reorganized
224  // (pre-shuffled) to optimize memory access patterns during dequantization.
225  //
226  // Tile Dimensions:
227  // - K-axis (Y in encoding): Corresponds to the K-dimension iteration
228  // - N-axis (X in encoding): Flattened scale index combining N and K groups
229  //
230  // The encoding distributes work across threads such that each thread loads
231  // the correct pre-shuffled scale for its corresponding B-matrix elements.
232  // =============================================================================
233  if constexpr(NPerQ <= WarpGemm::kN)
234  {
235  // =========================================================================
236  // CASE 1: Fine-grained Quantization (NPerQ <= WarpGemm::kN)
237  // =========================================================================
238  // Multiple quantization scales exist within a single warp's N-dimension.
239  // Each warp processes multiple scales: WarpGemm::kN / NPerQ scales per warp.
240  //
241  // Example: NPerQ=8, WarpGemm::kN=16, KPerQ=128, BlockGemmShape::kK=256
242  // → 2 scales per warp in N, 2 K-groups per block
243 
244  // N1: Number of K-dimension quantization groups per block,
245  // Each K-group of KPerQ elements shares the same scale.
246  // N0: Number of scales per warp in N-dimension, Since NPerQ
247  // <= WarpGemm::kN, each warp handles multiple scales.
248  // N2: Elements per thread
249  // NR1: Elements sharing the same scale in N-dimension
250  // NR0: Interleave factor to ensure full warp utilization
251  // K1: Number of warps distributed along this dimension
252  // K0: Iterations per warp to cover the K-tile
253  // KR: No replication in K-dimension
254  constexpr auto N1 = BlockGemmShape::kK / KPerQ;
255  constexpr auto N0 = WarpGemm::kN / NPerQ;
256  constexpr auto N2 = 1;
257  constexpr auto NR1 = NPerQ;
258  constexpr auto NR0 =
259  (warp_size <= (N0 * N1 * N2 * NR1)) ? 1 : warp_size / (N0 * N1 * N2 * NR1);
260  constexpr auto K1 = NWarps;
261  constexpr auto K0 = KPerTile / K1;
262  constexpr auto KR = 1;
263 
270  sequence<0, 2>>{});
271  }
272  else if constexpr(NPerQ < WarpGemm::kN * NWarps)
273  {
274  // =========================================================================
275  // CASE 2: Medium-grained Quantization (WarpGemm::kN < NPerQ < WarpGemm::kN *
276  // NWarps)
277  // =========================================================================
278  // Each warp handles exactly one quantization scale in N-dimension.
279  // Some warps share the same scale (KR > 1 creates warp grouping).
280  //
281  // Example: NPerQ=32, WarpGemm::kN=16, NWarps=4
282  // → KR=2 (2 warps share same scale), K1=2 (2 unique scale groups)
283 
284  // KR: Number of warps sharing the same scale
285  // K1: Number of distinct warp groups (unique scales)
286  // K0: Iterations to cover K-tile per warp group
287  // N1: K-dimension quantization groups
288  // N0: Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN)
289  // N2: Elements per thread
290  // NR1: Scale broadcast factor (full NPerQ)
291  // NR0: Remaining interleave factor
292 
293  constexpr auto KR = NPerQ / WarpGemm::kN;
294  constexpr auto K1 = NWarps / KR;
295  constexpr auto K0 = KPerTile / K1;
296  constexpr auto N1 = BlockGemmShape::kK / KPerQ;
297  constexpr auto N0 = 1;
298  constexpr auto N2 = 1;
299  constexpr auto NR1 = NPerQ;
300  constexpr auto NR0 =
301  (warp_size <= (N0 * N1 * N2 * NR1)) ? 1 : warp_size / (N0 * N1 * N2 * NR1);
302 
309  sequence<0, 2>>{});
310  }
311  else
312  {
313  // =========================================================================
314  // CASE 3: Coarse-grained Quantization (NPerQ >= WarpGemm::kN * NWarps)
315  // =========================================================================
316  // The quantization group spans ALL warps in N-dimension.
317  // All warps share the same scale value for their N-tiles.
318  //
319  // Example: NPerQ=128, WarpGemm::kN=16, NWarps=4
320  // → 128 >= 16*4=64, so all 4 warps use the same scale
321 
322  // N1: K-dimension quantization groups
323  // N0: Minimal (1) since scale is shared across N
324  // N2: Elements per thread
325  // NR1: Fixed broadcast size
326  // NR0: Remaining interleave factor
327 
328  constexpr auto N1 = BlockGemmShape::kK / KPerQ;
329  constexpr auto N0 = 1;
330  constexpr auto N2 = 1;
331  constexpr auto NR1 = 32;
332  constexpr auto NR0 =
333  (warp_size <= (N0 * N1 * N2 * NR1)) ? 1 : warp_size / (N0 * N1 * N2 * NR1);
340  sequence<0, 2>>{});
341  }
342  }
343  else
344  {
372  if constexpr(NPerQ < WarpGemm::kN)
373  {
374  // Case 1: Fine-grained - multiple quantization scales within a single warp
375  // N dimension needs to be partitioned the same way regardless of layout
376  constexpr index_t NR = 1; // No N replication needed
377  constexpr index_t N0 = NIterPerWarp; // Iterations per warp in N-dim
378  constexpr index_t N1 = NWarps; // Number of warps in N-dim
379  constexpr index_t N2 = WarpGemm::kN / NPerQ; // Number of scales per warp
380 
381  static_assert(N0 * N1 * N2 == NPerTile,
382  "N0, N1, N2 must cover the blocktile along N dimension.");
383 
384  if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>)
385  {
386  // ColumnMajor: [(N0, N1, N2), K] - N on Y-axis, partition Y
393  sequence<0, 0>>{});
394  }
395  else
396  {
397  // RowMajor: [K, (N0, N1, N2)] - N on X-axis, partition X
404  sequence<0, 0>>{});
405  }
406  }
407  else if constexpr(NPerQ <= WarpGemm::kN * NWarps)
408  {
409  // Case 2: Medium-grained - one quantization scale per warp
410  constexpr auto NR = NPerQ / WarpGemm::kN; // Scale replication factor
411  constexpr auto N1 = NWarps / NR; // Warps per unique scale
412  constexpr auto N0 = NPerTile / N1; // Iterations to cover N dimension
413 
414  if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>)
415  {
416  // ColumnMajor: [(N0, N1), K] - N on Y-axis
423  sequence<0, 0>>{});
424  }
425  else
426  {
427  // RowMajor: [K, (N0, N1)] - N on X-axis
434  sequence<0, 0>>{});
435  }
436  }
437  else // NPerQ > WarpGemm::kN * NWarps
438  {
439  // Case 3: Coarse-grained - quantization group spans all warps
440  // All warps in N-dimension share the same quantization scale
441  if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>)
442  {
443  // ColumnMajor: [N, K]
450  sequence<0, 0>>{});
451  }
452  else
453  {
454  // RowMajor: [K, N]
461  sequence<0, 0>>{});
462  }
463  }
464  }
465  }
466 };
467 
468 template <typename GroupSizes>
470 {
471  static constexpr index_t kM = GroupSizes::at(number<0>{});
472  static constexpr index_t kN = GroupSizes::at(number<1>{});
473  static constexpr index_t kK = GroupSizes::at(number<2>{});
474 
475  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
476  {
477  return concat('_', "quant_group_shape", concat('x', kM, kN, kK));
478  }
479 };
480 
481 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:495
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: gemm_group_quant_utils.hpp:470
static constexpr index_t kM
Definition: gemm_group_quant_utils.hpp:471
static constexpr index_t kK
Definition: gemm_group_quant_utils.hpp:473
static constexpr index_t kN
Definition: gemm_group_quant_utils.hpp:472
static CK_TILE_HOST const std::string GetName()
Definition: gemm_group_quant_utils.hpp:475
Definition: integral_constant.hpp:13
Definition: numeric.hpp:81
Definition: sequence.hpp:49
Definition: gemm_group_quant_utils.hpp:144
static constexpr index_t NWarps
Definition: gemm_group_quant_utils.hpp:151
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Definition: gemm_group_quant_utils.hpp:176
static constexpr index_t MWarps
Definition: gemm_group_quant_utils.hpp:150
static constexpr index_t MIterPerWarp
Definition: gemm_group_quant_utils.hpp:154
static constexpr index_t X
Definition: gemm_group_quant_utils.hpp:162
static constexpr index_t KWarps
Definition: gemm_group_quant_utils.hpp:152
static constexpr index_t Y0
Definition: gemm_group_quant_utils.hpp:167
static constexpr index_t num_warps
Definition: gemm_group_quant_utils.hpp:148
static constexpr index_t Y2
Definition: gemm_group_quant_utils.hpp:172
static constexpr index_t Y1
Definition: gemm_group_quant_utils.hpp:170
static constexpr index_t XR
Definition: gemm_group_quant_utils.hpp:163
static constexpr index_t warp_size
Definition: gemm_group_quant_utils.hpp:147
Definition: gemm_group_quant_utils.hpp:57
static constexpr index_t KWarps
Definition: gemm_group_quant_utils.hpp:64
static constexpr index_t MWarps
Definition: gemm_group_quant_utils.hpp:62
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Definition: gemm_group_quant_utils.hpp:73
static constexpr index_t warp_size
Definition: gemm_group_quant_utils.hpp:59
static constexpr index_t NWarps
Definition: gemm_group_quant_utils.hpp:63
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution_transposed()
Definition: gemm_group_quant_utils.hpp:113
static constexpr index_t num_warps
Definition: gemm_group_quant_utils.hpp:60
static constexpr index_t MIterPerWarp
Definition: gemm_group_quant_utils.hpp:66
Definition: gemm_group_quant_utils.hpp:199
static constexpr index_t NIterPerWarp
Definition: gemm_group_quant_utils.hpp:207
static constexpr index_t MWarps
Definition: gemm_group_quant_utils.hpp:203
static constexpr index_t NWarps
Definition: gemm_group_quant_utils.hpp:204
static constexpr index_t num_warps
Definition: gemm_group_quant_utils.hpp:201
static constexpr index_t warp_size
Definition: gemm_group_quant_utils.hpp:200
static constexpr CK_TILE_HOST_DEVICE auto make_2d_static_tile_distribution()
Definition: gemm_group_quant_utils.hpp:212
static constexpr index_t KWarps
Definition: gemm_group_quant_utils.hpp:205
Definition: static_encoding_pattern.hpp:108
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192