/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp Source File
gridwise_gemm_xdlops_bwd_weight.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
18 
19 namespace ck {
20 
21 // Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
22 // be used for low_lengths that are known at compile time and are power of 2, otherwise performance
23 // will be very bad
24 template <typename LowLengths>
26 {
27  static constexpr index_t NDimLow = LowLengths::Size();
28 
31 
33  decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
34 
35  using UpLengths =
36  decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
37 
38  LowLengths low_lengths_;
41 
42  __host__ __device__ constexpr Merge_v4_no_carry() = default;
43 
44  __host__ __device__ constexpr Merge_v4_no_carry(const LowLengths& low_lengths)
45  : low_lengths_{low_lengths},
47  container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
48  up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
49  {
50  static_assert(LowerIndex::Size() == NDimLow, "wrong!");
51  }
52 
53  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
54 
55  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
56 
57  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
58 
59  template <typename LowIdx, typename UpIdx>
60  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
61  const UpIdx& idx_up) const
62  {
63  static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
64  "wrong! inconsistent # of dimension");
65 
66  index_t tmp = idx_up[Number<0>{}];
67 
68  // division and mod
69  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
70  idx_low(i) = tmp / this->low_lengths_scan_[i];
71  tmp %= this->low_lengths_scan_[i];
72  });
73 
74  idx_low(Number<NDimLow - 1>{}) = tmp;
75  }
76 
77  template <typename LowIdxDiff,
78  typename UpIdxDiff,
79  typename LowIdx,
80  typename UpIdx,
81  index_t Hack>
82  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
83  const UpIdxDiff& idx_up_diff,
84  LowIdx& idx_low,
85  const UpIdx& idx_up_new,
86  Number<Hack>) const
87  {
88  static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
89  LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
90  "wrong! inconsistent # of dimension");
91 
92  constexpr auto I0 = Number<0>{};
93  constexpr auto INm1 = Number<NDimLow - 1>{};
94 
95  index_t tmp = idx_up_new[I0];
96 
97  idx_low(INm1) = tmp;
98  idx_diff_low(INm1) = idx_up_diff[I0];
99  }
100 
101  __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
102 
103  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
104  {
105  return true;
106  }
107 
108  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
109  {
113  }
114 
115  template <typename UpIdx>
116  __host__ __device__ static constexpr bool
117  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
118  {
119  return true;
120  }
121 
122  __host__ __device__ void Print() const
123  {
124  printf("{");
125  printf("Merge_v3_direct_division_mod_wrw, ");
126  printf("low_lengths_ ");
127  print_multi_index(low_lengths_);
128  printf("low_lengths_scan_ ");
129  print_multi_index(low_lengths_scan_);
130  printf("up_lengths_ ");
131  print_multi_index(up_lengths_);
132  printf("}");
133  }
134 };
135 
136 template <typename LowLengths>
137 __host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLengths& low_lengths)
138 {
139  return Merge_v4_no_carry<LowLengths>{low_lengths};
140 }
141 
142 template <typename GridwiseGemm,
143  typename FloatA,
144  typename FloatB,
145  typename FloatC,
146  typename AGridDesc_B_K0_M_K1,
147  typename BGridDesc_B_K0_N_K1,
148  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
149  typename AElementwiseOperation,
150  typename BElementwiseOperation,
151  typename CElementwiseOperation,
152  typename CBlockClusterAdaptor,
153  bool HasMainKBlockLoop,
154  bool SplitKOffsetHack>
155 __global__ void
156 #if CK_USE_LAUNCH_BOUNDS
158 #endif
159  kernel_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid,
160  const FloatB* __restrict__ p_b_grid,
161  FloatC* __restrict__ p_c_grid,
162  const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
163  const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
164  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
165  c_grid_desc_mblock_mperblock_nblock_nperblock,
166  const AElementwiseOperation a_element_op,
167  const BElementwiseOperation b_element_op,
168  const CElementwiseOperation c_element_op,
169  const CBlockClusterAdaptor c_block_cluster_adaptor,
170  const long_index_t split_k_stride_a,
171  const long_index_t split_k_stride_b,
172  index_t k_batch)
173 {
174 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
175  defined(__gfx12__)
176  if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
177  {
178  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
179 
180  GridwiseGemm::template Run<HasMainKBlockLoop, SplitKOffsetHack>(
181  p_a_grid,
182  p_b_grid,
183  p_c_grid,
184  p_shared,
185  a_b_k0_m_k1_grid_desc,
186  b_b_k0_n_k1_grid_desc,
187  c_grid_desc_mblock_mperblock_nblock_nperblock,
188  a_element_op,
189  b_element_op,
190  c_element_op,
191  c_block_cluster_adaptor,
192  split_k_stride_a,
193  split_k_stride_b,
194  k_batch);
195  }
196 #else
197  ignore = p_a_grid;
198  ignore = p_b_grid;
199  ignore = p_c_grid;
200  ignore = a_b_k0_m_k1_grid_desc;
201  ignore = b_b_k0_n_k1_grid_desc;
202  ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
203  ignore = a_element_op;
204  ignore = b_element_op;
205  ignore = c_element_op;
206  ignore = c_block_cluster_adaptor;
207  ignore = split_k_stride_a;
208  ignore = split_k_stride_b;
209  ignore = k_batch;
210 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
211 }
212 
213 template <index_t BlockSize,
214  typename FloatA,
215  typename FloatB,
216  typename FloatAcc,
217  typename FloatC,
218  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
219  typename AGridDesc_B_K0_M_K1,
220  typename BGridDesc_B_K0_N_K1,
221  typename CMNGridDesc,
222  typename AElementwiseOperation,
223  typename BElementwiseOperation,
224  typename CElementwiseOperation,
225  index_t MPerBlock,
226  index_t NPerBlock,
227  index_t K0PerBlock,
228  index_t MPerXdl,
229  index_t NPerXdl,
230  index_t K1Value,
231  index_t MRepeat,
232  index_t NRepeat,
233  typename ABlockTransferThreadClusterLengths_K0_M_K1,
234  typename ABlockTransferThreadClusterArrangeOrder,
235  typename ABlockTransferSrcAccessOrder,
236  index_t ABlockTransferSrcVectorDim,
237  index_t ABlockTransferSrcScalarPerVector,
238  index_t ABlockTransferDstScalarPerVector_K1,
239  bool AThreadTransferSrcResetCoordinateAfterRun,
240  bool ABlockLdsExtraM,
241  index_t ABlockLdsM1PerBlock,
242  index_t ABlockLdsM0PerBlock,
243  index_t ABlockLdsM1Padding,
244  typename BBlockTransferThreadClusterLengths_K0_N_K1,
245  typename BBlockTransferThreadClusterArrangeOrder,
246  typename BBlockTransferSrcAccessOrder,
247  index_t BBlockTransferSrcVectorDim,
248  index_t BBlockTransferSrcScalarPerVector,
249  index_t BBlockTransferDstScalarPerVector_K1,
250  bool BThreadTransferSrcResetCoordinateAfterRun,
251  bool BBlockLdsExtraN,
252  index_t BBlockLdsN1PerBlock,
253  index_t BBlockLdsN0PerBlock,
254  index_t BBlockLdsN1Padding,
255  index_t CShuffleMRepeatPerShuffle,
256  index_t CShuffleNRepeatPerShuffle,
257  index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
258  typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
259  bool ABlockLdsExtraM1Wrw = false,
260  bool BBlockLdsExtraN1Wrw = false,
261  index_t NumGemmKPrefetchStage = 1,
262  PipelineVersion PipelineVer = PipelineVersion::v1,
263  typename ComputeTypeA = FloatA,
264  typename ComputeTypeB = ComputeTypeA>
266 {
267  static constexpr auto I0 = Number<0>{};
268  static constexpr auto I1 = Number<1>{};
269  static constexpr auto I2 = Number<2>{};
270  static constexpr auto I3 = Number<3>{};
271  static constexpr auto I4 = Number<4>{};
272  static constexpr auto I5 = Number<5>{};
273  static constexpr auto I6 = Number<6>{};
274  static constexpr auto I7 = Number<7>{};
275 
276  // K1 should be Number<...>
277  static constexpr auto K1 = Number<K1Value>{};
278 
280 
282  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
283 
284  // denorm test fix, required to work around fp16 mfma issue
285  // we convert fp16->fp32->bf16 and execute bf16 mfma instruction
286  // when mfma if fixed, remove this section and update
287  // FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB,
288  // throughout this file
289 #if CK_GFX90A_DENORM_WORKAROUND
290  using FloatAAdjusted =
292  using FloatBAdjusted =
294 #else
297 #endif
298 
299  // M0/M1/M1Padding
300  static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{};
301  static constexpr auto M0PerBlock = Number<ABlockLdsM0PerBlock>{};
302  static constexpr auto M1Padding = Number<ABlockLdsM1Padding>{};
303 
304  // N0/N1/N1Padding
305  static constexpr auto N1PerBlock = Number<BBlockLdsN1PerBlock>{};
306  static constexpr auto N0PerBlock = Number<BBlockLdsN0PerBlock>{};
307  static constexpr auto N1Padding = Number<BBlockLdsN1Padding>{};
308 
309  __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
310  {
311  constexpr auto max_lds_align = K1;
312 
313  // A matrix in LDS memory, dst of blockwise copy
314  constexpr auto a_block_desc_k0_m_k1 = [&]() {
315  if constexpr(ABlockLdsExtraM)
316  {
317  if constexpr(ABlockLdsExtraM1Wrw)
318  {
319  constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor(
320  make_tuple(
322  make_tuple(Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding),
323  Number<M1PerBlock>{} * K1 + M1Padding,
324  K1,
325  I1));
326 
327  constexpr auto a_block_desc_k0_m_k1_tmp = transform_tensor_descriptor(
328  a_block_desc_k0_m0_m1_k1,
335 
336  return a_block_desc_k0_m_k1_tmp;
337  }
338  else
339  {
342  make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
343  }
344  }
345  else
346  {
348  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
349  }
350  }();
351 
352  return a_block_desc_k0_m_k1;
353  }
354 
355  __host__ __device__ static constexpr auto GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1()
356  {
357  constexpr auto max_lds_align = K1;
358 
359  // A matrix in LDS memory, dst of blockwise copy
360  constexpr auto a_block_desc_b_k0_m_k1 = [&]() {
361  if constexpr(ABlockLdsExtraM)
362  {
363  if constexpr(ABlockLdsExtraM1Wrw)
364  {
365  constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor(
370  K1),
372  (Number<M1PerBlock>{} * K1 + M1Padding),
373  Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding),
374  Number<M1PerBlock>{} * K1 + M1Padding,
375  K1,
376  I1));
377 
378  constexpr auto a_block_desc_b_k0_m_k1_tmp = transform_tensor_descriptor(
379  a_block_desc_b_k0_m0_m1_k1,
387 
388  return a_block_desc_b_k0_m_k1_tmp;
389  }
390  else
391  {
395  Number<MPerBlock + 1>{} * K1,
396  K1,
397  I1));
398  }
399  }
400  else
401  {
404  max_lds_align);
405  }
406  }();
407 
408  return a_block_desc_b_k0_m_k1;
409  }
410 
411  __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
412  {
413  constexpr auto max_lds_align = K1;
414 
415  // B matrix in LDS memory, dst of blockwise copy
416  constexpr auto b_block_desc_k0_n_k1 = [&]() {
417  if constexpr(BBlockLdsExtraN)
418  {
419  if constexpr(BBlockLdsExtraN1Wrw)
420  {
421  constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor(
422  make_tuple(
424  make_tuple(Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding),
425  Number<N1PerBlock>{} * K1 + N1Padding,
426  K1,
427  I1));
428 
429  constexpr auto b_block_desc_k0_n_k1_tmp = transform_tensor_descriptor(
430  b_block_desc_k0_n0_n1_k1,
437 
438  return b_block_desc_k0_n_k1_tmp;
439  }
440  else
441  {
444  make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
445  }
446  }
447  else
448  {
450  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
451  }
452  }();
453 
454  return b_block_desc_k0_n_k1;
455  }
456 
457  __host__ __device__ static constexpr auto GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1()
458  {
459  constexpr auto max_lds_align = K1;
460 
461  // B matrix in LDS memory, dst of blockwise copy
462  constexpr auto b_block_desc_b_k0_n_k1 = [&]() {
463  if constexpr(BBlockLdsExtraN)
464  {
465  if constexpr(BBlockLdsExtraN1Wrw)
466  {
467  constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor(
472  K1),
474  (Number<N1PerBlock>{} * K1 + N1Padding),
475  Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding),
476  Number<N1PerBlock>{} * K1 + N1Padding,
477  K1,
478  I1));
479 
480  constexpr auto b_block_desc_b_k0_n_k1_tmp = transform_tensor_descriptor(
481  b_block_desc_b_k0_n0_n1_k1,
489 
490  return b_block_desc_b_k0_n_k1_tmp;
491  }
492  else
493  {
497  Number<NPerBlock + 1>{} * K1,
498  K1,
499  I1));
500  }
501  }
502  else
503  {
506  max_lds_align);
507  }
508  }();
509 
510  return b_block_desc_b_k0_n_k1;
511  }
512 
513  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
514  {
515  constexpr auto max_lds_align = K1;
516 
517  // A matrix in LDS memory, dst of blockwise copy
518  constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1();
519 
520  // B matrix in LDS memory, dst of blockwise copy
521  constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
522 
523  // LDS allocation for A and B: be careful of alignment
524  constexpr auto a_block_space_size = math::integer_least_multiple(
525  a_b_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
526 
527  constexpr auto b_block_space_size = math::integer_least_multiple(
528  b_b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
529 
530  constexpr auto c_block_size =
531  GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
532 
533  return math::max((a_block_space_size * sizeof(FloatAAdjusted) +
534  b_block_space_size * sizeof(FloatBAdjusted)),
535  c_block_size * sizeof(FloatC));
536  }
537 
538  template <
539  InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
540  __device__ static bool constexpr IsValidCompilationParameter()
541  {
542  return ck::tensor_operation::device::IsValidGemmCompilationParameter<
543  BlockSize,
544  MPerBlock,
545  NPerBlock,
546  MPerXdl,
547  NPerXdl,
548  MRepeat,
549  NRepeat,
550  FloatC,
551  CGlobalMemoryDataOperation_>();
552  }
553  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
554  template <typename Block2CTileMap>
555  __host__ __device__ static constexpr bool
556  CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
557  const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
558  const CMNGridDesc& c_m_n_grid_desc,
559  const Block2CTileMap& block_2_ctile_map)
560  {
561  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
562  "wrong! K1 need to be known at compile-time");
563 
564  static_assert((MPerBlock % (MPerXdl * MRepeat) == 0) &&
565  (NPerBlock % (NRepeat * NPerXdl)) == 0,
566  "Invalid tuning param!");
567 
568  const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
569  const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
570  const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
571  const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
572 
573  // check gridwise gemm pipeline
574  const auto num_k_loop = K0 / K0PerBlock;
575 
576  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
577  {
578  return false;
579  }
580 
581  if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
582  K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
583  K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
584  K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
585  KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
586  return false;
587 
588  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
589  return false;
590 
591  if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
592  {
593  return false;
594  }
595 
596  constexpr long_index_t TwoGB = (long_index_t{1} << 31);
597 
598  if(!(a_b_k0_m_k1_grid_desc.GetElementSpaceSize() * sizeof(FloatA) <= TwoGB &&
599  b_b_k0_n_k1_grid_desc.GetElementSpaceSize() * sizeof(FloatB) <= TwoGB &&
600  c_m_n_grid_desc.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
601  {
602  return false;
603  }
604 
605  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
606  return true;
607  }
608 
609  __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
610  {
611  // const bool has_main_k0_block_loop = K0 > K0PerBlock;
612  const index_t num_loop = K0 / K0PerBlock;
613 
614  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
615 
616  // return has_main_k0_block_loop;
617  }
618 
619  __host__ __device__ static constexpr auto
620  MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc)
621  {
622  const auto M = c_m_n_grid_desc.GetLength(I0);
623  const auto N = c_m_n_grid_desc.GetLength(I1);
624 
625  const auto MBlock = M / MPerBlock;
626  const auto NBlock = N / NPerBlock;
627 
629  c_m_n_grid_desc,
634  }
635 
636  // return block_id to C matrix tile idx (m0, n0) mapping
637  __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
638  const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
639  {
641  c_m_n_grid_desc, M01, N01, KBatch);
642  }
643 
644  __host__ __device__ static constexpr auto
646  {
647  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
648  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
649 
651  make_tuple(I1,
653  I1,
655  }
656 
658  decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{}));
659  using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
660 
661  template <bool HasMainKBlockLoop, bool SplitKOffsetHack = false>
662  __device__ static void Run(const FloatA* __restrict__ p_a_grid,
663  const FloatB* __restrict__ p_b_grid,
664  FloatC* __restrict__ p_c_grid,
665  void* __restrict__ p_shared,
666  const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
667  const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
669  c_grid_desc_mblock_mperblock_nblock_nperblock,
670  const AElementwiseOperation& a_element_op,
671  const BElementwiseOperation& b_element_op,
672  const CElementwiseOperation& c_element_op,
673  const CBlockClusterAdaptor& c_block_cluster_adaptor,
674  const long_index_t split_k_stride_a,
675  const long_index_t split_k_stride_b,
676  index_t k_batch)
677  {
678  const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
679 
680  // divide block work by [M, N]
681  const auto block_work_idx =
682  c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
683 
684  const index_t k_batch_id = block_work_idx[I0];
685 
686  // Use compile-time branching based on template parameters
687  const long_index_t split_k_offset_a = SplitKOffsetHack ? k_batch_id * split_k_stride_a : 0;
688  const long_index_t split_k_offset_b = SplitKOffsetHack ? k_batch_id * split_k_stride_b : 0;
689 
690  // When hack is enabled, buffer size equals the stride (calculated from descriptor's
691  // CalculateOffset method in the device layer). This properly accounts for the
692  // descriptor's transform pipeline and non-compact strides.
693  // When hack is disabled, use the full element space size.
694  const long_index_t a_buffer_size =
695  SplitKOffsetHack ? split_k_stride_a : a_b_k0_m_k1_grid_desc.GetElementSpaceSize();
696 
697  const long_index_t b_buffer_size =
698  SplitKOffsetHack ? split_k_stride_b : b_b_k0_n_k1_grid_desc.GetElementSpaceSize();
699 
700  ignore = k_batch; // k_batch value itself not used in this function
701 
702  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
703  p_a_grid + split_k_offset_a, a_buffer_size);
704  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
705  p_b_grid + split_k_offset_b, b_buffer_size);
706  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
707  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
708 
709  if(!c_block_cluster_adaptor.ValidCTileIndex(
710  make_tuple(block_work_idx[I1], block_work_idx[I2]),
711  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
712  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
713  {
714  return;
715  }
716 
717  // HACK: this force m/n_block_data_idx_on_grid into SGPR
718  const index_t m_block_data_idx_on_grid =
719  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
720 
721  const index_t n_block_data_idx_on_grid =
722  __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
723 
724  // lds max alignment
725  constexpr auto max_lds_align = K1;
726 
727  // A matrix in LDS memory, dst of blockwise copy
728  constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
729 
730  constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1();
731  // B matrix in LDS memory, dst of blockwise copy
732  constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
733 
734  constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
735  // A matrix blockwise copy
736  auto a_blockwise_copy =
738  AElementwiseOperation,
740  InMemoryDataOperationEnum::Set,
742  ABlockTransferThreadClusterLengths_K0_M_K1,
743  ABlockTransferThreadClusterArrangeOrder,
744  FloatA,
746  decltype(a_b_k0_m_k1_grid_desc),
747  decltype(a_b_k0_m_k1_block_desc),
748  ABlockTransferSrcAccessOrder,
750  ABlockTransferSrcVectorDim,
751  3,
752  ABlockTransferSrcScalarPerVector,
753  ABlockTransferDstScalarPerVector_K1,
754  1,
755  1,
756  AThreadTransferSrcResetCoordinateAfterRun,
757  true>(
758  a_b_k0_m_k1_grid_desc,
759  make_multi_index(SplitKOffsetHack ? 0 : k_batch_id, 0, m_block_data_idx_on_grid, 0),
760  a_element_op,
761  a_b_k0_m_k1_block_desc,
762  make_multi_index(0, 0, 0, 0),
764 
765  // B matrix blockwise copy
766  auto b_blockwise_copy =
768  BElementwiseOperation,
770  InMemoryDataOperationEnum::Set,
772  BBlockTransferThreadClusterLengths_K0_N_K1,
773  BBlockTransferThreadClusterArrangeOrder,
774  FloatB,
776  decltype(b_b_k0_n_k1_grid_desc),
777  decltype(b_b_k0_n_k1_block_desc),
778  BBlockTransferSrcAccessOrder,
780  BBlockTransferSrcVectorDim,
781  3,
782  BBlockTransferSrcScalarPerVector,
783  BBlockTransferDstScalarPerVector_K1,
784  1,
785  1,
786  BThreadTransferSrcResetCoordinateAfterRun,
787  true>(
788  b_b_k0_n_k1_grid_desc,
789  make_multi_index(SplitKOffsetHack ? 0 : k_batch_id, 0, n_block_data_idx_on_grid, 0),
790  b_element_op,
791  b_b_k0_n_k1_block_desc,
792  make_multi_index(0, 0, 0, 0),
794 
795  // GEMM definition
796  // c_mtx += transpose(a_mtx) * b_mtx
797  // a_mtx[K0PerBlock, MPerBlock] is in LDS
798  // b_mtx[K0PerBlock, NPerBlock] is in LDS
799  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
800  // register
801  // sanity check
802  constexpr bool is_single_rate_mfma =
804  K1 <= 4) ||
807  K1 < 32))
808  ? true
809  : false;
810  constexpr auto is_scale_mfma = false;
811  constexpr index_t KPack = math::max(K1,
812  MfmaSelector<ComputeTypeA,
813  MPerXdl,
814  NPerXdl,
815  ComputeTypeB,
816  is_single_rate_mfma,
817  is_scale_mfma>::selected_mfma.k_per_blk);
818 
819  auto blockwise_gemm =
823  FloatAcc,
824  decltype(a_k0_m_k1_block_desc),
825  decltype(b_k0_n_k1_block_desc),
826  MPerXdl,
827  NPerXdl,
828  MRepeat,
829  NRepeat,
830  KPack,
831  ComputeTypeA,
832  ComputeTypeB>{};
833 
834  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
835 
836  // LDS allocation for A and B: be careful of alignment
837  constexpr auto a_block_space_size =
838  math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
839 
840  constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
841  constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
842 
843  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
844  static_cast<FloatAAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
845 
846  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
847  static_cast<FloatBAdjusted*>(p_shared) + a_block_space_size,
848  b_k0_n_k1_block_desc.GetElementSpaceSize());
849 
850  // gridwise GEMM pipeline
851  const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
852 
853  GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
854  a_b_k0_m_k1_block_desc,
855  a_blockwise_copy,
856  a_grid_buf,
857  a_block_buf,
858  a_block_slice_copy_step,
859  b_b_k0_n_k1_grid_desc,
860  b_b_k0_n_k1_block_desc,
861  b_blockwise_copy,
862  b_grid_buf,
863  b_block_buf,
864  b_block_slice_copy_step,
865  blockwise_gemm,
866  c_thread_buf,
867  K0BlockMainLoop);
868 
869  // output: register to global memory
870  {
871  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
872  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
873 
874  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
875  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
876 
877  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
878  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
879 
880  constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
881  constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
882  constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
883  constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
884  constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
885  constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
886  constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
887  constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
888 
889  constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
890  GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
891 
892  auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
893  static_cast<FloatC*>(p_shared),
894  c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
895 
896  static_assert(M1 == MWave, "");
897  static_assert(N1 == NWave, "");
898  static_assert(M2 * M3 * M4 == MPerXdl, "");
899  static_assert(N2 == NPerXdl, "");
900 
901  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
902  c_block_desc_mblock_mperblock_nblock_nperblock,
903  make_tuple(
904  make_freeze_transform(I0), // freeze mblock
905  make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle,
906  M1,
907  M2,
908  M3,
909  M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
910  make_freeze_transform(I0), // freeze nblock
911  make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle,
912  N1,
913  N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
915  make_tuple(
917 
918  // calculate origin of thread output tensor on global memory
919  // blockwise GEMM c matrix starting index
920  const auto c_thread_mtx_on_block =
921  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
922 
923  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
924  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
925 
926  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
928  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
931 
932  const auto m_thread_data_on_block_idx =
933  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
934  make_multi_index(m_thread_data_on_block));
935 
936  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
941 
942  const auto n_thread_data_on_block_idx =
943  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
944  make_multi_index(n_thread_data_on_block));
945 
946  // VGPR to LDS
947  auto c_thread_copy_vgpr_to_lds =
949  FloatC,
950  decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
951  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
953  Sequence<CShuffleMRepeatPerShuffle,
954  CShuffleNRepeatPerShuffle,
955  I1,
956  I1,
957  M2,
958  I1,
959  M4,
960  I1>,
962  7,
963  1,
964  InMemoryDataOperationEnum::Set,
965  1,
966  true>{
967  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
969  0,
970  m_thread_data_on_block_idx[I1],
971  n_thread_data_on_block_idx[I1],
972  m_thread_data_on_block_idx[I2],
973  m_thread_data_on_block_idx[I3],
974  m_thread_data_on_block_idx[I4],
975  n_thread_data_on_block_idx[I2]),
977 
978  // LDS to global
979  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
980  ThisThreadBlock, // index_t BlockSize,
981  CElementwiseOperation, // ElementwiseOperation,
982  CGlobalMemoryDataOperation, // DstInMemOp,
983  Sequence<1,
984  CShuffleMRepeatPerShuffle * MWave * MPerXdl,
985  1,
986  CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
987  CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
988  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
989  FloatC, // typename SrcData,
990  FloatC, // typename DstData,
991  decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
992  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
993  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
994  3, // index_t VectorDim,
995  CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
996  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
997  false> // bool ThreadTransferDstResetCoordinateAfterRun
998  {c_block_desc_mblock_mperblock_nblock_nperblock,
999  make_multi_index(0, 0, 0, 0),
1000  c_grid_desc_mblock_mperblock_nblock_nperblock,
1001  make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0),
1002  c_element_op};
1003 
1004  constexpr auto mxdlperwave_forward_step =
1005  make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0);
1006  constexpr auto nxdlperwave_forward_step =
1007  make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl);
1008  constexpr auto nxdlperwave_backward_step =
1009  make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl);
1010 
1011  static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
1012  constexpr auto mxdlperwave = mxdlperwave_iter;
1013 
1014  static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
1015  constexpr bool nxdlperwave_forward_sweep =
1016  (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
1017 
1018  constexpr index_t nxdlperwave_value =
1019  nxdlperwave_forward_sweep
1020  ? nxdlperwave_iter
1021  : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
1022 
1023  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
1024 
1025  // make sure it's safe to do ds_write
1026  block_sync_lds();
1027 
1028  // VGPR to LDS
1029  c_thread_copy_vgpr_to_lds.Run(
1030  c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
1031  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
1032  c_thread_buf,
1033  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1034  c_block_buf);
1035 
1036  // make sure it's safe to do ds_read
1037  block_sync_lds();
1038 
1039  // LDS to global
1040  c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
1041  c_block_buf,
1042  c_grid_desc_mblock_mperblock_nblock_nperblock,
1043  c_grid_buf);
1044 
1045  // move on nxdlperwave dimension
1046  if constexpr(nxdlperwave_forward_sweep &&
1047  (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
1048  {
1049  c_block_copy_lds_to_global.MoveDstSliceWindow(
1050  c_grid_desc_mblock_mperblock_nblock_nperblock,
1051  nxdlperwave_forward_step);
1052  }
1053  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
1054  {
1055  c_block_copy_lds_to_global.MoveDstSliceWindow(
1056  c_grid_desc_mblock_mperblock_nblock_nperblock,
1057  nxdlperwave_backward_step);
1058  }
1059  });
1060 
1061  // move on mxdlperwave dimension
1062  if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
1063  {
1064  c_block_copy_lds_to_global.MoveDstSliceWindow(
1065  c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
1066  }
1067  });
1068  }
1069  }
1070 
1071  template <bool HasMainKBlockLoop>
1072  __device__ static void Run(const FloatA* __restrict__ p_a_grid,
1073  const FloatB* __restrict__ p_b_grid,
1074  FloatC* __restrict__ p_c_grid,
1075  void* __restrict__ p_shared,
1076  const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
1077  const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
1079  c_grid_desc_mblock_mperblock_nblock_nperblock,
1080  const AElementwiseOperation& a_element_op,
1081  const BElementwiseOperation& b_element_op,
1082  const CElementwiseOperation& c_element_op,
1083  const CBlockClusterAdaptor& c_block_cluster_adaptor)
1084  {
1085  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1086  p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
1087  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1088  p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
1089  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1090  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1091 
1092  const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
1093 
1094  // divide block work by [M, N]
1095  const auto block_work_idx =
1096  c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1097 
1098  const index_t k_batch_id = block_work_idx[I0];
1099 
1100  if(!c_block_cluster_adaptor.ValidCTileIndex(
1101  make_tuple(block_work_idx[I1], block_work_idx[I2]),
1102  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1103  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1104  {
1105  return;
1106  }
1107 
1108  // HACK: this force m/n_block_data_idx_on_grid into SGPR
1109  const index_t m_block_data_idx_on_grid =
1110  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
1111 
1112  const index_t n_block_data_idx_on_grid =
1113  __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
1114 
1115  // lds max alignment
1116  constexpr auto max_lds_align = K1;
1117 
1118  // A matrix in LDS memory, dst of blockwise copy
1119  constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
1120 
1121  constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1();
1122  // B matrix in LDS memory, dst of blockwise copy
1123  constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
1124 
1125  constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
1126  // A matrix blockwise copy
1127  auto a_blockwise_copy =
1129  AElementwiseOperation,
1131  InMemoryDataOperationEnum::Set,
1133  ABlockTransferThreadClusterLengths_K0_M_K1,
1134  ABlockTransferThreadClusterArrangeOrder,
1135  FloatA,
1137  decltype(a_b_k0_m_k1_grid_desc),
1138  decltype(a_b_k0_m_k1_block_desc),
1139  ABlockTransferSrcAccessOrder,
1141  ABlockTransferSrcVectorDim,
1142  3,
1143  ABlockTransferSrcScalarPerVector,
1144  ABlockTransferDstScalarPerVector_K1,
1145  1,
1146  1,
1147  AThreadTransferSrcResetCoordinateAfterRun,
1148  true>(
1149  a_b_k0_m_k1_grid_desc,
1150  make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
1151  a_element_op,
1152  a_b_k0_m_k1_block_desc,
1153  make_multi_index(0, 0, 0, 0),
1155 
1156  // B matrix blockwise copy
1157  auto b_blockwise_copy =
1159  BElementwiseOperation,
1161  InMemoryDataOperationEnum::Set,
1163  BBlockTransferThreadClusterLengths_K0_N_K1,
1164  BBlockTransferThreadClusterArrangeOrder,
1165  FloatB,
1167  decltype(b_b_k0_n_k1_grid_desc),
1168  decltype(b_b_k0_n_k1_block_desc),
1169  BBlockTransferSrcAccessOrder,
1171  BBlockTransferSrcVectorDim,
1172  3,
1173  BBlockTransferSrcScalarPerVector,
1174  BBlockTransferDstScalarPerVector_K1,
1175  1,
1176  1,
1177  BThreadTransferSrcResetCoordinateAfterRun,
1178  true>(
1179  b_b_k0_n_k1_grid_desc,
1180  make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
1181  b_element_op,
1182  b_b_k0_n_k1_block_desc,
1183  make_multi_index(0, 0, 0, 0),
1185 
1186  // GEMM definition
1187  // c_mtx += transpose(a_mtx) * b_mtx
1188  // a_mtx[K0PerBlock, MPerBlock] is in LDS
1189  // b_mtx[K0PerBlock, NPerBlock] is in LDS
1190  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
1191  // register
1192  // sanity check
1193  constexpr bool is_single_rate_mfma =
1195  K1 <= 4) ||
1196  (is_same<ComputeTypeA, int8_t>::value && K1 <= 8) ||
1198  K1 < 32))
1199  ? true
1200  : false;
1201  constexpr auto is_scale_mfma = false;
1202  constexpr index_t KPack = math::max(K1,
1203  MfmaSelector<ComputeTypeA,
1204  MPerXdl,
1205  NPerXdl,
1206  ComputeTypeB,
1207  is_single_rate_mfma,
1208  is_scale_mfma>::selected_mfma.k_per_blk);
1209 
1210  auto blockwise_gemm =
1214  FloatAcc,
1215  decltype(a_k0_m_k1_block_desc),
1216  decltype(b_k0_n_k1_block_desc),
1217  MPerXdl,
1218  NPerXdl,
1219  MRepeat,
1220  NRepeat,
1221  KPack,
1222  ComputeTypeA,
1223  ComputeTypeB>{};
1224 
1225  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
1226 
1227  // LDS allocation for A and B: be careful of alignment
1228  constexpr auto a_block_space_size =
1229  math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
1230 
1231  constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
1232  constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
1233 
1234  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1235  static_cast<FloatAAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
1236 
1237  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1238  static_cast<FloatBAdjusted*>(p_shared) + a_block_space_size,
1239  b_k0_n_k1_block_desc.GetElementSpaceSize());
1240 
1241  // gridwise GEMM pipeline
1242  const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
1243 
1244  GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
1245  a_b_k0_m_k1_block_desc,
1246  a_blockwise_copy,
1247  a_grid_buf,
1248  a_block_buf,
1249  a_block_slice_copy_step,
1250  b_b_k0_n_k1_grid_desc,
1251  b_b_k0_n_k1_block_desc,
1252  b_blockwise_copy,
1253  b_grid_buf,
1254  b_block_buf,
1255  b_block_slice_copy_step,
1256  blockwise_gemm,
1257  c_thread_buf,
1258  K0BlockMainLoop);
1259 
1260  // output: register to global memory
1261  {
1262  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
1263  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
1264 
1265  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
1266  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1267 
1268  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
1269  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1270 
1271  constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
1272  constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
1273  constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
1274  constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
1275  constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
1276  constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
1277  constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
1278  constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
1279 
1280  constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
1281  GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
1282 
1283  auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1284  static_cast<FloatC*>(p_shared),
1285  c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1286 
1287  static_assert(M1 == MWave, "");
1288  static_assert(N1 == NWave, "");
1289  static_assert(M2 * M3 * M4 == MPerXdl, "");
1290  static_assert(N2 == NPerXdl, "");
1291 
1292  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1293  c_block_desc_mblock_mperblock_nblock_nperblock,
1294  make_tuple(
1295  make_freeze_transform(I0), // freeze mblock
1296  make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle,
1297  M1,
1298  M2,
1299  M3,
1300  M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
1301  make_freeze_transform(I0), // freeze nblock
1302  make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle,
1303  N1,
1304  N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
1306  make_tuple(
1308 
1309  // calculate origin of thread output tensor on global memory
1310  // blockwise GEMM c matrix starting index
1311  const auto c_thread_mtx_on_block =
1312  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1313 
1314  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1315  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1316 
1317  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1319  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1321  make_tuple(Sequence<0>{}));
1322 
1323  const auto m_thread_data_on_block_idx =
1324  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1325  make_multi_index(m_thread_data_on_block));
1326 
1327  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1331  make_tuple(Sequence<0>{}));
1332 
1333  const auto n_thread_data_on_block_idx =
1334  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1335  make_multi_index(n_thread_data_on_block));
1336 
1337  // VGPR to LDS
1338  auto c_thread_copy_vgpr_to_lds =
1340  FloatC,
1341  decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
1342  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1344  Sequence<CShuffleMRepeatPerShuffle,
1345  CShuffleNRepeatPerShuffle,
1346  I1,
1347  I1,
1348  M2,
1349  I1,
1350  M4,
1351  I1>,
1353  7,
1354  1,
1355  InMemoryDataOperationEnum::Set,
1356  1,
1357  true>{
1358  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1359  make_multi_index(0,
1360  0,
1361  m_thread_data_on_block_idx[I1],
1362  n_thread_data_on_block_idx[I1],
1363  m_thread_data_on_block_idx[I2],
1364  m_thread_data_on_block_idx[I3],
1365  m_thread_data_on_block_idx[I4],
1366  n_thread_data_on_block_idx[I2]),
1368 
1369  // LDS to global
1370  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1371  ThisThreadBlock, // index_t BlockSize,
1372  CElementwiseOperation, // ElementwiseOperation,
1373  CGlobalMemoryDataOperation, // DstInMemOp,
1374  Sequence<1,
1375  CShuffleMRepeatPerShuffle * MWave * MPerXdl,
1376  1,
1377  CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1378  CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1379  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1380  FloatC, // typename SrcData,
1381  FloatC, // typename DstData,
1382  decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
1383  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1384  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1385  3, // index_t VectorDim,
1386  CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
1387  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1388  false> // bool ThreadTransferDstResetCoordinateAfterRun
1389  {c_block_desc_mblock_mperblock_nblock_nperblock,
1390  make_multi_index(0, 0, 0, 0),
1391  c_grid_desc_mblock_mperblock_nblock_nperblock,
1392  make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0),
1393  c_element_op};
1394 
1395  constexpr auto mxdlperwave_forward_step =
1396  make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0);
1397  constexpr auto nxdlperwave_forward_step =
1398  make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl);
1399  constexpr auto nxdlperwave_backward_step =
1400  make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl);
1401 
1402  static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
1403  constexpr auto mxdlperwave = mxdlperwave_iter;
1404 
1405  static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
1406  constexpr bool nxdlperwave_forward_sweep =
1407  (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
1408 
1409  constexpr index_t nxdlperwave_value =
1410  nxdlperwave_forward_sweep
1411  ? nxdlperwave_iter
1412  : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
1413 
1414  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
1415 
1416  // make sure it's safe to do ds_write
1417  block_sync_lds();
1418 
1419  // VGPR to LDS
1420  c_thread_copy_vgpr_to_lds.Run(
1421  c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
1422  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
1423  c_thread_buf,
1424  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1425  c_block_buf);
1426 
1427  // make sure it's safe to do ds_read
1428  block_sync_lds();
1429 
1430  // LDS to global
1431  c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
1432  c_block_buf,
1433  c_grid_desc_mblock_mperblock_nblock_nperblock,
1434  c_grid_buf);
1435 
1436  // move on nxdlperwave dimension
1437  if constexpr(nxdlperwave_forward_sweep &&
1438  (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
1439  {
1440  c_block_copy_lds_to_global.MoveDstSliceWindow(
1441  c_grid_desc_mblock_mperblock_nblock_nperblock,
1442  nxdlperwave_forward_step);
1443  }
1444  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
1445  {
1446  c_block_copy_lds_to_global.MoveDstSliceWindow(
1447  c_grid_desc_mblock_mperblock_nblock_nperblock,
1448  nxdlperwave_backward_step);
1449  }
1450  });
1451 
1452  // move on mxdlperwave dimension
1453  if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
1454  {
1455  c_block_copy_lds_to_global.MoveDstSliceWindow(
1456  c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
1457  }
1458  });
1459  }
1460  }
1461 }; // namespace ck
1462 
1463 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:54
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:279
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
int64_t long_index_t
Definition: ck.hpp:302
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
ushort bhalf_t
Definition: data_type.hpp:30
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto container_reverse_exclusive_scan(const Array< TData, NSize > &x, Reduce f, TData init)
Definition: container_helper.hpp:213
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:212
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__host__ constexpr __device__ auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition: container_helper.hpp:111
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
PipelineVersion
Pipeline version enumeration for GEMM kernels.
Definition: pipeline_enum.hpp:17
__host__ __device__ void print_multi_index(const Tuple< Xs... > &x)
Definition: statically_indexed_array_multi_index.hpp:147
__global__ void kernel_gemm_xdlops_bwd_weight(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor, const long_index_t split_k_stride_a, const long_index_t split_k_stride_b, index_t k_batch)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:159
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
__host__ constexpr __device__ auto make_merge_transform_v4_no_carry(const LowLengths &low_lengths)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:137
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
Definition: array.hpp:16
Definition: block_to_ctile_map.hpp:720
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_smfmac_xdlops.hpp:78
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:266
conditional_t< is_same_v< ComputeTypeB, ck::tf32_t >, float, ComputeTypeB > FloatBAdjusted
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:296
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:279
conditional_t< is_same_v< ComputeTypeA, ck::tf32_t >, float, ComputeTypeA > FloatAAdjusted
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:295
__host__ static constexpr __device__ auto GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:355
__host__ static constexpr __device__ auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:411
__host__ static constexpr __device__ auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:309
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:609
static __device__ void Run(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc_B_K0_M_K1 &a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 &b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const CBlockClusterAdaptor &c_block_cluster_adaptor)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:1072
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CMNGridDesc &c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:637
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:658
__host__ static constexpr __device__ auto GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:457
decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)) CBlockClusterAdaptor
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:659
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:645
static __device__ void Run(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc_B_K0_M_K1 &a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 &b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const CBlockClusterAdaptor &c_block_cluster_adaptor, const long_index_t split_k_stride_a, const long_index_t split_k_stride_b, index_t k_batch)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:662
static __device__ constexpr bool IsValidCompilationParameter()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:540
__host__ static constexpr __device__ auto MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:620
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_B_K0_M_K1 &a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 &b_b_k0_n_k1_grid_desc, const CMNGridDesc &c_m_n_grid_desc, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:556
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:282
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:513
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:26
__host__ constexpr __device__ Merge_v4_no_carry(const LowLengths &low_lengths)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:44
LowLengthsScan low_lengths_scan_
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:39
__host__ constexpr __device__ Merge_v4_no_carry()=default
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number< 1 >{}))) UpLengths
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:36
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:117
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_up_diff, LowIdx &idx_low, const UpIdx &idx_up_new, Number< Hack >) const
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:82
static constexpr index_t NDimLow
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:27
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:53
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:57
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:60
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:108
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:101
UpLengths up_lengths_
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:40
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number< 1 >{})) LowLengthsScan
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:33
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:55
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:103
LowLengths low_lengths_
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:38
__host__ __device__ void Print() const
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:122
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1255
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
Definition: type.hpp:177
Definition: math.hpp:34
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:340