/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp Source File
gridwise_moe_gemm_blockscale.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
16 
18 
19 #define DEBUG_LOG 0
20 
21 namespace ck {
22 
23 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24 // kernel function Blockers:
25 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26 // two lds chunks.
27 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28 // buffer when we declare __shared__ inside blkgemmpipe
29 
31 {
32  gelu_and_mul = 0,
33  silu_and_mul = 1
34 };
35 
36 template <typename GridwiseGemm,
37  bool HasMainKBlockLoop,
38  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
39  index_t MinimumOccupancy = 1,
40  TailNumber TailNum = TailNumber::Even>
41 __global__ void
42 #if CK_USE_LAUNCH_BOUNDS
43 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
44 #endif
45  // __attribute__((amdgpu_waves_per_eu(1, 1)))
46  kernel_moe_gemm(typename GridwiseGemm::Argument karg)
47 {
48 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
49  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
50  {
51  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52 
53  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
54 
55  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56  karg.p_sorted_token_ids,
57  karg.p_sorted_expert_ids,
58  karg.p_max_token_id,
59  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
61  karg.p_ds_grid,
62  karg.p_c_grid,
63  karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
64  karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
65  p_shared,
66  karg,
67  karg.a_element_op,
68  karg.b_element_op,
69  karg.c_element_op);
70  }
71 #else
72  ignore = karg;
73 #endif // end of if (defined(__gfx9__))
74 }
75 
76 template <typename GridwiseGemm,
77  bool HasMainKBlockLoop,
78  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
79  index_t MinimumOccupancy = 1,
80  TailNumber TailNum = TailNumber::Even>
81 __global__ void
82 #if CK_USE_LAUNCH_BOUNDS
83 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
84 #endif
85  // __attribute__((amdgpu_waves_per_eu(1, 1)))
86  kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
87 {
88 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
89  if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
90  {
91  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
92  __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
93 
94  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
95 
96  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
97  karg.p_sorted_token_ids,
98  karg.p_sorted_expert_ids,
99  karg.p_max_token_id,
100  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
101  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
102  karg.p_ds_grid,
103  karg.p_c_grid,
104  karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
105  karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
106  p_shared,
107  p_shared1,
108  karg,
109  karg.a_element_op,
110  karg.b_element_op,
111  karg.c_element_op);
112  }
113 #else
114  ignore = karg;
115 #endif // end of if (defined(__gfx9__))
116 }
117 
118 template <typename ALayout,
119  typename BLayout,
120  typename DsLayout,
121  typename CLayout,
122  typename ADataType,
123  typename BDataType,
124  typename AccDataType,
125  typename CShuffleDataType,
126  typename DsDataType,
127  typename CDataType,
128  typename AElementwiseOperation,
129  typename BElementwiseOperation,
130  typename CElementwiseOperation,
132  index_t BlockSize,
133  index_t ScaleBlockM,
134  index_t ScaleBlockN,
135  index_t ScaleBlockK,
136  index_t MPerBlock,
137  index_t NPerBlock,
138  index_t KPerBlock,
139  index_t AK1Value,
140  index_t BK1Value,
141  index_t MPerXdl,
142  index_t NPerXdl,
143  index_t MXdlPerWave,
144  index_t NXdlPerWave,
145  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
146  typename ABlockTransferThreadClusterArrangeOrder,
147  typename ABlockTransferSrcAccessOrder,
148  index_t ABlockTransferSrcVectorDim,
149  index_t ABlockTransferSrcScalarPerVector,
150  index_t ABlockTransferDstScalarPerVector_AK1,
151  bool AThreadTransferSrcResetCoordinateAfterRun,
152  index_t ABlockLdsExtraM,
153  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
154  typename BBlockTransferThreadClusterArrangeOrder,
155  typename BBlockTransferSrcAccessOrder,
156  index_t BBlockTransferSrcVectorDim,
157  index_t BBlockTransferSrcScalarPerVector,
158  index_t BBlockTransferDstScalarPerVector_BK1,
159  bool BThreadTransferSrcResetCoordinateAfterRun,
160  index_t BBlockLdsExtraN,
161  index_t CShuffleMXdlPerWavePerShuffle,
162  index_t CShuffleNXdlPerWavePerShuffle,
163  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
164  typename CDEShuffleBlockTransferScalarPerVectors,
167  index_t ActivationOperation = 0,
168  bool NSwizzle = false,
169  bool IsInputGemm = true,
170  bool IsSplitK = false,
171  bool MulRoutedWeight = true,
172  typename IndexType = index_t,
173  typename ComputeTypeA = CDataType,
174  typename ComputeTypeB = ComputeTypeA,
175  typename LDSTypeA = ADataType,
176  typename LDSTypeB = BDataType,
177  bool NonTemporalLoadB = false>
179 {
180  using AScaleType = float;
181  using BScaleType = float;
182 
183  static constexpr auto I0 = Number<0>{};
184  static constexpr auto I1 = Number<1>{};
185  static constexpr auto I2 = Number<2>{};
186  static constexpr auto I3 = Number<3>{};
187  static constexpr auto I4 = Number<4>{};
188  static constexpr auto I5 = Number<5>{};
189  static constexpr auto I6 = Number<6>{};
190  static constexpr auto I7 = Number<7>{};
191 
193  CDEShuffleBlockTransferScalarPerVectors{}[I0];
194  // K1 should be Number<...>
195  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
196  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
197  static constexpr auto AK1Number = Number<AK1Value>{};
198  static constexpr auto BK1Number = Number<BK1Value>{};
199  static constexpr auto BlockSizeNumber = Number<BlockSize>{};
200 
201  static constexpr index_t NumDTensor = DsDataType::Size();
202 
204  static constexpr index_t KPack =
206  static constexpr index_t KGroup = []() {
208  // On gfx950, we have a mfma that required 32 f8 elements as input,
209  // splited into 2 groups of 16 f8 elements.
210  // the 2 groups is not contiguous in the B preshuffed layout.
211  // and we do not want it to be contiguous in the B preshuffled layout
212  // because a memory instruction can only read 16 f8 elements at a time.
213  return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
214  else
215  return 1;
216  }();
217  static constexpr index_t KLane =
219  static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
220  static constexpr index_t NLane = NPerXdl;
221  static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
222  // static constexpr index_t NumTokens = 1;
223  static constexpr index_t SortedTileSize = MPerBlock;
224 
225  static constexpr auto MakeDsGridPointer()
226  {
227  return generate_tuple(
228  [&](auto i) {
229  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
230 
231  return static_cast<const DDataType*>(nullptr);
232  },
234  }
235 
236  using DsGridPointer = decltype(MakeDsGridPointer());
237 
239 
240  static constexpr index_t APackedSize = []() {
242  return 2;
243  else
244  return 1;
245  }();
246 
247  static constexpr index_t BPackedSize = []() {
249  return 2;
250  else
251  return 1;
252  }();
253 
254  __host__ static auto CalculateGridSize(index_t M, index_t N, index_t K, index_t KBatch)
255  {
256  const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
257  const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
258  const index_t gridx = NSwizzle ? nblock * mblock : nblock;
259  const index_t gridy = NSwizzle ? 1 : mblock;
260  const index_t gridz = KBatch == 1 ? 1 : math::integer_divide_ceil(K, KPerBlock * KBatch);
261 
262  return std::make_tuple(gridx, gridy, gridz);
263  }
264 
265  __host__ __device__ static auto CalculateMPadded(index_t M)
266  {
267  return math::integer_least_multiple(M, MPerBlock);
268  }
269 
270  __host__ __device__ static auto CalculateNPadded(index_t N)
271  {
272  return math::integer_least_multiple(N, NPerBlock);
273  }
274 
275  __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
276  {
277  return math::integer_divide_ceil(N, NLane);
278  }
279  __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
280  {
282  }
283 
284  __host__ __device__ static auto CalculateKPadded(index_t K)
285  {
286  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
287  }
288 
289  __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
290  {
291  // auto K_t = K_Batch * KPerBlock;
292  // return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
293  return K_Batch == 1 ? K / AK1Value : K_Batch * KPerBlock / AK1Value;
294  }
295 
296  __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
297  {
298  // auto K_t = K_Batch * KPerBlock;
299  // return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
300  return K_Batch == 1 ? K / BK1Value : K_Batch * KPerBlock / BK1Value;
301  }
302 
303  __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
304  {
305  // auto K_t = K_Batch * KPerBlock;
306  // return (K + K_t - 1) / K_t * KPerBlock;
307  return K_Batch == 1 ? K : K_Batch * KPerBlock;
308  }
309 
310  __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
311  {
312  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
313  // auto K_t = K_Batch * KReadVec;
314  // return (K + K_t - 1) / K_t * KReadVec;
315  return K_Batch == 1 ? math::integer_divide_ceil(K, KReadVec) * KReadVec
316  : K_Batch * KPerBlock;
317  }
318 
319  __host__ __device__ static auto CalculateMBlock(index_t M)
320  {
321  return math::integer_divide_ceil(M, MPerBlock);
322  }
323 
324  __host__ __device__ static auto CalculateNBlock(index_t N)
325  {
326  return math::integer_divide_ceil(N, NPerBlock);
327  }
328 
329  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
330  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
331  {
332  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
333  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
334 
336  TileDesc_K0_MN_K1{},
342  }
343 
344  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
345  IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
346  {
347  const auto a_grid_desc_mraw_kraw = [&]() {
348  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
349  {
350  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
351  }
352  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
353  {
354  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
355  }
356  }();
357 
359 
360  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
361  GemmSpec == GemmSpecialization::MNKPadding)
362  {
363  // pad both M and K
364  const auto a_grid_desc_m_k =
365  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
367  make_right_pad_transform(K, KPad - K)),
370 
371  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
372  a_grid_desc_m_k,
377 
378  return a_grid_desc_ak0_m_ak1;
379  }
380  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
381  GemmSpec == GemmSpecialization::MNPadding)
382  {
383  // pad M, but not K
384  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
385  a_grid_desc_mraw_kraw,
387  make_right_pad_transform(M, MPad - M)),
390 
391  return a_grid_desc_ak0_m_ak1;
392  }
393  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
394  GemmSpec == GemmSpecialization::NKPadding)
395  {
396  // pad K, but not M
397  const auto a_grid_desc_m_k = transform_tensor_descriptor(
398  a_grid_desc_mraw_kraw,
402 
403  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
404  a_grid_desc_m_k,
409 
410  return a_grid_desc_ak0_m_ak1;
411  }
412  else
413  {
414  // not pad M or K
415  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
416  a_grid_desc_mraw_kraw,
421  return a_grid_desc_ak0_m_ak1;
422  }
423  }
424 
425  __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
426  {
427  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
428  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
429  constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack / KGroup>{};
431  make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
432  make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
433  }
434 
435  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
436  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
437  {
438  const auto b_grid_desc_nraw_kraw = [&]() {
440  {
441  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
442  }
444  {
445  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
446  }
447  }();
448 
450 
451  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
452  GemmSpec != GemmSpecialization::Default),
453  "pk_i4_t does not support padding");
454 
455  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
456  GemmSpec == GemmSpecialization::MNKPadding)
457  {
458  // pad both N and K
459  const auto b_grid_desc_n_k =
460  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
462  make_right_pad_transform(K, KPad - K)),
465 
466  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
467  b_grid_desc_n_k,
472 
473  return b_grid_desc_bk0_n_bk1;
474  }
475  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
476  GemmSpec == GemmSpecialization::MNPadding)
477  {
478  // pad N, but not K
479  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
480  b_grid_desc_nraw_kraw,
482  make_right_pad_transform(N, NPad - N)),
485 
486  return b_grid_desc_bk0_n_bk1;
487  }
488  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
489  GemmSpec == GemmSpecialization::MKPadding)
490  {
491  // pad K, but not N
492  const auto b_grid_desc_n_k = transform_tensor_descriptor(
493  b_grid_desc_nraw_kraw,
497 
498  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
499  b_grid_desc_n_k,
504 
505  return b_grid_desc_bk0_n_bk1;
506  }
507  else
508  {
509  // not pad N or K
510  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
511  b_grid_desc_nraw_kraw,
516 
517  return b_grid_desc_bk0_n_bk1;
518  }
519  }
520 
521  template <typename ABlockDesc_AK0_M_AK1>
522  __host__ __device__ static constexpr auto
523  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
524  {
525  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
526 
527  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
528  }
529 
530  template <typename BBlockDesc_BK0_N_BK1>
531  __host__ __device__ static constexpr auto
532  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
533  {
534  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
535  }
536 
537  template <typename ELayout>
538  __host__ __device__ static auto MakeCGridDescriptor_M_N(
539  IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
540  {
541  const auto c_grid_desc_mraw_nraw = [&]() {
543  {
544  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
545  }
547  {
548  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
549  }
550  }();
551 
552  // pad M and N
553  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
555  make_right_pad_transform(N, NPad - N)),
558  }
559 
560  template <typename DLayout>
561  __host__ __device__ static auto
563  {
564  const auto c_grid_desc_mraw_nraw = [&]() {
566  {
567  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
568  }
570  {
571  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
572  }
573  }();
574 
575  // pad M and N
576  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
578  make_right_pad_transform(N, NPad - N)),
581  }
582 
583  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
584  index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
585  {
586  return generate_tuple(
587  [&](auto i) {
588  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
589  return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
590  },
592  }
593 
594  template <typename DsGridDesc>
596  const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
597  {
598  return generate_tuple(
599  [&](auto i) {
601  ds_grid_desc_m_n[i], MBlock, NBlock);
602  },
604  }
605 
606  using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
607 
608  struct Problem
609  {
610  __host__ __device__ Problem(index_t NumTokens_,
611  index_t TopK_,
612  index_t M_,
613  index_t N_,
614  index_t K_,
615  index_t StrideA_,
616  index_t StrideB_,
617  std::array<index_t, NumDTensor> StrideDs_,
618  index_t StrideC_,
619  index_t KBatch_)
620  : NumTokens{NumTokens_},
621  TopK{TopK_},
622  M{M_},
623  N{N_},
624  K{K_},
625  StrideA{StrideA_},
626  StrideB{StrideB_},
627  StrideDs{StrideDs_},
628  StrideC{StrideC_},
629  KBatch{KBatch_},
632  KRead{CalculateKRead(K_, KBatch_)},
633  KPadded{CalculateKPadded(K_, KBatch_)},
634  AK0{CalculateAK0Padded(K_, KBatch_)},
635  BK0{CalculateBK0Padded(K_, KBatch_)},
636  MBlock{CalculateMBlock(M_)},
638  {
639  }
640 
641  __host__ void Print() const
642  {
643  std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
644  << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
645  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
646  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
647  << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
648  << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
649  << "NBlock: " << NBlock << "}" << std::endl;
650  }
651 
659  std::array<index_t, NumDTensor> StrideDs;
670  };
671 
672  // Argument
674  {
675  __host__ Argument(const index_t* p_sorted_token_ids_,
676  const index_t* p_sorted_expert_ids_,
677  const index_t* p_max_token_id_,
678  const ADataType* p_a_grid_,
679  const BDataType* p_b_grid_,
680  std::array<const void*, NumDTensor> p_ds_grid_,
681  CDataType* p_c_grid_,
682  index_t NumTokens_,
683  index_t TopK_,
684  index_t M_,
685  index_t N_,
686  index_t K_,
687  index_t StrideA_,
688  index_t StrideB_,
689  std::array<index_t, NumDTensor> StrideDs_,
690  index_t StrideC_,
691  const AScaleType* p_a_scale_grid_,
692  const BScaleType* p_b_scale_grid_,
693  index_t k_batch_,
694  AElementwiseOperation a_element_op_,
695  BElementwiseOperation b_element_op_,
696  CElementwiseOperation c_element_op_)
697  : Problem{NumTokens_,
698  TopK_,
699  M_,
700  N_,
701  K_,
702  StrideA_,
703  StrideB_,
704  StrideDs_,
705  StrideC_,
706  k_batch_},
707  p_sorted_token_ids{p_sorted_token_ids_},
708  p_sorted_expert_ids{p_sorted_expert_ids_},
709  p_max_token_id{p_max_token_id_},
710  p_a_grid{p_a_grid_},
711  p_b_grid{p_b_grid_},
712  p_ds_grid{},
713  p_c_grid{p_c_grid_},
714  p_a_scale_grid{p_a_scale_grid_},
715  p_b_scale_grid{p_b_scale_grid_},
716  a_element_op{a_element_op_},
717  b_element_op{b_element_op_},
718  c_element_op{c_element_op_}
719  {
720 
721  // populate pointer, desc for Ds
722  static_for<0, NumDTensor, 1>{}([&](auto i) {
723  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
724 
725  // D pointer
726  p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
727  });
728  }
729 
733  const ADataType* p_a_grid;
734  const BDataType* p_b_grid;
736  CDataType* p_c_grid;
737 
740 
741  const AElementwiseOperation a_element_op;
742  const BElementwiseOperation b_element_op;
743  const CElementwiseOperation c_element_op;
744  };
745 
747  {
748  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
749  {
750  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
751  {
752  a_k_split_offset = k_id * karg.KRead / APackedSize;
754  }
755  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
756  {
757  a_k_split_offset = k_id * karg.KRead * karg.StrideA;
759  }
760 
761  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
762  {
763  b_k_split_offset = k_id * karg.KRead * karg.StrideB;
765  }
766  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
767  {
768  // KPack * NLane * KLane * K0 * N0
769  b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
770  bscale_k_split_offset = k_id * karg.KRead / ScaleBlockK;
771  }
772 
773  // if(k_id < karg.KBatch - 1)
774  // {
775  // karg.K = karg.KRead;
776  // }
777  // else
778  // {
779  // karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
780  // }
781  }
782 
787  };
788 
789  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
790  {
791  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
792  constexpr index_t WaveSize = BlockSize / (MWave * NWave);
793  // A matrix in LDS memory, dst of blockwise copy
794  if constexpr(ABlockLdsExtraM)
795  {
799  }
800  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
801  // in some cases.
803  {
804  constexpr auto a_lds_block_desc =
807 
808  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
809  a_lds_block_desc,
815 
816  return a_lds_block_desc_permuted;
817  }
818  else // ColumnMajor A
819  {
820  // kfold and mpair dimension is not always required.
821  // more dimension in merge_transform increase the difficulty of generating immarg offset
822  // for compiler.
823  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
824  constexpr auto M1 = MPerBlock / M0;
825 
826  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
827  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
828  constexpr auto KThreadRead = WaveSize / MPerXdl;
829  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
830 
831  constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
832  ? 1
833  : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
834  constexpr auto KThreadReadPerm =
835  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
836  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
837  : KThreadRead;
838 
839  // 1<=mpair<=n0
840  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
841  ? 1
842  : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
843  ? M0
844  : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
845 
846  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
850  Number<kfold * M0 / mpair>{},
851  Number<mpair>{},
852  AK1Number));
853 
854  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
855  a_lds_block_desc,
856  make_tuple(
860  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
863  make_tuple(
865  make_tuple(
867 
868  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
869  a_lds_block_desc_permuted,
870  make_tuple(
878  Sequence<1>{},
879  Sequence<2>{},
880  Sequence<3>{},
881  Sequence<4>{},
882  Sequence<5>{}),
884  Sequence<2>{},
885  Sequence<0, 3>{},
886  Sequence<4, 5>{},
887  Sequence<6>{},
888  Sequence<7>{}));
889 
890  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
891  a_lds_block_desc_unmerged,
894  Number<KThreadWrite / kfold / KThreadReadPerm>{},
895  Number<kfold>{},
902 
903  return a_lds_block_desc_ak0_m_ak1;
904  }
905  }
906 
907  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
908  {
909  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
912  }
913 
915  {
916  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
917 
918  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
920  make_tuple(I1,
922  I1,
924 
925  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
926  }
927 
930  BlkGemmPipelineVer,
931  BlkGemmPipeSched,
932  BlockSize,
933  ADataType,
934  BDataType,
935  ComputeTypeA,
936  AccDataType,
943  ABlockTransferSrcScalarPerVector,
944  BBlockTransferSrcScalarPerVector,
945  MPerBlock,
946  NPerBlock,
947  KPerBlock,
948  ScaleBlockM,
949  ScaleBlockN,
950  ScaleBlockK,
951  MPerXdl,
952  NPerXdl,
953  MXdlPerWave,
954  NXdlPerWave,
955  KPack,
956  IsInputGemm && !IsSplitK > ())>;
957 
958  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
959  {
960  // LDS allocation for A and B: be careful of alignment
961  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
962  // lds max alignment
963  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
964 
965  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
966  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
967 
968  // LDS allocation for C shuffle in LDS
969  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
971 
972  constexpr auto c_block_size =
973  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
974 
975  return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize,
976  c_block_size * sizeof(CShuffleDataType));
977  }
978 
980 
981  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
982  __host__ static constexpr bool CheckValidity(const Argument& karg)
983  {
984  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
985  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
986  "Invalid tuning param!");
987 
988  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
993  {
994  if(!(karg.M % MPerBlock == 0))
995  {
996 #if DEBUG_LOG
997  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
998  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
999  << std::endl;
1000 
1001 #endif // DEBUG_LOG
1002  return false;
1003  }
1004  }
1005 
1006  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1011  {
1012  if(!(karg.N % NPerBlock == 0))
1013  {
1014 #if DEBUG_LOG
1015  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1016  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1017  << std::endl;
1018 
1019 #endif // DEBUG_LOG
1020  return false;
1021  }
1022  }
1023 
1024  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1028  {
1029 
1030  auto K_t = karg.KBatch * KPerBlock;
1031  if(!(karg.K % K_t == 0))
1032  {
1033 #if DEBUG_LOG
1034  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1035  << karg.K << " " << __FILE__ << ":" << __LINE__
1036  << ", in function: " << __func__ << std::endl;
1037 
1038 #endif // DEBUG_LOG
1039  return false;
1040  }
1041  }
1042  else
1043  {
1044  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1045  auto K_t = karg.KBatch * KReadVec;
1046  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1047  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1048  {
1049  return false;
1050  }
1051  }
1052 
1054  {
1055  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1056  {
1057 #if DEBUG_LOG
1058  std::cout << "Arg K (" << karg.K
1059  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1060  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1061  << __LINE__ << ", in function: " << __func__ << std::endl;
1062 
1063 #endif // DEBUG_LOG
1064  return false;
1065  }
1066  }
1067  else
1068  {
1069  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1070  {
1071 #if DEBUG_LOG
1072  std::cout << "Arg M (" << karg.M
1073  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1074  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1075  << __LINE__ << ", in function: " << __func__ << std::endl;
1076 
1077 #endif // DEBUG_LOG
1078  return false;
1079  }
1080  }
1081 
1083  {
1084  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1085  {
1086 #if DEBUG_LOG
1087  std::cout << "Arg N (" << karg.N
1088  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1089  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1090  << __LINE__ << ", in function: " << __func__ << std::endl;
1091 
1092 #endif // DEBUG_LOG
1093  return false;
1094  }
1095  }
1096  else
1097  {
1098  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1099  {
1100 #if DEBUG_LOG
1101  std::cout << "Arg K (" << karg.K
1102  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1103  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1104  << __LINE__ << ", in function: " << __func__ << std::endl;
1105 
1106 #endif // DEBUG_LOG
1107  return false;
1108  }
1109  }
1110 
1112  {
1114  {
1115 #if DEBUG_LOG
1116  std::cout << "Arg N (" << karg.N
1117  << ") value is not a multiple of "
1118  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1119  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1120  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1121 
1122 #endif // DEBUG_LOG
1123  return false;
1124  }
1125  }
1126  else
1127  {
1129  {
1130 #if DEBUG_LOG
1131  std::cout << "Arg M (" << karg.M
1132  << ") value is not a multiple of "
1133  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1134  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1135  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1136 
1137 #endif // DEBUG_LOG
1138  return false;
1139  }
1140  }
1141 
1142  // check gridwise gemm pipeline
1143 #if 0
1144  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1145 
1146  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1147  {
1148  return false;
1149  }
1150 #endif
1151  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1152  return true;
1153  }
1154 
1155  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1156  {
1157  const index_t num_loop = K / KPerBlock;
1158 
1159  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1160  }
1161 
1162  __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1163  {
1164  const index_t num_loop = K / KPerBlock;
1165 
1166  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1167  }
1168 
1169  template <typename CGridDesc>
1171  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1172  {
1173  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1174  c_grid_desc_m_n,
1179 
1180  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1181  }
1182 
1183  // return block_id to C matrix tile idx (m0, n0) mapping
1184  // if arch = gfx942
1185  // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1186  // NPerBlock>;
1187 
1188  template <bool HasMainKBlockLoop,
1189  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1190  TailNumber TailNum = TailNumber::Odd>
1191  __device__ static void Run(const index_t* p_sorted_token_ids,
1192  const index_t* p_sorted_expert_ids,
1193  const index_t* p_max_token_id,
1194  const ADataType* p_a_grid,
1195  const BDataType* p_b_grid,
1196  DsGridPointer& p_ds_grid,
1197  CDataType* p_c_grid,
1198  const AScaleType* p_a_scale_grid,
1199  const BScaleType* p_b_scale_grid,
1200  void* p_shared,
1201  const Problem& problem,
1202  AElementwiseOperation a_element_op,
1203  BElementwiseOperation b_element_op,
1204  CElementwiseOperation c_element_op)
1205  {
1206 #if defined(__gfx942__) || defined(__gfx950__)
1207  constexpr auto b_coherence_flag = NonTemporalLoadB
1208  ? AmdBufferCoherenceEnum::WAVE_NT1
1210 #else
1211  constexpr auto b_coherence_flag = AmdBufferCoherenceEnum::DefaultCoherence;
1212 #endif
1213  ignore = b_element_op;
1214  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1));
1215  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1216  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1217  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1218  problem.MPadded,
1219  problem.K,
1220  problem.KPadded,
1221  problem.StrideA,
1222  problem.AK0);
1223  const auto b_grid_desc_bpreshuffled =
1224  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1225  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1226  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1227  problem.MPadded,
1228  problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
1229  problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
1230  problem.StrideC);
1231 
1232  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1233  make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens
1234  : problem.NumTokens * problem.TopK,
1235  ScaleBlockM),
1236  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1237  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1238  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1239  make_tuple(math::integer_divide_ceil(problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
1240  ScaleBlockN),
1241  math::integer_divide_ceil(problem.K, ScaleBlockK)),
1242  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1243 
1244  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1246  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1247  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1248  // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1249  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1250  if(expert_block_id * MPerBlock >= max_token_id)
1251  return;
1252  const index_t expert_id =
1253  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1254  const auto block_mn = [&]() -> std::pair<int, int> {
1255  if constexpr(NSwizzle)
1256  {
1257  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1258  const index_t prefix_block = ecnt_prefix * problem.NBlock;
1259  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1260  const index_t expert_swizzle =
1261  ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1262  const index_t bid_new = blockIdx.x - prefix_block;
1263  const index_t nid = __builtin_amdgcn_readfirstlane(
1264  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1265  const index_t mid =
1266  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1267  return {nid, mid};
1268  }
1269  else
1270  {
1271  return {blockIdx.x, blockIdx.y};
1272  }
1273  }();
1274  const index_t block_n_id = block_mn.first;
1275  const index_t block_m_id = block_mn.second;
1276  const index_t token0 =
1277  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1278 
1279  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1280  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1281  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1282  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1283  constexpr auto AKThreads = AK0Threads * AK1Threads;
1284  constexpr auto AMRepeats = MPerBlock / AMThreads;
1285  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1286 
1287  if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1288  return;
1290  static_for<0, AMRepeats, 1>{}([&](auto m0) {
1291  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1292  index_t token_offset = fused_token & 0xffffff;
1293  if constexpr(!IsInputGemm)
1294  {
1295  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1296  }
1297  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1298  });
1299  const index_t expert_stride =
1300  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1301  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1302  math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
1303  math::integer_divide_ceil(problem.K, ScaleBlockK));
1304 
1305  // N0, K0, Blocksize*KPack
1306  const index_t n_block_data_idx_on_grid =
1307  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1308 
1309  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1310  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1311  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
1312  p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
1313  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1314 
1315  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1316  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1317  const auto b_scale_grid_buf =
1318  make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
1319  p_b_scale_grid + expert_id * expert_scale_stride,
1320  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1321 
1322  // A matrix in LDS memory, dst of blockwise copy
1323  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1324 
1325  // B matrix in LDS memory, dst of blockwise copy
1326  // dummy
1327  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1328  // A matrix blockwise copy
1329  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1331  AElementwiseOperation,
1335  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1336  ABlockTransferThreadClusterArrangeOrder,
1337  ADataType,
1338  LDSTypeA,
1339  decltype(a_grid_desc_ak0_m_ak1),
1340  decltype(a_block_desc_ak0_m_ak1),
1341  ABlockTransferSrcAccessOrder,
1343  ABlockTransferSrcVectorDim,
1344  2,
1345  ABlockTransferSrcScalarPerVector,
1346  ABlockTransferDstScalarPerVector_AK1,
1347  1,
1348  1,
1349  AThreadTransferSrcResetCoordinateAfterRun,
1350  true,
1351  IndexType,
1352  1,
1353  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1354  make_multi_index(0, 0, 0),
1355  a_element_op,
1356  a_block_desc_ak0_m_ak1,
1357  make_multi_index(0, 0, 0),
1359  gather_offsets);
1360 
1361  // Thread-wise copy
1362  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1363  auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1364  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1365 
1366  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1367  BDataType,
1368  BDataType,
1369  decltype(b_grid_desc_bpreshuffled),
1370  decltype(b_block_desc_bk0_n_bk1),
1373  3,
1374  BBlockTransferSrcScalarPerVector,
1375  BThreadTransferSrcResetCoordinateAfterRun,
1376  true>(b_grid_desc_bpreshuffled,
1377  make_multi_index(n_block_data_idx_on_grid,
1379  0,
1380  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1381 
1382  // LDS allocation for A and B: be careful of alignment
1383  // Cast after lds
1384  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1385  static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1386 
1387  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1388  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1389 
1390  // Blockwise GEMM pipeline
1391  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1392  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1393  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1394  decltype(c_thread_buf) c_thread_buf_up;
1395 
1396  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1397  problem.KBatch == 1
1398  ? (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1399  KPerBlock
1400  : problem.KBatch);
1401  constexpr index_t ScaleSliceSizeM = MXdlPerWave;
1402  constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
1403  constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
1404 
1405  // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
1406  // ScaleSliceSizeK is first dimension in C scale for packed math
1407  constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
1409 
1410  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1411  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1412  constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
1413  auto a_thread_offset = get_thread_local_1d_id() % MPerXdl +
1414  (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl;
1415 
1416  constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
1418 
1419  constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
1421 
1422  // get each thread's offset in the scale tensor
1423  // A scale
1424  const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
1425 
1426  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
1427  return;
1428  StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
1429  static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
1430  const index_t fused_token =
1431  p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
1432  index_t token_offset = fused_token & 0xffffff;
1433  if constexpr(!IsInputGemm)
1434  {
1435  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1436  }
1437  scale_gather_offsets(m0) =
1438  token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK);
1439  });
1440 
1441  auto a_scale_thread_copy =
1443  AScaleType,
1444  decltype(a_scale_grid_desc_am_ak),
1445  decltype(a_scale_thread_desc),
1448  1,
1449  ScaleSliceSizeK,
1450  1,
1451  false,
1452  MXdlPerWave>(
1453  a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets);
1454 
1455  auto b_scale_thread_copy =
1457  BScaleType,
1458  decltype(b_scale_grid_desc_bn_ak),
1459  decltype(b_scale_thread_desc),
1462  1,
1463  ScaleSliceSizeK,
1464  1,
1465  false>(
1466  b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1467 
1468  // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
1469  constexpr auto a_scale_thread_slice_copy_step =
1470  make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK));
1471  constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
1472 
1473  constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
1474  if constexpr(IsInputGemm && !IsSplitK)
1475  {
1476  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
1477  const auto b_grid_buf_up =
1478  make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
1479  p_b_grid_up +
1480  expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
1481  b_grid_desc_bpreshuffled.GetElementSpaceSize());
1482  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1483  BDataType,
1484  BDataType,
1485  decltype(b_grid_desc_bpreshuffled),
1486  decltype(b_block_desc_bk0_n_bk1),
1489  3,
1490  BBlockTransferSrcScalarPerVector,
1491  BThreadTransferSrcResetCoordinateAfterRun,
1492  true>(b_grid_desc_bpreshuffled,
1493  make_multi_index(n_block_data_idx_on_grid,
1495  0,
1496  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1497  const BScaleType* p_b_scale_grid_up =
1498  p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
1499  const auto b_scale_grid_buf_up =
1500  make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
1501  p_b_scale_grid_up + expert_id * expert_scale_stride,
1502  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1503  auto b_scale_thread_copy_up =
1505  BScaleType,
1506  decltype(b_scale_grid_desc_bn_ak),
1507  decltype(b_scale_thread_desc),
1510  1,
1511  ScaleSliceSizeK,
1512  1,
1513  false>(
1514  b_scale_grid_desc_bn_ak,
1515  make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1516 
1517  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1518  a_grid_desc_ak0_m_ak1,
1519  a_block_desc_ak0_m_ak1,
1520  a_blockwise_copy,
1521  a_grid_buf,
1522  a_block_buf,
1523  a_block_slice_copy_step,
1524 
1525  b_grid_desc_bpreshuffled,
1526  b_block_desc_bk0_n_bk1,
1527  b_blockwise_copy,
1528  b_blockwise_copy_up,
1529  b_grid_buf,
1530  b_grid_buf_up,
1531  b_block_buf,
1532  b_block_slice_copy_step,
1533 
1534  c_scale_thread_desc,
1535  c_thread_buf,
1536  c_thread_buf_up,
1537 
1538  a_scale_grid_desc_am_ak,
1539  a_scale_thread_desc,
1540  a_scale_thread_copy,
1541  a_scale_grid_buf,
1542  a_scale_thread_slice_copy_step,
1543 
1544  b_scale_grid_desc_bn_ak,
1545  b_scale_thread_desc,
1546  b_scale_thread_copy,
1547  b_scale_thread_copy_up,
1548  b_scale_grid_buf,
1549  b_scale_grid_buf_up,
1550  b_scale_thread_slice_copy_step,
1551 
1552  num_k_block_main_loop);
1553  }
1554  else
1555  {
1556  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1557  a_grid_desc_ak0_m_ak1,
1558  a_block_desc_ak0_m_ak1,
1559  a_blockwise_copy,
1560  a_grid_buf,
1561  a_block_buf,
1562  a_block_slice_copy_step,
1563 
1564  b_grid_desc_bpreshuffled,
1565  b_block_desc_bk0_n_bk1,
1566  b_blockwise_copy,
1567  b_grid_buf,
1568  b_block_buf,
1569  b_block_slice_copy_step,
1570 
1571  c_scale_thread_desc,
1572  c_thread_buf,
1573 
1574  a_scale_grid_desc_am_ak,
1575  a_scale_thread_desc,
1576  a_scale_thread_copy,
1577  a_scale_grid_buf,
1578  a_scale_thread_slice_copy_step,
1579 
1580  b_scale_grid_desc_bn_ak,
1581  b_scale_thread_desc,
1582  b_scale_thread_copy,
1583  b_scale_grid_buf,
1584  b_scale_thread_slice_copy_step,
1585 
1586  num_k_block_main_loop);
1587  }
1588 
1589  // shuffle C and write out
1590  {
1591  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1592  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1593  "wrong!");
1594 
1595  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1596 
1597  // transposed XDL
1598  // TODO: hacky, fix it!
1599  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1600  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1601 
1602  // TODO: hacky, fix it!
1603  // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
1604  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
1605  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1606 
1607  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
1608  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
1609  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
1610  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
1611  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
1612  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
1613  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
1614  constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
1615 
1616  static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
1617  static_assert(M0 * M1 * M2 == MPerBlock);
1618  static_assert(N4 == 4 || N4 == 8);
1619  const index_t m1 = get_warp_local_1d_id() / NWave;
1620  const index_t m2 = threadIdx.x % get_warp_size() % M2;
1621 
1622  float topk_weight;
1623  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1624  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1625  if constexpr(MulRoutedWeight)
1626  {
1627  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
1628  topk_weight = p_ds_grid[I0][m_pos];
1629  }
1630  static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
1631  static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
1632  constexpr index_t c_offset =
1633  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1634  make_tuple(m0, n0, n2 * N4 + n4));
1635  constexpr auto cidx = Number<c_offset>{};
1636  if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise
1637  {
1638  if constexpr(ActivationOperation == Activation::silu_and_mul)
1639  {
1640  float gate = c_thread_buf[cidx];
1641  float up = c_thread_buf_up[cidx];
1642  if constexpr(MulRoutedWeight)
1643  {
1644  gate = gate * topk_weight;
1645  up = up * topk_weight;
1646  }
1648  {
1649  gate *= 16;
1650  up *= 16;
1651  }
1653  c_thread_buf(cidx) = gate * up;
1654  }
1655  else if(ActivationOperation == Activation::gelu_and_mul)
1656  {
1657  float gate = c_thread_buf[cidx];
1658  float up = c_thread_buf_up[cidx];
1659  if constexpr(MulRoutedWeight)
1660  {
1661  gate = gate * topk_weight;
1662  up = up * topk_weight;
1663  }
1665  {
1666  gate *= 16;
1667  up *= 16;
1668  }
1670  c_thread_buf(cidx) = gate * up;
1671  }
1672  }
1673  else
1674  {
1675  if constexpr(MulRoutedWeight)
1676  {
1677  c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
1678  }
1679  }
1680  });
1681  });
1682  });
1683  });
1684 
1685  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1687 
1688  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1689  static_cast<CShuffleDataType*>(p_shared),
1690  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1691 
1692  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
1693  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1694  make_tuple(
1697  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1698  M1, // M1 = MWave
1699  M2)), // M2 = MPerXdl
1702  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1703  N1, // N1 = NWave
1704  N2, // N2 * N3 * N4 = NPerXdl
1705  N3,
1706  N4))),
1708  make_tuple(
1710 
1711  // calculate origin of thread output tensor on global memory
1712  // blockwise GEMM c matrix starting index
1713  const auto c_thread_mtx_on_block =
1714  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1715 
1716  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1717  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1718 
1719  const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1723  make_tuple(Sequence<0>{}));
1724 
1725  const auto m_thread_data_on_block_idx =
1726  m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1727  make_multi_index(m_thread_data_on_block));
1728 
1729  const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1731  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
1733  make_tuple(Sequence<0>{}));
1734 
1735  const auto n_thread_data_on_block_idx =
1736  n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1737  make_multi_index(n_thread_data_on_block));
1738 
1739  // shuffle: threadwise copy C from VGPR to LDS
1740  auto c_thread_copy_vgpr_to_lds =
1742  CShuffleDataType,
1743  decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1744  decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1746  Sequence<CShuffleMXdlPerWavePerShuffle,
1747  CShuffleNXdlPerWavePerShuffle,
1748  I1,
1749  I1,
1750  I1,
1751  N2,
1752  I1,
1753  N4>,
1755  7,
1756  1,
1758  1,
1759  true>{
1760  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1761  make_multi_index(0,
1762  0,
1763  m_thread_data_on_block_idx[I1],
1764  n_thread_data_on_block_idx[I1],
1765  m_thread_data_on_block_idx[I2],
1766  n_thread_data_on_block_idx[I2],
1767  n_thread_data_on_block_idx[I3],
1768  n_thread_data_on_block_idx[I4]),
1770 
1771  using EDataType = CDataType;
1772 
1773  const auto ds_grid_desc_m_n =
1774  MakeDsGridDescriptor_M_N(problem.M,
1775  problem.MPadded,
1776  problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
1777  problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
1778  problem.StrideDs);
1779 
1780  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1782  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1783 
1784  const auto ds_grid_buf = generate_tuple(
1785  [&](auto i) {
1786  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
1787  const DDataType* ptr_ = p_ds_grid[i];
1788  // hack logic here to support different kind of strides. todo fix it.
1789  // ascale t, 1; bscale E, N, 1, move ptr to E
1790  return make_dynamic_buffer<AddressSpaceEnum::Global>(
1791  ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
1792  },
1793  Number<NumDTensor>{});
1794 
1795  // tuple of reference to C/Ds tensor descriptors
1796  const auto c_ds_desc_refs = concat_tuple_of_reference(
1797  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1798  generate_tie([&](auto i) -> const auto& // return type should be reference
1799  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1800  Number<NumDTensor>{}));
1801 
1802  // tuple of reference to C/Ds tensor descriptors
1803  const auto c_ds_buf_refs = concat_tuple_of_reference(
1804  tie(c_shuffle_block_buf),
1805  generate_tie([&](auto i) -> const auto& // return type should be reference
1806  { return ds_grid_buf[i]; },
1807  Number<NumDTensor>{}));
1808 
1809  // tuple of starting index of C/Ds blockwise copy
1810  const auto idx_c_ds_block_begin =
1813  [&](auto) {
1814  return make_multi_index(block_m_id, 0, block_n_id, 0);
1815  // return make_multi_index(block_work_idx[I0], 0,
1816  // block_work_idx[I1], 0);
1817  },
1818  Number<NumDTensor>{}));
1819 
1820  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1821  c_grid_desc_mblock_mperblock_nblock_nperblock;
1822 
1823  using CDEBlockTransferCluster =
1824  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1825  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1826  constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix
1827  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1829  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1831  decltype(c_ds_desc_refs),
1832  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1833  CElementwiseOperation,
1834  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1835  // support arbitray type
1836  Sequence<1,
1837  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1838  1,
1839  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1840  CDEBlockTransferCluster,
1841  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1842  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1843  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1844  3, // index_t SrcVectorDim,
1845  3, // index_t DstVectorDim,
1846  CDEShuffleBlockTransferScalarPerVectors,
1851  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1852  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1853  IndexType,
1854  1, // ScatterDim
1855  true, // OutputScatter: false, only use scatter weights
1856  scatter_weight_idx // ScatterWeightIdx: ascale
1857  >{c_ds_desc_refs,
1858  idx_c_ds_block_begin,
1859  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1860  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1861  c_element_op};
1862 
1863  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1864  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1865  // space filling curve for threadwise C in VGPR
1866  constexpr auto sfc_c_vgpr =
1869  Sequence<CShuffleMXdlPerWavePerShuffle,
1870  CShuffleNXdlPerWavePerShuffle,
1871  1,
1872  1,
1873  1,
1874  N2,
1875  1,
1876  N4>>{};
1877 
1878  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1879 
1880  // space filling curve for shuffled blockwise C/D/E
1881  constexpr auto sfc_cde_block =
1884  Sequence<1,
1885  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1886  1,
1887  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1888 
1889  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1890  constexpr auto EMThreads =
1891  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1892  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1893  constexpr auto ENThreads =
1894  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1895  static_for<0, num_access, 1>{}([&](auto access_id) {
1896  // make sure it's safe to write to LDS
1898 
1899  auto dstidx = sfc_cde_block.GetIndex(access_id);
1900  const index_t c_token_pos =
1901  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1902  static_for<0, EMRepeats, 1>{}([&](auto m0) {
1903  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1904  index_t token_offset = fused_token & 0xffffff;
1905  if constexpr(IsInputGemm)
1906  {
1907  token_offset = token_offset * problem.TopK + (fused_token >> 24);
1908  }
1909  scatter_offsets(m0) =
1910  token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1);
1911  });
1912 
1913  block_sync_lds();
1914 
1915  // each thread write its data from VGPR to LDS
1916  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1917  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1918  c_thread_buf,
1919  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1920  c_shuffle_block_buf);
1921 
1922  // make sure it's safe to read from LDS
1923  block_sync_lds();
1924 
1925  // each block copy its data from LDS to global
1926  cde_block_copy_lds_and_global.Run(
1927  c_ds_desc_refs,
1928  c_ds_buf_refs,
1929  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1930  tie(c_grid_buf),
1931  scatter_offsets);
1932 
1933  if constexpr(access_id < num_access - 1)
1934  {
1935  constexpr auto cde_lds_and_global_step =
1936  sfc_cde_block.GetForwardStep(access_id);
1937 
1938  // move on Ds
1939  static_for<0, NumDTensor, 1>{}([&](auto i) {
1940  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1941  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1942  });
1943 
1944  // move on E
1945  cde_block_copy_lds_and_global.MoveDstSliceWindow(
1946  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1947  I0,
1948  cde_lds_and_global_step);
1949  }
1950  });
1951  }
1952  }
1953 
1954  template <bool HasMainKBlockLoop,
1955  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1956  TailNumber TailNum = TailNumber::Odd>
1957  __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
1958  const index_t* p_sorted_expert_ids,
1959  const index_t* p_max_token_id,
1960  const ADataType* p_a_grid,
1961  const BDataType* p_b_grid,
1962  DsGridPointer& p_ds_grid,
1963  CDataType* p_c_grid,
1964  const AScaleType* p_a_scale_grid,
1965  const BScaleType* p_b_scale_grid,
1966  void* p_shared,
1967  void* p_shared1,
1968  const Problem& problem,
1969  AElementwiseOperation a_element_op,
1970  BElementwiseOperation b_element_op,
1971  CElementwiseOperation c_element_op)
1972  {
1973 #if defined(__gfx942__) || defined(__gfx950__)
1974  constexpr auto b_coherence_flag = NonTemporalLoadB
1975  ? AmdBufferCoherenceEnum::WAVE_NT1
1977 #else
1978  constexpr auto b_coherence_flag = AmdBufferCoherenceEnum::DefaultCoherence;
1979 #endif
1980  ignore = b_element_op;
1981  index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1982  index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1983  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1984  IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1985  problem.MPadded,
1986  problem.K,
1987  problem.KPadded,
1988  problem.StrideA,
1989  problem.AK0);
1990  const auto b_grid_desc_bpreshuffled =
1991  MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1992  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1993  IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1994  problem.MPadded,
1995  problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
1996  problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
1997  problem.StrideC);
1998 
1999  const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
2000  make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens
2001  : problem.NumTokens * problem.TopK,
2002  ScaleBlockM),
2003  math::integer_divide_ceil(problem.K, ScaleBlockK)),
2004  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
2005  const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
2006  make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
2007  math::integer_divide_ceil(problem.K, ScaleBlockK)),
2008  make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
2009  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2011  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2012  const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2013  const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2014  if(expert_block_id * MPerBlock >= max_token_id)
2015  return;
2016  const index_t expert_id =
2017  __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2018  const auto block_mn = [&]() -> std::pair<int, int> {
2019  if constexpr(NSwizzle)
2020  {
2021  const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2022  const index_t prefix_block = ecnt_prefix * problem.NBlock;
2023  const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2024  const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
2025  const index_t bid_new = blockIdx.x - prefix_block;
2026  const index_t nid = __builtin_amdgcn_readfirstlane(
2027  bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2028  const index_t mid =
2029  __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2030  return {nid, mid};
2031  }
2032  else
2033  {
2034  return {blockIdx.x, blockIdx.y};
2035  }
2036  }();
2037  const index_t block_n_id = block_mn.first;
2038  const index_t block_m_id = block_mn.second;
2039 
2040  const index_t token0 =
2041  __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2042 
2043  // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2044  constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2045  constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2046  constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2047  constexpr auto AKThreads = AK0Threads * AK1Threads;
2048  constexpr auto AMRepeats = MPerBlock / AMThreads;
2049  const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2050 
2051  if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
2052  token0 >= problem.NumTokens)
2053  return;
2055  gather_offsets; //= p_sorted_token_ids[token_pos];
2056  static_for<0, AMRepeats, 1>{}([&](auto m0) {
2057  const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2058  index_t token_offset = fused_token & 0xffffff;
2059  if constexpr(!IsInputGemm)
2060  {
2061  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2062  }
2063  gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2064  });
2065  const index_t expert_stride =
2066  __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2067  const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2068  math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
2069  math::integer_divide_ceil(problem.K, ScaleBlockK));
2070  // N0, K0, Blocksize*KPack
2071  const index_t n_block_data_idx_on_grid =
2072  __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2073 
2074  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2075  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2076  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
2077  p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
2078  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2079 
2080  const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2081  p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2082  const auto b_scale_grid_buf =
2083  make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
2084  p_b_scale_grid + expert_id * expert_scale_stride,
2085  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2086 
2087  // A matrix in LDS memory, dst of blockwise copy
2088  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2089 
2090  // B matrix in LDS memory, dst of blockwise copy
2091  // dummy
2092  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2093  // A matrix blockwise copy
2094  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
2096  AElementwiseOperation,
2100  ABlockTransferThreadClusterLengths_AK0_M_AK1,
2101  ABlockTransferThreadClusterArrangeOrder,
2102  ADataType,
2103  LDSTypeA,
2104  decltype(a_grid_desc_ak0_m_ak1),
2105  decltype(a_block_desc_ak0_m_ak1),
2106  ABlockTransferSrcAccessOrder,
2108  ABlockTransferSrcVectorDim,
2109  2,
2110  ABlockTransferSrcScalarPerVector,
2111  ABlockTransferDstScalarPerVector_AK1,
2112  1,
2113  1,
2114  AThreadTransferSrcResetCoordinateAfterRun,
2115  true,
2116  IndexType,
2117  1,
2118  BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2119  make_multi_index(0, 0, 0),
2120  a_element_op,
2121  a_block_desc_ak0_m_ak1,
2122  make_multi_index(0, 0, 0),
2124  gather_offsets);
2125 
2126  // Thread-wise copy
2127  // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2128  auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2129  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2130  auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2131  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2132  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2133 
2134  auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
2135  BDataType,
2136  BDataType,
2137  decltype(b_grid_desc_bpreshuffled),
2138  decltype(b_block_desc_bk0_n_bk1),
2141  3,
2142  BBlockTransferSrcScalarPerVector,
2143  BThreadTransferSrcResetCoordinateAfterRun,
2144  true>(b_grid_desc_bpreshuffled,
2145  make_multi_index(n_block_data_idx_on_grid,
2147  0,
2148  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2149 
2150  // LDS allocation for A and B: be careful of alignment
2151  // Cast after lds
2152  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2153  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2154  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2155  static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2156  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2157 
2158  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2159  constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
2160 
2161  // Blockwise GEMM pipeline
2162  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2163  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2164  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2165  decltype(c_thread_buf) c_thread_buf_up;
2166 
2167  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2168  problem.KBatch == 1
2169  ? (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2170  KPerBlock
2171  : problem.KBatch);
2172 
2173  // scale
2174  constexpr index_t ScaleSliceSizeM = MXdlPerWave;
2175  constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
2176  constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
2177 
2178  // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
2179  // ScaleSliceSizeK is first dimension in C scale for packed math
2180  constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
2182 
2183  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
2184  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
2185  constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
2186  auto a_thread_offset = get_thread_local_1d_id() % MPerXdl +
2187  (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl;
2188 
2189  constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
2191 
2192  constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
2194 
2195  // get each thread's offset in the scale tensor
2196  // A scale
2197  const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
2198 
2199  if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2200  return;
2201  StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
2202  static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
2203  const index_t fused_token =
2204  p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
2205  index_t token_offset = fused_token & 0xffffff;
2206  if constexpr(!IsInputGemm)
2207  {
2208  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2209  }
2210  scale_gather_offsets(m0) = static_cast<IndexType>(token_offset) *
2211  math::integer_divide_ceil(problem.K, ScaleBlockK);
2212  });
2213 
2214  auto a_scale_thread_copy =
2216  AScaleType,
2217  decltype(a_scale_grid_desc_am_ak),
2218  decltype(a_scale_thread_desc),
2221  1,
2222  ScaleSliceSizeK,
2223  1,
2224  false,
2225  MXdlPerWave>(
2226  a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets);
2227 
2228  auto b_scale_thread_copy =
2230  BScaleType,
2231  decltype(b_scale_grid_desc_bn_ak),
2232  decltype(b_scale_thread_desc),
2235  1,
2236  ScaleSliceSizeK,
2237  1,
2238  false>(
2239  b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2240 
2241  // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
2242  constexpr auto a_scale_thread_slice_copy_step =
2243  make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK));
2244  constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
2245 
2246  constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
2247  if constexpr(IsInputGemm && !IsSplitK)
2248  {
2249  const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2250  const auto b_grid_buf_up =
2251  make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
2252  p_b_grid_up +
2253  expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
2254  b_grid_desc_bpreshuffled.GetElementSpaceSize());
2255  auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2256  BDataType,
2257  BDataType,
2258  decltype(b_grid_desc_bpreshuffled),
2259  decltype(b_block_desc_bk0_n_bk1),
2262  3,
2263  BBlockTransferSrcScalarPerVector,
2264  BThreadTransferSrcResetCoordinateAfterRun,
2265  true>(b_grid_desc_bpreshuffled,
2266  make_multi_index(n_block_data_idx_on_grid,
2268  0,
2269  KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2270  const BScaleType* p_b_scale_grid_up =
2271  p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
2272  const auto b_scale_grid_buf_up =
2273  make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
2274  p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
2275  b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2276  auto b_scale_thread_copy_up =
2278  BScaleType,
2279  decltype(b_scale_grid_desc_bn_ak),
2280  decltype(b_scale_thread_desc),
2283  1,
2284  ScaleSliceSizeK,
2285  1,
2286  false>(
2287  b_scale_grid_desc_bn_ak,
2288  make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2289 
2290  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2291  a_grid_desc_ak0_m_ak1,
2292  a_block_desc_ak0_m_ak1,
2293  a_blockwise_copy,
2294  a_grid_buf,
2295  a_block_bufs,
2296  a_block_slice_copy_step,
2297  b_grid_desc_bpreshuffled,
2298  b_block_desc_bk0_n_bk1,
2299  b_blockwise_copy,
2300  b_blockwise_copy_up,
2301  b_grid_buf,
2302  b_grid_buf_up,
2303  b_block_bufs,
2304  b_block_slice_copy_step,
2305  c_scale_thread_desc,
2306  c_thread_buf,
2307  c_thread_buf_up,
2308  a_scale_grid_desc_am_ak,
2309  a_scale_thread_desc,
2310  a_scale_thread_copy,
2311  a_scale_grid_buf,
2312  a_scale_thread_slice_copy_step,
2313  b_scale_grid_desc_bn_ak,
2314  b_scale_thread_desc,
2315  b_scale_thread_copy,
2316  b_scale_thread_copy_up,
2317  b_scale_grid_buf,
2318  b_scale_grid_buf_up,
2319  b_scale_thread_slice_copy_step,
2320  num_k_block_main_loop);
2321  }
2322  else
2323  {
2324  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2325  a_grid_desc_ak0_m_ak1,
2326  a_block_desc_ak0_m_ak1,
2327  a_blockwise_copy,
2328  a_grid_buf,
2329  a_block_bufs,
2330  a_block_slice_copy_step,
2331  b_grid_desc_bpreshuffled,
2332  b_block_desc_bk0_n_bk1,
2333  b_blockwise_copy,
2334  b_grid_buf,
2335  b_block_bufs,
2336  b_block_slice_copy_step,
2337  c_scale_thread_desc,
2338  c_thread_buf,
2339  a_scale_grid_desc_am_ak,
2340  a_scale_thread_desc,
2341  a_scale_thread_copy,
2342  a_scale_grid_buf,
2343  a_scale_thread_slice_copy_step,
2344  b_scale_grid_desc_bn_ak,
2345  b_scale_thread_desc,
2346  b_scale_thread_copy,
2347  b_scale_grid_buf,
2348  b_scale_thread_slice_copy_step,
2349  num_k_block_main_loop);
2350  }
2351 
2352  // shuffle C and write out
2353  {
2354 
2355  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2356  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2357  "wrong!");
2358 
2359  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2360 
2361  // transposed XDL
2362  // TODO: hacky, fix it!
2363  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
2364  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2365 
2366  // TODO: hacky, fix it!
2367  // only used to get lengths
2368  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
2369  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2370 
2371  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
2372  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
2373  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
2374  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
2375  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
2376  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
2377  constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
2378  constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
2379 
2380  static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
2381  static_assert(M0 * M1 * M2 == MPerBlock);
2382  static_assert(N4 == 4 || N4 == 8);
2383  const index_t m1 = get_warp_local_1d_id() / NWave;
2384  const index_t m2 = threadIdx.x % get_warp_size() % M2;
2385 
2386  float topk_weight;
2387  static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2388  static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2389  if constexpr(MulRoutedWeight)
2390  {
2391  const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
2392  topk_weight = p_ds_grid[I0][m_pos];
2393  }
2394  static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
2395  static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
2396  constexpr index_t c_offset =
2397  blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2398  make_tuple(m0, n0, n2 * N4 + n4));
2399  constexpr auto cidx = Number<c_offset>{};
2400  if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise
2401  {
2402  if constexpr(ActivationOperation == Activation::silu_and_mul)
2403  {
2404  float gate = c_thread_buf[cidx];
2405  float up = c_thread_buf_up[cidx];
2406  if constexpr(MulRoutedWeight)
2407  {
2408  gate = gate * topk_weight;
2409  up = up * topk_weight;
2410  }
2412  {
2413  gate *= 16;
2414  up *= 16;
2415  }
2417  c_thread_buf(cidx) = gate * up;
2418  }
2419  else if(ActivationOperation == Activation::gelu_and_mul)
2420  {
2421  float gate = c_thread_buf[cidx];
2422  float up = c_thread_buf_up[cidx];
2423  if constexpr(MulRoutedWeight)
2424  {
2425  gate = gate * topk_weight;
2426  up = up * topk_weight;
2427  }
2429  {
2430  gate *= 16;
2431  up *= 16;
2432  }
2434  c_thread_buf(cidx) = gate * up;
2435  }
2436  }
2437  else
2438  {
2439  if constexpr(MulRoutedWeight)
2440  {
2441  c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
2442  }
2443  }
2444 
2445  });
2446  });
2447  });
2448  });
2449 
2450  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2452 
2453  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2454  static_cast<CShuffleDataType*>(p_shared),
2455  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2456 
2457  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
2458  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2459  make_tuple(
2462  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2463  M1, // M1 = MWave
2464  M2)), // M2 = MPerXdl
2467  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2468  N1, // N1 = NWave
2469  N2, // N2 * N3 * N4 = NPerXdl
2470  N3,
2471  N4))),
2473  make_tuple(
2475 
2476  // calculate origin of thread output tensor on global memory
2477  // blockwise GEMM c matrix starting index
2478  const auto c_thread_mtx_on_block =
2479  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2480 
2481  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2482  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2483 
2484  const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
2488  make_tuple(Sequence<0>{}));
2489 
2490  const auto m_thread_data_on_block_idx =
2491  m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
2492  make_multi_index(m_thread_data_on_block));
2493 
2494  const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
2496  make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
2498  make_tuple(Sequence<0>{}));
2499 
2500  const auto n_thread_data_on_block_idx =
2501  n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
2502  make_multi_index(n_thread_data_on_block));
2503 
2504  // shuffle: threadwise copy C from VGPR to LDS
2505  auto c_thread_copy_vgpr_to_lds =
2507  CShuffleDataType,
2508  decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2509  decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2511  Sequence<CShuffleMXdlPerWavePerShuffle,
2512  CShuffleNXdlPerWavePerShuffle,
2513  I1,
2514  I1,
2515  I1,
2516  N2,
2517  I1,
2518  N4>,
2520  7,
2521  1,
2523  1,
2524  true>{
2525  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2526  make_multi_index(0,
2527  0,
2528  m_thread_data_on_block_idx[I1],
2529  n_thread_data_on_block_idx[I1],
2530  m_thread_data_on_block_idx[I2],
2531  n_thread_data_on_block_idx[I2],
2532  n_thread_data_on_block_idx[I3],
2533  n_thread_data_on_block_idx[I4]),
2535 
2536  using EDataType = CDataType;
2537 
2538  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2539  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2540 
2541  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2543  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2544 
2545  const auto ds_grid_buf = generate_tuple(
2546  [&](auto i) {
2547  return make_dynamic_buffer<AddressSpaceEnum::Global>(
2548  p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2549  },
2550  Number<NumDTensor>{});
2551 
2552  // tuple of reference to C/Ds tensor descriptors
2553  const auto c_ds_desc_refs = concat_tuple_of_reference(
2554  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2555  generate_tie([&](auto i) -> const auto& // return type should be reference
2556  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2557  Number<NumDTensor>{}));
2558 
2559  // tuple of reference to C/Ds tensor descriptors
2560  const auto c_ds_buf_refs = concat_tuple_of_reference(
2561  tie(c_shuffle_block_buf),
2562  generate_tie([&](auto i) -> const auto& // return type should be reference
2563  { return ds_grid_buf[i]; },
2564  Number<NumDTensor>{}));
2565 
2566  // tuple of starting index of C/Ds blockwise copy
2567  const auto idx_c_ds_block_begin =
2570  [&](auto) {
2571  return make_multi_index(block_m_id, 0, block_n_id, 0);
2572  // return make_multi_index(block_work_idx[I0], 0,
2573  // block_work_idx[I1], 0);
2574  },
2575  Number<NumDTensor>{}));
2576 
2577  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2578  c_grid_desc_mblock_mperblock_nblock_nperblock;
2579 
2580  using CDEBlockTransferCluster =
2581  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2582  const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2583  constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix
2584  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2586  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2588  decltype(c_ds_desc_refs),
2589  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2590  CElementwiseOperation,
2591  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
2592  // support arbitray type
2593  Sequence<1,
2594  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2595  1,
2596  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2597  CDEBlockTransferCluster,
2598  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2599  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2600  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2601  3, // index_t SrcVectorDim,
2602  3, // index_t DstVectorDim,
2603  CDEShuffleBlockTransferScalarPerVectors,
2608  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2609  Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2610  IndexType,
2611  1, // ScatterDim
2612  true, // OutputScatter: false, only use scatter weights
2613  scatter_weight_idx // ScatterWeightIdx: ascale
2614  >{c_ds_desc_refs,
2615  idx_c_ds_block_begin,
2616  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2617  make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2618  c_element_op};
2619 
2620  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2621  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2622  // space filling curve for threadwise C in VGPR
2623  constexpr auto sfc_c_vgpr =
2626  Sequence<CShuffleMXdlPerWavePerShuffle,
2627  CShuffleNXdlPerWavePerShuffle,
2628  1,
2629  1,
2630  1,
2631  N2,
2632  1,
2633  N4>>{};
2634 
2635  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2636 
2637  // space filling curve for shuffled blockwise C/D/E
2638  constexpr auto sfc_cde_block =
2641  Sequence<1,
2642  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2643  1,
2644  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2645 
2646  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2647  constexpr auto EMThreads =
2648  CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2649  constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2650  constexpr auto ENThreads =
2651  CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2652  static_for<0, num_access, 1>{}([&](auto access_id) {
2653  // make sure it's safe to write to LDS
2655  scatter_offsets; //= p_sorted_token_ids[c_token_pos];
2656 
2657  auto dstidx = sfc_cde_block.GetIndex(access_id);
2658  const index_t c_token_pos =
2659  block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2660  static_for<0, EMRepeats, 1>{}([&](auto m0) {
2661  const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2662  index_t token_offset = fused_token & 0xffffff;
2663  if constexpr(IsInputGemm)
2664  {
2665  token_offset = token_offset * problem.TopK + (fused_token >> 24);
2666  }
2667  scatter_offsets(m0) =
2668  token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1);
2669  });
2670 
2671  block_sync_lds();
2672 
2673  // each thread write its data from VGPR to LDS
2674  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2675  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2676  c_thread_buf,
2677  c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2678  c_shuffle_block_buf);
2679 
2680  // make sure it's safe to read from LDS
2681  block_sync_lds();
2682 
2683  // each block copy its data from LDS to global
2684  cde_block_copy_lds_and_global.Run(
2685  c_ds_desc_refs,
2686  c_ds_buf_refs,
2687  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2688  tie(c_grid_buf),
2689  scatter_offsets);
2690 
2691  if constexpr(access_id < num_access - 1)
2692  {
2693  constexpr auto cde_lds_and_global_step =
2694  sfc_cde_block.GetForwardStep(access_id);
2695 
2696  // move on Ds
2697  static_for<0, NumDTensor, 1>{}([&](auto i) {
2698  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2699  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2700  });
2701 
2702  // move on E
2703  cde_block_copy_lds_and_global.MoveDstSliceWindow(
2704  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2705  I0,
2706  cde_lds_and_global_step);
2707  }
2708  });
2709  }
2710  }
2711 };
2712 
2713 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:251
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:54
Definition: ck.hpp:270
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:45
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:835
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:209
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:279
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Block GEMM pipeline version enumeration.
Definition: scheduler_enum.hpp:17
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
int64_t long_index_t
Definition: ck.hpp:302
TailNumber
Tail number enumeration for pipeline buffering.
Definition: scheduler_enum.hpp:49
@ Even
Even number of iterations.
@ Odd
Odd number of iterations.
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:219
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:832
BlockGemmPipelineScheduler
Block GEMM pipeline scheduler enumeration.
Definition: scheduler_enum.hpp:33
@ Intrawave
Schedule within a single wavefront.
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:212
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:84
Definition: gridwise_moe_gemm_blockscale.hpp:674
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm_blockscale.hpp:735
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm_blockscale.hpp:730
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, const AScaleType *p_a_scale_grid_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm_blockscale.hpp:675
const AScaleType * p_a_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:738
const BScaleType * p_b_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:739
const index_t * p_max_token_id
Definition: gridwise_moe_gemm_blockscale.hpp:732
CDataType * p_c_grid
Definition: gridwise_moe_gemm_blockscale.hpp:736
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm_blockscale.hpp:731
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:742
const BDataType * p_b_grid
Definition: gridwise_moe_gemm_blockscale.hpp:734
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:743
const ADataType * p_a_grid
Definition: gridwise_moe_gemm_blockscale.hpp:733
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:741
Definition: gridwise_moe_gemm_blockscale.hpp:609
index_t TopK
Definition: gridwise_moe_gemm_blockscale.hpp:653
index_t MBlock
Definition: gridwise_moe_gemm_blockscale.hpp:668
index_t KPadded
Definition: gridwise_moe_gemm_blockscale.hpp:665
index_t StrideC
Definition: gridwise_moe_gemm_blockscale.hpp:660
index_t MPadded
Definition: gridwise_moe_gemm_blockscale.hpp:662
index_t KRead
Definition: gridwise_moe_gemm_blockscale.hpp:664
index_t BK0
Definition: gridwise_moe_gemm_blockscale.hpp:667
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm_blockscale.hpp:610
index_t K
Definition: gridwise_moe_gemm_blockscale.hpp:656
index_t StrideB
Definition: gridwise_moe_gemm_blockscale.hpp:658
index_t AK0
Definition: gridwise_moe_gemm_blockscale.hpp:666
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm_blockscale.hpp:659
index_t N
Definition: gridwise_moe_gemm_blockscale.hpp:655
index_t KBatch
Definition: gridwise_moe_gemm_blockscale.hpp:661
index_t NPadded
Definition: gridwise_moe_gemm_blockscale.hpp:663
index_t M
Definition: gridwise_moe_gemm_blockscale.hpp:654
index_t NBlock
Definition: gridwise_moe_gemm_blockscale.hpp:669
__host__ void Print() const
Definition: gridwise_moe_gemm_blockscale.hpp:641
index_t NumTokens
Definition: gridwise_moe_gemm_blockscale.hpp:652
index_t StrideA
Definition: gridwise_moe_gemm_blockscale.hpp:657
Definition: gridwise_moe_gemm_blockscale.hpp:747
index_t bscale_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:786
index_t ascale_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:785
index_t a_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:783
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm_blockscale.hpp:748
index_t b_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:784
Definition: gridwise_moe_gemm_blockscale.hpp:179
static constexpr index_t KPack
Definition: gridwise_moe_gemm_blockscale.hpp:204
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm_blockscale.hpp:583
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm_blockscale.hpp:223
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:303
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm_blockscale.hpp:958
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:1170
remove_cvref_t< decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm &&!IsSplitK >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm_blockscale.hpp:956
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm_blockscale.hpp:344
static constexpr auto I4
Definition: gridwise_moe_gemm_blockscale.hpp:187
static constexpr auto I5
Definition: gridwise_moe_gemm_blockscale.hpp:188
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:330
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1957
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm_blockscale.hpp:435
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm_blockscale.hpp:425
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:289
static constexpr auto AK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:195
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm_blockscale.hpp:199
static constexpr auto I2
Definition: gridwise_moe_gemm_blockscale.hpp:185
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1155
static constexpr auto BK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:196
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:324
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1191
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:982
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm_blockscale.hpp:219
static constexpr auto I7
Definition: gridwise_moe_gemm_blockscale.hpp:190
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:523
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:270
static constexpr auto BK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:198
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm_blockscale.hpp:225
static constexpr auto I1
Definition: gridwise_moe_gemm_blockscale.hpp:184
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1162
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm_blockscale.hpp:907
static constexpr auto I0
Definition: gridwise_moe_gemm_blockscale.hpp:183
static constexpr auto I3
Definition: gridwise_moe_gemm_blockscale.hpp:186
float AScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:180
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm_blockscale.hpp:914
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:275
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:319
static constexpr index_t KLane
Definition: gridwise_moe_gemm_blockscale.hpp:217
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm_blockscale.hpp:789
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:296
static constexpr auto I6
Definition: gridwise_moe_gemm_blockscale.hpp:189
float BScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:181
static constexpr index_t NWave
Definition: gridwise_moe_gemm_blockscale.hpp:221
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm_blockscale.hpp:201
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm_blockscale.hpp:192
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:284
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:595
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:310
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:240
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:562
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:265
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t K, index_t KBatch)
Definition: gridwise_moe_gemm_blockscale.hpp:254
static constexpr index_t NLane
Definition: gridwise_moe_gemm_blockscale.hpp:220
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:247
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm_blockscale.hpp:238
static constexpr auto AK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:197
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))> DsGridDesc_M_N
Definition: gridwise_moe_gemm_blockscale.hpp:606
static constexpr index_t KGroup
Definition: gridwise_moe_gemm_blockscale.hpp:206
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:538
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:279
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:532
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm_blockscale.hpp:236
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1255
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1861
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1808
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1855
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:440
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:118
Definition: amd_ck_fp8.hpp:36
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:270
Definition: unary_element_wise_operation.hpp:1041
Definition: unary_element_wise_operation.hpp:340
Definition: unary_element_wise_operation.hpp:1087