/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File
blockwise_gemm_pipeline_wmmaops_v1.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Naive pipeline with lowest resource request per WGP
11 
12 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
13  index_t BlockSize,
14  typename ADataType,
15  typename BDataType,
16  typename ComputeTypeA,
17  typename ComputeTypeB,
18  typename AccDataType,
19  typename AWmmaTileDesc,
20  typename BWmmaTileDesc,
21  index_t ABlockTransferSrcScalarPerVector,
22  index_t BBlockTransferSrcScalarPerVector,
23  index_t MPerBlock,
24  index_t NPerBlock,
25  index_t KPerBlock,
26  index_t MPerWmma,
27  index_t NPerWmma,
28  index_t MRepeat,
29  index_t NRepeat,
30  index_t KPack,
31  index_t KInner,
32  bool TransposeC = false,
33  bool BSkipLDS = false>
35 {
36 };
37 
38 template <index_t BlockSize,
39  typename ADataType,
40  typename BDataType,
41  typename ComputeTypeA,
42  typename ComputeTypeB,
43  typename AccDataType,
44  typename AWmmaTileDesc,
45  typename BWmmaTileDesc,
46  index_t ABlockTransferSrcScalarPerVector,
47  index_t BBlockTransferSrcScalarPerVector,
48  index_t MPerBlock,
49  index_t NPerBlock,
50  index_t KPerBlock,
51  index_t MPerWmma,
52  index_t NPerWmma,
53  index_t MRepeat,
54  index_t NRepeat,
55  index_t KPack,
56  index_t KInner,
57  bool TransposeC>
59  BlockSize,
60  ADataType,
61  BDataType,
62  ComputeTypeA,
63  ComputeTypeB,
64  AccDataType,
65  AWmmaTileDesc,
66  BWmmaTileDesc,
67  ABlockTransferSrcScalarPerVector,
68  BBlockTransferSrcScalarPerVector,
69  MPerBlock,
70  NPerBlock,
71  KPerBlock,
72  MPerWmma,
73  NPerWmma,
74  MRepeat,
75  NRepeat,
76  KPack,
77  KInner,
78  TransposeC,
79  false>
81  ADataType,
82  BDataType,
83  ComputeTypeA,
84  ComputeTypeB,
85  AccDataType,
86  AWmmaTileDesc,
87  BWmmaTileDesc,
88  ABlockTransferSrcScalarPerVector,
89  BBlockTransferSrcScalarPerVector,
90  MPerBlock,
91  NPerBlock,
92  KPerBlock,
93  MPerWmma,
94  NPerWmma,
95  MRepeat,
96  NRepeat,
97  KPack,
98  KInner,
99  TransposeC>
100 {
101  // GlobalPrefetchStages: 1
102  // LocalPreFillStages: 1
103  // LocalPreFetchStages: 0
104  // LocalSharedMemoryBuffer: 1
106  ADataType,
107  BDataType,
108  ComputeTypeA,
109  ComputeTypeB,
110  AccDataType,
111  AWmmaTileDesc,
112  BWmmaTileDesc,
113  ABlockTransferSrcScalarPerVector,
114  BBlockTransferSrcScalarPerVector,
115  MPerBlock,
116  NPerBlock,
117  KPerBlock,
118  MPerWmma,
119  NPerWmma,
120  MRepeat,
121  NRepeat,
122  KPack,
123  KInner,
124  TransposeC>;
125  using Base::I0;
126  using Base::I1;
127  using typename Base::HotLoopInstList;
128 
129  using Base::A_K1;
130  using Base::A_KRow;
131  using Base::B_K1;
132  using Base::B_KRow;
133  using Base::KRepeat;
134  using Base::WmmaK;
135 
136  using Base::wmma_gemm;
137 
138  using Base::CalculateCThreadOriginDataIndex;
139  using Base::
140  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
141  using Base::GetCThreadBuffer;
142  using Base::
143  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
144 
145  using Base::a_block_desc_k0_m0_m1_m2_k1;
146  using Base::b_block_desc_k0_n0_n1_n2_k1;
147 
148  using typename Base::Empty;
149 
150  static constexpr index_t PrefetchStages = 1;
151  static constexpr index_t PrefillStages = 1;
152  static constexpr index_t GlobalBufferNum = 1;
153 
154  static bool __host__ __device__ BlockHasHotloop(index_t num_loop)
155  {
156  return num_loop > PrefetchStages;
157  }
158 
160  {
161  ignore = num_loop;
162  return TailNumber::Full;
163  }
164 
165  template <bool HasMainLoop,
166  TailNumber TailNum,
167  typename AGridDesc,
168  typename ABlockDesc,
169  typename ABlockTransfer,
170  typename AGridBuffer,
171  typename ABlockBuffer,
172  typename ABlockTransferStep,
173  typename BGridDesc,
174  typename BBlockDesc,
175  typename BBlockTransfer,
176  typename BGridBuffer,
177  typename BBlockBuffer,
178  typename BBlockTransferStep,
179  typename CThreadBuffer,
180  typename AScaleStruct,
181  typename BScaleStruct,
182  typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
183  __device__ void Run(const AGridDesc& a_grid_desc,
184  const ABlockDesc& a_block_desc,
185  ABlockTransfer& a_blockwise_copy,
186  const AGridBuffer& a_grid_buf,
187  ABlockBuffer& a_block_buf,
188  const ABlockTransferStep& a_block_copy_step,
189  const BGridDesc& b_grid_desc,
190  const BBlockDesc& b_block_desc,
191  BBlockTransfer& b_blockwise_copy,
192  const BGridBuffer& b_grid_buf,
193  BBlockBuffer& b_block_buf,
194  const BBlockTransferStep& b_block_copy_step,
195  CThreadBuffer& c_thread_buf,
196  AScaleStruct&,
197  BScaleStruct& b_scale_struct,
198  index_t num_loop,
199  index_t num_loop_per_scale) const
200  {
201  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
202 
203  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
204  a_thread_desc_.GetElementSpaceSize());
205  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
206  b_thread_desc_.GetElementSpaceSize());
207 
208  // Global prefetch 1
209  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
210  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
211 
212  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
213  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
214 
215  // Scales global load
216  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
217 
218  // Local prefill 1
219  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
220  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
221 
222  // Initialize C
223  c_thread_buf.Clear();
224 
225  auto blockwise_gemm_func = [&]() {
226  // Local load
227  static_for<0, KRepeat, 1>{}([&](auto k0) {
228  static_for<0, MRepeat, 1>{}([&](auto m0) {
229  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
230  make_tuple(I0, m0, k0, I0, I0, I0, I0),
231  a_block_buf,
232  a_thread_desc_,
233  make_tuple(I0, I0, I0, I0, I0, I0, I0),
234  a_thread_buf);
235  if constexpr(m0 == I0)
236  {
237  if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
238  {
239  static_for<0, NRepeat, 1>{}([&](auto n0) {
240  b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
241  make_tuple(I0, n0, k0, I0, I0, I0, I0),
242  b_block_buf,
243  b_thread_desc_,
244  make_tuple(I0, n0, I0, I0, I0, I0, I0),
245  b_thread_buf);
246  });
247  }
248  else
249  {
250  static_for<0, NRepeat, 1>{}([&](auto n0) {
251  b_thread_copy_.Run(
252  b_block_desc_k0_n0_n1_n2_k1,
253  make_tuple(I0, n0, k0, I0, I0, I0, I0),
254  b_block_buf,
255  b_scale_struct.scale_thread_bufs(
256  I0)[Number<n0 * BScaleStruct::num_scale_k_block +
257  k0 / BScaleStruct::num_scale_krepeat>{}],
258  b_thread_desc_,
259  make_tuple(I0, n0, I0, I0, I0, I0, I0),
260  b_thread_buf);
261  });
262  }
263  }
264 
265  static_for<0, KInner, 1>{}([&](auto k_inner) {
266  static_for<0, NRepeat, 1>{}([&](auto n0) {
267  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
268  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
269 
270  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
271  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
272  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
273  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
275  I0,
276  I0,
277  I0,
278  I0,
279  I0,
280  Number<kk % A_K1>{}))>{}];
281  });
282  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
283  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
284  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
285  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
287  n0,
288  I0,
289  I0,
290  I0,
291  I0,
292  Number<kk % B_K1>{}))>{}];
293  });
294 
295  using wmma_input_type_a =
296  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
297  using wmma_input_type_b =
298  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
299 
300  constexpr index_t c_offset =
301  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
302 
303  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
304  b_thread_vec.template AsType<wmma_input_type_b>(),
305  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
306  });
307  });
308  });
309  });
310  };
311 
312  // main body
313  if constexpr(HasMainLoop)
314  {
315  index_t i = 0;
316  do
317  {
318  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
319  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
320 
321  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
322  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
323 
324  block_sync_lds();
325  blockwise_gemm_func();
326 
327  block_sync_lds();
328  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
329  if constexpr(ck::is_same<BScaleStruct, Empty>::value == false)
330  {
331  block_sync_lds();
332  }
333  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
334  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
335 
336  constexpr index_t num_ds_write_inst =
337  HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
338 
339  constexpr index_t num_buffer_load_inst = HotLoopInstList::A_Buffer_Load_Inst_Num +
340  HotLoopInstList::B_Buffer_Load_Inst_Num;
342  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
343  });
344  static_for<0, KRepeat, 1>{}([&](auto) {
345  static_for<0, MRepeat, 1>{}([&](auto m0) {
346  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
347  if constexpr(m0 == I0)
348  {
349  static_for<0, NRepeat, 1>{}([&](auto) {
350  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
351  });
352  }
353  static_for<0, KInner, 1>{}([&](auto) {
354  static_for<0, NRepeat, 1>{}([&](auto) {
355  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
356  });
357  });
358  });
359  });
361  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
362  });
363 
364  i += 1;
365  } while(i < (num_loop - 1));
366  }
367 
368  // tail
369  if constexpr(TailNum == TailNumber::Full)
370  {
371  block_sync_lds();
372  blockwise_gemm_func();
373  }
374  }
375 
376  template <bool HasMainLoop,
377  TailNumber TailNum,
378  typename AGridDesc,
379  typename ABlockDesc,
380  typename ABlockTransfer,
381  typename AGridBuffer,
382  typename ABlockBuffer,
383  typename ABlockTransferStep,
384  typename BGridDesc,
385  typename BBlockDesc,
386  typename BBlockTransfer,
387  typename BGridBuffer,
388  typename BBlockBuffer,
389  typename BBlockTransferStep,
390  typename CThreadBuffer,
391  typename AScaleStruct,
392  typename BScaleStruct,
394  !ck::is_same_v<BScaleStruct, Empty>,
395  bool>::type = false>
396  __device__ void Run(const AGridDesc& a_grid_desc,
397  const ABlockDesc& a_block_desc,
398  ABlockTransfer& a_blockwise_copy,
399  const AGridBuffer& a_grid_buf,
400  ABlockBuffer& a_block_buf,
401  const ABlockTransferStep& a_block_copy_step,
402  const BGridDesc& b_grid_desc,
403  const BBlockDesc& b_block_desc,
404  BBlockTransfer& b_blockwise_copy,
405  const BGridBuffer& b_grid_buf,
406  BBlockBuffer& b_block_buf,
407  const BBlockTransferStep& b_block_copy_step,
408  CThreadBuffer& c_thread_buf,
409  AScaleStruct& a_scale_struct,
410  BScaleStruct& b_scale_struct,
411  index_t num_loop,
412  index_t num_loop_per_scale) const
413  {
414  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
415  static constexpr auto NumScaleKBlock =
416  Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
417 
418  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
419  Base::a_thread_desc_.GetElementSpaceSize());
420  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
421  Base::b_thread_desc_.GetElementSpaceSize());
422 
423  using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
424  auto c_scale_struct = CScaleStruct{};
425 
426  // Global prefetch 1
427  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
428  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
429 
430  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
431  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
432 
433  // Scales global load
434  a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
435  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
436 
437  c_scale_struct.Load(a_scale_struct, b_scale_struct);
438 
439  // Local prefill 1
440  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
441  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
442 
443  // Initialize C
444  c_thread_buf.Clear();
445 
446  auto blockwise_gemm_func = [&]() {
447  // Local load
448  static_for<0, KRepeat, 1>{}([&](auto k0) {
449  static_for<0, MRepeat, 1>{}([&](auto m0) {
450  Base::a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
451  make_tuple(I0, m0, k0, I0, I0, I0, I0),
452  a_block_buf,
453  Base::a_thread_desc_,
454  make_tuple(I0, m0, k0, I0, I0, I0, I0),
455  a_thread_buf);
456  });
457  static_for<0, NRepeat, 1>{}([&](auto n0) {
458  Base::b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
459  make_tuple(I0, n0, k0, I0, I0, I0, I0),
460  b_block_buf,
461  Base::b_thread_desc_,
462  make_tuple(I0, n0, k0, I0, I0, I0, I0),
463  b_thread_buf);
464  });
465  });
466 
467  static_for<0, MRepeat, 1>{}([&](auto m0) {
468  static_for<0, NRepeat, 1>{}([&](auto n0) {
469  static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
470  c_scale_struct.Clear();
471  static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
472  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
473  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
474 
475  static_for<0, KInner, 1>{}([&](auto k_inner) {
476  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
477  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
478  constexpr index_t k_index =
479  kscale0 * (KRepeat / NumScaleKBlock) + k0;
480  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
481  a_thread_buf[Number<Base::a_thread_desc_.CalculateOffset(
483  m0,
484  k_index,
485  I0,
486  I0,
487  I0,
488  Number<kk % A_K1>{}))>{}];
489  });
490  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
491  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
492  constexpr index_t k_index =
493  kscale0 * (KRepeat / NumScaleKBlock) + k0;
494  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
495  b_thread_buf[Number<Base::b_thread_desc_.CalculateOffset(
497  n0,
498  k_index,
499  I0,
500  I0,
501  I0,
502  Number<kk % B_K1>{}))>{}];
503  });
504 
505  using wmma_input_type_a =
506  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
507  using wmma_input_type_b =
508  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
509 
510  wmma_gemm.Run(
511  a_thread_vec.template AsType<wmma_input_type_a>(),
512  b_thread_vec.template AsType<wmma_input_type_b>(),
513  c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
514  Number<0>{}));
515  });
516  });
517  c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
518  });
519  });
520  });
521  };
522 
523  // main body
524  if constexpr(HasMainLoop)
525  {
526  index_t i = 0;
527  do
528  {
529  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
530  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
531 
532  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
533  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
534 
535  a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
536  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
537 
538  block_sync_lds();
539  blockwise_gemm_func();
540 
541  block_sync_lds();
542  c_scale_struct.Load(a_scale_struct, b_scale_struct);
543 
544  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
545  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
546 
547  i += 1;
548  } while(i < (num_loop - 1));
549  }
550 
551  // tail
552  if constexpr(TailNum == TailNumber::Full)
553  {
554  block_sync_lds();
555  blockwise_gemm_func();
556  }
557  }
558 
559  protected:
560  // A[MRepeat, I1, I1, KPack]
561  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
562  make_tuple(Number<KPack / A_K1 / A_KRow>{}, I1, I1, I1, I1, I1, Number<A_K1>{}));
563 
564  // B[NRepeat, N1, N2, KPack]
565  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(
566  Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, I1, Number<B_K1>{}));
567 
568  using AThreadCopy =
570  ComputeTypeA,
571  decltype(a_block_desc_k0_m0_m1_m2_k1),
572  decltype(a_thread_desc_),
573  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
575  6,
576  A_K1,
577  A_K1>;
578 
579  using BThreadCopy =
581  ComputeTypeB,
582  decltype(b_block_desc_k0_n0_n1_n2_k1),
583  decltype(b_thread_desc_),
584  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
586  6,
587  B_K1,
588  B_K1>;
589 
590  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
591  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
592  using Base::c_thread_desc_;
593 };
594 
595 template <index_t BlockSize,
596  typename ADataType,
597  typename BDataType,
598  typename ComputeTypeA,
599  typename ComputeTypeB,
600  typename AccDataType,
601  typename AWmmaTileDesc,
602  typename BWmmaTileDesc,
603  index_t ABlockTransferSrcScalarPerVector,
604  index_t BBlockTransferSrcScalarPerVector,
605  index_t MPerBlock,
606  index_t NPerBlock,
607  index_t KPerBlock,
608  index_t MPerWmma,
609  index_t NPerWmma,
610  index_t MRepeat,
611  index_t NRepeat,
612  index_t KPack,
613  index_t KInner,
614  bool TransposeC>
616  BlockSize,
617  ADataType,
618  BDataType,
619  ComputeTypeA,
620  ComputeTypeB,
621  AccDataType,
622  AWmmaTileDesc,
623  BWmmaTileDesc,
624  ABlockTransferSrcScalarPerVector,
625  BBlockTransferSrcScalarPerVector,
626  MPerBlock,
627  NPerBlock,
628  KPerBlock,
629  MPerWmma,
630  NPerWmma,
631  MRepeat,
632  NRepeat,
633  KPack,
634  KInner,
635  TransposeC,
636  false>
638  ADataType,
639  BDataType,
640  ComputeTypeA,
641  ComputeTypeB,
642  AccDataType,
643  AWmmaTileDesc,
644  BWmmaTileDesc,
645  ABlockTransferSrcScalarPerVector,
646  BBlockTransferSrcScalarPerVector,
647  MPerBlock,
648  NPerBlock,
649  KPerBlock,
650  MPerWmma,
651  NPerWmma,
652  MRepeat,
653  NRepeat,
654  KPack,
655  KInner,
656  TransposeC>
657 {
658  // GlobalPrefetchStages: 1
659  // LocalPreFillStages: 1
660  // LocalPreFetchStages: 0
661  // LocalSharedMemoryBuffer: 1
663  ADataType,
664  BDataType,
665  ComputeTypeA,
666  ComputeTypeB,
667  AccDataType,
668  AWmmaTileDesc,
669  BWmmaTileDesc,
670  ABlockTransferSrcScalarPerVector,
671  BBlockTransferSrcScalarPerVector,
672  MPerBlock,
673  NPerBlock,
674  KPerBlock,
675  MPerWmma,
676  NPerWmma,
677  MRepeat,
678  NRepeat,
679  KPack,
680  KInner,
681  TransposeC>;
682  using Base::I0;
683  using Base::I1;
684 
685  using Base::A_K1;
686  using Base::A_KRow;
687  using Base::B_K1;
688  using Base::B_KRow;
689  using Base::KRepeat;
690  using Base::WmmaK;
691 
692  using Base::wmma_gemm;
693 
694  using Base::CalculateCThreadOriginDataIndex;
695  using Base::
696  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
697  using Base::GetCThreadBuffer;
698  using Base::
699  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
700 
701  using Base::a_block_desc_k0_m0_m1_m2_k1;
702  using Base::b_block_desc_k0_n0_n1_n2_k1;
703 
704  using typename Base::Empty;
705 
707  static constexpr index_t KRepeatPerCluster = math::max(KRepeat / NumKClusters, 1);
708 
709  static constexpr index_t PrefetchStages = 1;
710  static constexpr index_t PrefillStages = 1;
711  static constexpr index_t GlobalBufferNum = 1;
712 
713  __host__ __device__ static bool BlockHasHotloop(index_t num_loop)
714  {
715  return num_loop > PrefetchStages;
716  }
717 
719  {
720  ignore = num_loop;
721  return TailNumber::Full;
722  }
723 
724  template <typename AScaleStruct, typename BScaleStruct>
725  struct KLoopParams
726  {
727  static constexpr auto KRepeatNoScale = 1;
728  static constexpr auto NumScaleKBlock =
729  Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
730  static constexpr auto KRepeatPerNumScaleKBlock = KRepeatPerCluster / NumScaleKBlock;
731  };
732 
733  template <>
734  struct KLoopParams<Empty, Empty>
735  {
736  static constexpr index_t KRepeatNoScale = KRepeatPerCluster;
737  static constexpr index_t NumScaleKBlock = 1;
738  static constexpr index_t KRepeatPerNumScaleKBlock = 1;
739  };
740 
741  template <bool HasMainLoop,
742  TailNumber TailNum,
743  typename AGridDesc,
744  typename ABlockDesc,
745  typename ABlockTransfer,
746  typename AGridBuffer,
747  typename ABlockBuffer,
748  typename ABlockTransferStep,
749  typename BGridDesc,
750  typename BBlockDesc,
751  typename BBlockTransfer,
752  typename BGridBuffer,
753  typename BBlockBuffer,
754  typename BBlockTransferStep,
755  typename CThreadBuffer,
756  typename AScaleStruct,
757  typename BScaleStruct,
758  typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
759  __device__ void Run(const AGridDesc& a_grid_desc,
760  const ABlockDesc& a_block_desc,
761  ABlockTransfer& a_blockwise_copy,
762  const AGridBuffer& a_grid_buf,
763  ABlockBuffer& a_block_buf,
764  const ABlockTransferStep& a_block_copy_step,
765  const BGridDesc& b_grid_desc,
766  const BBlockDesc& b_block_desc,
767  BBlockTransfer& b_blockwise_copy,
768  const BGridBuffer& b_grid_buf,
769  BBlockBuffer& b_block_buf,
770  const BBlockTransferStep& b_block_copy_step,
771  CThreadBuffer& c_thread_buf,
772  AScaleStruct&,
773  BScaleStruct& b_scale_struct,
774  index_t num_loop,
775  index_t num_loop_per_scale) const
776  {
777  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
778 
779  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
780  a_thread_desc_.GetElementSpaceSize());
781  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
782  b_thread_desc_.GetElementSpaceSize());
783 
784  // Global prefetch 1
785  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
786  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
787 
788  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
789  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
790 
791  // Scales global load
792  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
793 
794  // Local prefill 1
795  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
796  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
797 
798  // Initialize C
799  c_thread_buf.Clear();
800 
801  auto blockwise_gemm_func = [&]() {
802  static_for<0, KRepeat, KRepeatPerCluster>{}([&](auto k0_offset) {
803  static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
804  static_for<0, MRepeat, 1>{}([&](auto m0) {
805  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
806  make_tuple(I0, m0, k0_offset + k0_inner, I0, I0, I0, I0),
807  a_block_buf,
808  a_thread_desc_,
809  make_tuple(I0, m0, k0_inner, I0, I0, I0, I0),
810  a_thread_buf);
811  });
812  if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
813  {
814  static_for<0, NRepeat, 1>{}([&](auto n0) {
815  b_thread_copy_.Run(
816  b_block_desc_k0_n0_n1_n2_k1,
817  make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
818  b_block_buf,
819  b_thread_desc_,
820  make_tuple(I0, n0, k0_inner, I0, I0, I0, I0),
821  b_thread_buf);
822  });
823  }
824  else
825  {
826  static_for<0, NRepeat, 1>{}([&](auto n0) {
827  b_thread_copy_.Run(
828  b_block_desc_k0_n0_n1_n2_k1,
829  make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
830  b_block_buf,
831  b_scale_struct.scale_thread_bufs(I0)[Number<
832  n0 * BScaleStruct::num_scale_k_block +
833  (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
834  b_thread_desc_,
835  make_tuple(I0, n0, k0_inner, I0, I0, I0, I0),
836  b_thread_buf);
837  });
838  }
839  });
840 
841  __builtin_amdgcn_sched_barrier(0);
842  // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
843  // but except the first, as we can shorten non-MAC cluster a bit and there's no
844  // observable negative impact. The desired effect is waves in a workgroup
845  // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
846  // resource from other workgroups and reducing the chance of latency hiding by
847  // waiting for the rest of the workgroup at the eventual sync point.
848  if constexpr(k0_offset != 0 || KRepeat == 1)
849  {
850  __builtin_amdgcn_s_barrier();
851  __builtin_amdgcn_sched_barrier(0);
852  }
853  static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
854  static_for<0, KInner, 1>{}([&](auto k_inner) {
855  static_for<0, MRepeat, 1>{}([&](auto m0) {
856  static_for<0, NRepeat, 1>{}([&](auto n0) {
857  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
858  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
859 
860  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
861  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
862  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
863  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
865  m0,
866  k0_inner,
867  I0,
868  I0,
869  I0,
870  Number<kk % A_K1>{}))>{}];
871  });
872  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
873  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
874  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
875  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
877  n0,
878  k0_inner,
879  I0,
880  I0,
881  I0,
882  Number<kk % B_K1>{}))>{}];
883  });
884 
885  using wmma_input_type_a =
886  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
887  using wmma_input_type_b =
888  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
889 
890  constexpr index_t c_offset =
891  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
892 
893  // The block_sync_lds() here performs double duty:
894  // A) safeguard against data hazard.
895  // B) reduce VMEM FIFO congestion by applying small delays to
896  // different wavefronts.
897  // It is performed near the end of MAC cluster to minimize lgkmcnt
898  // penalty
899  if constexpr(k0_offset + k0_inner == KRepeat - 1 &&
900  m0 == MRepeat - 1 && n0 == NRepeat - 1)
901  {
902  __builtin_amdgcn_sched_barrier(0);
903  block_sync_lds();
904  __builtin_amdgcn_sched_barrier(0);
905  }
906  wmma_gemm.Run(
907  a_thread_vec.template AsType<wmma_input_type_a>(),
908  b_thread_vec.template AsType<wmma_input_type_b>(),
909  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
910  if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
911  {
912  __builtin_amdgcn_sched_barrier(0);
913  __builtin_amdgcn_s_setprio(1);
914  __builtin_amdgcn_sched_barrier(0);
915  }
916  });
917  });
918  });
919  });
920 
921  __builtin_amdgcn_sched_barrier(0);
922  __builtin_amdgcn_s_setprio(0);
923  __builtin_amdgcn_sched_barrier(0);
924  });
925  };
926 
927  // main body
928  if constexpr(HasMainLoop)
929  {
930  index_t i = 0;
931  do
932  {
933  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
934  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
935 
936  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
937  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
938 
939  block_sync_lds();
940  blockwise_gemm_func();
941 
942  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
943  if constexpr(ck::is_same<BScaleStruct, Empty>::value == false)
944  {
945  block_sync_lds();
946  }
947  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
948  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
949 
950  i += 1;
951  } while(i < (num_loop - 1));
952  }
953 
954  // tail
955  if constexpr(TailNum == TailNumber::Full)
956  {
957  block_sync_lds();
958  blockwise_gemm_func();
959  }
960  }
961 
962  protected:
963  static constexpr auto a_thread_desc_ =
965  Number<MRepeat>{},
966  Number<KRepeatPerCluster>{},
967  I1,
968  I1,
969  I1,
970  Number<A_K1>{}),
971  make_tuple(Number<A_K1>{},
972  Number<KPack / A_KRow>{},
973  Number<KPack / A_KRow * MRepeat>{},
974  I0,
975  I0,
976  I0,
977  I1));
978 
979  static constexpr auto b_thread_desc_ =
981  Number<NRepeat>{},
982  Number<KRepeatPerCluster>{},
983  I1,
984  I1,
985  I1,
986  Number<B_K1>{}),
987  make_tuple(Number<B_K1>{},
988  Number<KPack / B_KRow>{},
989  Number<KPack / B_KRow * NRepeat>{},
990  I0,
991  I0,
992  I0,
993  I1));
994 
995  using AThreadCopy =
997  ComputeTypeA,
998  decltype(a_block_desc_k0_m0_m1_m2_k1),
999  decltype(a_thread_desc_),
1000  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
1002  6,
1003  A_K1,
1004  A_K1>;
1005 
1006  using BThreadCopy =
1008  ComputeTypeB,
1009  decltype(b_block_desc_k0_n0_n1_n2_k1),
1010  decltype(b_thread_desc_),
1011  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
1013  6,
1014  B_K1,
1015  B_K1>;
1016 
1017  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
1018  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
1019  using Base::c_thread_desc_;
1020 };
1021 
1022 template <index_t BlockSize,
1023  typename ADataType,
1024  typename BDataType,
1025  typename ComputeTypeA,
1026  typename ComputeTypeB,
1027  typename AccDataType,
1028  typename AWmmaTileDesc,
1029  typename BWmmaTileDesc,
1030  index_t ABlockTransferSrcScalarPerVector,
1031  index_t BBlockTransferSrcScalarPerVector,
1032  index_t MPerBlock,
1033  index_t NPerBlock,
1034  index_t KPerBlock,
1035  index_t MPerWmma,
1036  index_t NPerWmma,
1037  index_t MRepeat,
1038  index_t NRepeat,
1039  index_t KPack,
1040  index_t KInner,
1041  bool TransposeC>
1043  BlockSize,
1044  ADataType,
1045  BDataType,
1046  ComputeTypeA,
1047  ComputeTypeB,
1048  AccDataType,
1049  AWmmaTileDesc,
1050  BWmmaTileDesc,
1051  ABlockTransferSrcScalarPerVector,
1052  BBlockTransferSrcScalarPerVector,
1053  MPerBlock,
1054  NPerBlock,
1055  KPerBlock,
1056  MPerWmma,
1057  NPerWmma,
1058  MRepeat,
1059  NRepeat,
1060  KPack,
1061  KInner,
1062  TransposeC,
1063  true>
1065  ADataType,
1066  BDataType,
1067  ComputeTypeA,
1068  ComputeTypeB,
1069  AccDataType,
1070  AWmmaTileDesc,
1071  BWmmaTileDesc,
1072  ABlockTransferSrcScalarPerVector,
1073  BBlockTransferSrcScalarPerVector,
1074  MPerBlock,
1075  NPerBlock,
1076  KPerBlock,
1077  MPerWmma,
1078  NPerWmma,
1079  MRepeat,
1080  NRepeat,
1081  KPack,
1082  KInner,
1083  TransposeC>
1084 {
1085  // GlobalPrefetchStages: 2
1086  // LocalPreFillStages: 1
1087  // LocalPreFetchStages: 1
1088  // LocalSharedMemoryBuffer: 1
1090  ADataType,
1091  BDataType,
1092  ComputeTypeA,
1093  ComputeTypeB,
1094  AccDataType,
1095  AWmmaTileDesc,
1096  BWmmaTileDesc,
1097  ABlockTransferSrcScalarPerVector,
1098  BBlockTransferSrcScalarPerVector,
1099  MPerBlock,
1100  NPerBlock,
1101  KPerBlock,
1102  MPerWmma,
1103  NPerWmma,
1104  MRepeat,
1105  NRepeat,
1106  KPack,
1107  KInner,
1108  TransposeC>;
1109  using Base::I0;
1110  using Base::I1;
1111  using Base::MWaves;
1112  using Base::WaveSize;
1113  using typename Base::HotLoopInstList;
1114 
1115  using Base::A_K1;
1116  using Base::A_KRow;
1117  using Base::B_K1;
1118  using Base::B_KRow;
1119  using Base::KRepeat;
1120  using Base::WmmaK;
1121 
1122  using Base::wmma_gemm;
1123 
1124  using Base::CalculateCThreadOriginDataIndex;
1125  using Base::
1126  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
1127  using Base::GetCThreadBuffer;
1128  using Base::
1129  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
1130 
1131  using Base::a_block_desc_k0_m0_m1_m2_k1;
1132  using Base::b_block_desc_k0_n0_n1_n2_k1;
1133 
1134  using typename Base::Empty;
1135 
1136  static constexpr index_t PrefetchStages = 2;
1137  static constexpr index_t PrefillStages = 1;
1138  static constexpr index_t GlobalBufferNum = 2;
1139 
1140  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
1141 
1143  {
1144  return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
1145  }
1146 
1147  __device__ static constexpr auto HotLoopScheduler()
1148  {
1149  constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
1150  constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
1151  constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
1152  constexpr auto wmma_interleave = 2;
1153  // B global
1155  ignore = i;
1156  if constexpr(MPerBlock >= 128 && NPerBlock >= 128)
1157  {
1158  __builtin_amdgcn_sched_group_barrier(0x008, 2 * wmma_interleave, 0);
1159  }
1160  else
1161  {
1162  __builtin_amdgcn_sched_group_barrier(0x008, wmma_interleave, 0);
1163  }
1164  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
1165  });
1166 
1167  // A global
1169  ignore = i;
1170  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
1171  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
1172  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
1173  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
1174  });
1175 
1176  // A local
1177  static_for<0, num_ds_read_inst_a, 1>{}([&](auto i) {
1178  ignore = i;
1179  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
1180  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
1181  });
1182  }
1183 
1184  template <bool HasMainLoop,
1185  TailNumber TailNum,
1186  typename AGridDesc,
1187  typename ABlockDesc,
1188  typename ABlockTransfer,
1189  typename AGridBuffer,
1190  typename ABlockBuffer,
1191  typename ABlockTransferStep,
1192  typename BGridDesc,
1193  typename BBlockDesc,
1194  typename BBlockTransfer,
1195  typename BGridBuffer,
1196  typename BBlockBuffer,
1197  typename BBlockTransferStep,
1198  typename CThreadBuffer,
1199  typename AScaleStruct,
1200  typename BScaleStruct,
1201  typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
1202  __device__ void Run(const AGridDesc& a_grid_desc,
1203  const ABlockDesc& a_block_desc,
1204  ABlockTransfer& a_blockwise_copy,
1205  const AGridBuffer& a_grid_buf,
1206  ABlockBuffer& a_block_buf,
1207  const ABlockTransferStep& a_block_copy_step,
1208  const BGridDesc& b_grid_desc,
1209  const BBlockDesc&,
1210  BBlockTransfer& b_blockwise_copy,
1211  const BGridBuffer& b_grid_buf,
1212  BBlockBuffer&,
1213  const BBlockTransferStep& b_block_copy_step,
1214  CThreadBuffer& c_thread_buf,
1215  AScaleStruct&,
1216  BScaleStruct&,
1217  index_t num_loop,
1218  index_t) const
1219  {
1220  __builtin_amdgcn_sched_barrier(0);
1221  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
1222 
1223  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
1224  a_thread_desc_.GetElementSpaceSize());
1225  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
1226  b_thread_desc_.GetElementSpaceSize());
1227 
1228  StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
1229  constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
1230 
1231  // Global prefetch A1 B1
1232  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
1233  b_blockwise_copy.Run(b_grid_desc,
1234  b_grid_buf,
1235  b_block_desc_k0_n0_n1_n2_k1,
1236  b_block_origin_idx,
1237  b_thread_bufs(I0));
1238 
1239  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1240  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1241  __builtin_amdgcn_sched_barrier(0);
1242 
1243  // Local prefill A1
1244  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1245 
1246  // Global prefetch A2
1247  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
1248  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1249 
1250  // Local prefetch A1
1251  block_sync_lds();
1252  static_for<0, MRepeat, 1>{}([&](auto m0) {
1253  static_for<0, KRepeat, 1>{}([&](auto k0) {
1254  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1255  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1256  a_block_buf,
1257  a_thread_desc_,
1258  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1259  a_thread_buf);
1260  });
1261  });
1262 
1263  // Initialize C
1264  c_thread_buf.Clear();
1265 
1266  __builtin_amdgcn_sched_barrier(0);
1267 
1268  // main body
1269  if constexpr(HasMainLoop)
1270  {
1271  index_t i = 0;
1272  do
1273  {
1274  auto LoopFunc = [&](auto wmma_reg_buf, auto local_read_buf) {
1275  b_blockwise_copy.Run(b_grid_desc,
1276  b_grid_buf,
1277  b_block_desc_k0_n0_n1_n2_k1,
1278  b_block_origin_idx,
1279  b_thread_bufs(local_read_buf));
1280 
1281  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1282 
1283  block_sync_lds();
1284 
1285  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf);
1286 
1287  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
1288  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1289 
1290  static_for<0, MRepeat, 1>{}([&](auto m0) {
1291  static_for<0, NRepeat, 1>{}([&](auto n0) {
1292  static_for<0, KRepeat, 1>{}([&](auto k0) {
1293  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1294  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1295  static_for<0, KInner, 1>{}([&](auto k_inner) {
1296  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
1297  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1298  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1299  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1301  m0,
1302  k0,
1303  I0,
1304  I0,
1305  I0,
1306  Number<kk % A_K1>{}))>{}];
1307  });
1308  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
1309  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1310  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1311  b_thread_bufs[wmma_reg_buf]
1312  [Number<b_thread_desc_.CalculateOffset(
1314  I0,
1315  I0,
1316  n0,
1317  I0,
1318  k0,
1319  Number<kk % B_K1>{}))>{}];
1320  });
1321  using wmma_input_type_a =
1322  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1323  using wmma_input_type_b =
1324  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1325 
1326  constexpr index_t c_offset =
1327  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
1328 
1329  wmma_gemm.Run(
1330  a_thread_vec.template AsType<wmma_input_type_a>(),
1331  b_thread_vec.template AsType<wmma_input_type_b>(),
1332  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1333  });
1334  });
1335  });
1336  });
1337 
1338  block_sync_lds();
1339 
1340  // loop prefetch copy
1341  static_for<0, MRepeat, 1>{}([&](auto m0) {
1342  static_for<0, KRepeat, 1>{}([&](auto k0) {
1343  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1344  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1345  a_block_buf,
1346  a_thread_desc_,
1347  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1348  a_thread_buf);
1349  });
1350  });
1351 
1352  HotLoopScheduler();
1353  __builtin_amdgcn_sched_barrier(0);
1354  };
1355 
1356  LoopFunc(I0, I1);
1357  LoopFunc(I1, I0);
1358 
1359  i += 2;
1360  } while(i < (num_loop - 2));
1361  }
1362 
1363  // tail
1364  if constexpr(TailNum == TailNumber::Even)
1365  {
1366  b_blockwise_copy.Run(b_grid_desc,
1367  b_grid_buf,
1368  b_block_desc_k0_n0_n1_n2_k1,
1369  b_block_origin_idx,
1370  b_thread_bufs(I1));
1371 
1372  block_sync_lds();
1373 
1374  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1375 
1376  static_for<0, MRepeat, 1>{}([&](auto m0) {
1377  static_for<0, NRepeat, 1>{}([&](auto n0) {
1378  static_for<0, KRepeat, 1>{}([&](auto k0) {
1379  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1380  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1381  static_for<0, KInner, 1>{}([&](auto k_inner) {
1382  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
1383  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1384  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1385  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1387  m0,
1388  k0,
1389  I0,
1390  I0,
1391  I0,
1392  Number<kk % A_K1>{}))>{}];
1393  });
1394  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
1395  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1396  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1397  b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
1399  I0,
1400  I0,
1401  n0,
1402  I0,
1403  k0,
1404  Number<kk % B_K1>{}))>{}];
1405  });
1406 
1407  using wmma_input_type_a =
1408  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1409  using wmma_input_type_b =
1410  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1411 
1412  constexpr index_t c_offset =
1413  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
1414 
1415  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1416  b_thread_vec.template AsType<wmma_input_type_b>(),
1417  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1418  });
1419  });
1420  });
1421  });
1422 
1423  block_sync_lds();
1424 
1425  // tail Local Prefetch A1
1426  static_for<0, MRepeat, 1>{}([&](auto m0) {
1427  static_for<0, KRepeat, 1>{}([&](auto k0) {
1428  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1429  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1430  a_block_buf,
1431  a_thread_desc_,
1432  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1433  a_thread_buf);
1434  });
1435  });
1436 
1437  __builtin_amdgcn_sched_barrier(0);
1438 
1439  static_for<0, MRepeat, 1>{}([&](auto m0) {
1440  static_for<0, NRepeat, 1>{}([&](auto n0) {
1441  static_for<0, KRepeat, 1>{}([&](auto k0) {
1442  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1443  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1444  static_for<0, KInner, 1>{}([&](auto k_inner) {
1445  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
1446  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1447  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1448  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1450  m0,
1451  k0,
1452  I0,
1453  I0,
1454  I0,
1455  Number<kk % A_K1>{}))>{}];
1456  });
1457  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
1458  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1459  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1460  b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
1462  I0,
1463  I0,
1464  n0,
1465  I0,
1466  k0,
1467  Number<kk % B_K1>{}))>{}];
1468  });
1469  using wmma_input_type_a =
1470  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1471  using wmma_input_type_b =
1472  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1473 
1474  constexpr index_t c_offset =
1475  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
1476 
1477  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1478  b_thread_vec.template AsType<wmma_input_type_b>(),
1479  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1480  });
1481  });
1482  });
1483  });
1484  // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
1485  // latency
1486  // __builtin_amdgcn_sched_barrier(0);
1487  }
1488  else if constexpr(TailNum == TailNumber::Odd)
1489  {
1490  static_for<0, MRepeat, 1>{}([&](auto m0) {
1491  static_for<0, NRepeat, 1>{}([&](auto n0) {
1492  static_for<0, KRepeat, 1>{}([&](auto k0) {
1493  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1494  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1495  static_for<0, KInner, 1>{}([&](auto k_inner) {
1496  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
1497  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1498  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1499  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1501  m0,
1502  k0,
1503  I0,
1504  I0,
1505  I0,
1506  Number<kk % A_K1>{}))>{}];
1507  });
1508  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
1509  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1510  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1511  b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
1513  I0,
1514  I0,
1515  n0,
1516  I0,
1517  k0,
1518  Number<kk % B_K1>{}))>{}];
1519  });
1520  using wmma_input_type_a =
1521  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1522  using wmma_input_type_b =
1523  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1524 
1525  constexpr index_t c_offset =
1526  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
1527 
1528  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1529  b_thread_vec.template AsType<wmma_input_type_b>(),
1530  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1531  });
1532  });
1533  });
1534  });
1535  }
1536  }
1537 
1538  template <bool HasMainLoop,
1539  TailNumber TailNum,
1540  typename AGridDesc,
1541  typename ABlockDesc,
1542  typename ABlockTransfer,
1543  typename AGridBuffer,
1544  typename ABlockBuffer,
1545  typename ABlockTransferStep,
1546  typename BGridDesc,
1547  typename BBlockDesc,
1548  typename BBlockTransfer,
1549  typename BGridBuffer,
1550  typename BBlockBuffer,
1551  typename BBlockTransferStep,
1552  typename CThreadBuffer,
1553  typename AScaleStruct,
1554  typename BScaleStruct,
1556  !ck::is_same_v<BScaleStruct, Empty>,
1557  bool>::type = false>
1558  __device__ void Run(const AGridDesc& a_grid_desc,
1559  const ABlockDesc& a_block_desc,
1560  ABlockTransfer& a_blockwise_copy,
1561  const AGridBuffer& a_grid_buf,
1562  ABlockBuffer& a_block_buf,
1563  const ABlockTransferStep& a_block_copy_step,
1564  const BGridDesc& b_grid_desc,
1565  const BBlockDesc&,
1566  BBlockTransfer& b_blockwise_copy,
1567  const BGridBuffer& b_grid_buf,
1568  BBlockBuffer&,
1569  const BBlockTransferStep& b_block_copy_step,
1570  CThreadBuffer& c_thread_buf,
1571  AScaleStruct& a_scale_struct,
1572  BScaleStruct& b_scale_struct,
1573  index_t num_loop,
1574  index_t num_loop_per_scale) const
1575  {
1576  __builtin_amdgcn_sched_barrier(0);
1577  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
1578  static constexpr auto NumScaleKBlock =
1579  Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
1580 
1581  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
1582  a_thread_desc_.GetElementSpaceSize());
1583  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
1584  b_thread_desc_.GetElementSpaceSize());
1585 
1586  StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
1587  constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
1588 
1589  using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
1590  auto c_scale_struct = CScaleStruct{};
1591 
1592  auto gemm_core_func = [&](auto reg_buf) {
1593  static_for<0, MRepeat, 1>{}([&](auto m0) {
1594  static_for<0, NRepeat, 1>{}([&](auto n0) {
1595  static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
1596  c_scale_struct.Clear();
1597  static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
1598  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1599  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1600  static_for<0, KInner, 1>{}([&](auto k_inner) {
1601  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
1602  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1603  constexpr index_t k_index =
1604  kscale0 * (KRepeat / NumScaleKBlock) + k0;
1605  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1606  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1608  m0,
1609  k_index,
1610  I0,
1611  I0,
1612  I0,
1613  Number<kk % A_K1>{}))>{}];
1614  });
1615  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
1616  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1617  constexpr index_t k_index =
1618  kscale0 * (KRepeat / NumScaleKBlock) + k0;
1619  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1620  b_thread_bufs[reg_buf]
1621  [Number<b_thread_desc_.CalculateOffset(
1623  I0,
1624  I0,
1625  n0,
1626  I0,
1627  k_index,
1628  Number<kk % B_K1>{}))>{}];
1629  });
1630  using wmma_input_type_a =
1631  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1632  using wmma_input_type_b =
1633  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1634  wmma_gemm.Run(
1635  a_thread_vec.template AsType<wmma_input_type_a>(),
1636  b_thread_vec.template AsType<wmma_input_type_b>(),
1637  c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
1638  Number<0>{}));
1639  });
1640  });
1641  c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
1642  });
1643  });
1644  });
1645  };
1646 
1647  auto a_local_prefetch_func = [&]() {
1648  static_for<0, MRepeat, 1>{}([&](auto m0) {
1649  static_for<0, KRepeat, 1>{}([&](auto k0) {
1650  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1651  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1652  a_block_buf,
1653  a_thread_desc_,
1654  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1655  a_thread_buf);
1656  });
1657  });
1658  };
1659 
1660  // Global prefetch A1 B1
1661  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
1662  b_blockwise_copy.Run(b_grid_desc,
1663  b_grid_buf,
1664  b_block_desc_k0_n0_n1_n2_k1,
1665  b_block_origin_idx,
1666  b_thread_bufs(I0));
1667 
1668  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1669  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1670 
1671  // Scales global load
1672  a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
1673  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
1674 
1675  __builtin_amdgcn_sched_barrier(0);
1676 
1677  c_scale_struct.Load(a_scale_struct, b_scale_struct);
1678 
1679  // Local prefill A1
1680  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1681 
1682  // Global prefetch A2
1683  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
1684  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1685 
1686  // Local prefetch A1
1687  block_sync_lds();
1688  a_local_prefetch_func();
1689 
1690  // Initialize C
1691  c_thread_buf.Clear();
1692 
1693  __builtin_amdgcn_sched_barrier(0);
1694 
1695  // main body
1696  if constexpr(HasMainLoop)
1697  {
1698  index_t i = 0;
1699  do
1700  {
1701  auto LoopFunc = [&](auto wmma_reg_buf, auto local_read_buf) {
1702  b_blockwise_copy.Run(b_grid_desc,
1703  b_grid_buf,
1704  b_block_desc_k0_n0_n1_n2_k1,
1705  b_block_origin_idx,
1706  b_thread_bufs(local_read_buf));
1707 
1708  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1709 
1710  block_sync_lds();
1711 
1712  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf);
1713 
1714  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
1715  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1716 
1717  a_scale_struct.template GlobalLoad<0>(
1718  (i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
1719  b_scale_struct.template GlobalLoad<0>(
1720  (i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
1721 
1722  gemm_core_func(wmma_reg_buf);
1723 
1724  block_sync_lds();
1725 
1726  // loop prefetch copy
1727  a_local_prefetch_func();
1728 
1729  c_scale_struct.Load(a_scale_struct, b_scale_struct);
1730 
1731  // HotLoopScheduler();
1732  __builtin_amdgcn_sched_barrier(0);
1733  };
1734 
1735  LoopFunc(I0, I1);
1736  LoopFunc(I1, I0);
1737 
1738  i += 2;
1739  } while(i < (num_loop - 2));
1740  }
1741 
1742  // tail
1743  if constexpr(TailNum == TailNumber::Even)
1744  {
1745  b_blockwise_copy.Run(b_grid_desc,
1746  b_grid_buf,
1747  b_block_desc_k0_n0_n1_n2_k1,
1748  b_block_origin_idx,
1749  b_thread_bufs(I1));
1750 
1751  block_sync_lds();
1752 
1753  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1754 
1755  a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
1756  b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
1757 
1758  gemm_core_func(I0);
1759 
1760  block_sync_lds();
1761 
1762  // tail Local Prefetch A1
1763  a_local_prefetch_func();
1764 
1765  c_scale_struct.Load(a_scale_struct, b_scale_struct);
1766 
1767  __builtin_amdgcn_sched_barrier(0);
1768 
1769  gemm_core_func(I1);
1770  // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
1771  // latency
1772  // __builtin_amdgcn_sched_barrier(0);
1773  }
1774  else if constexpr(TailNum == TailNumber::Odd)
1775  {
1776  gemm_core_func(I0);
1777  }
1778  }
1779 
1780  protected:
1781  static constexpr auto b_thread_desc_ =
1783  I1,
1784  I1,
1785  Number<NRepeat>{},
1786  I1,
1787  Number<KRepeat>{},
1788  Number<B_K1>{}));
1789 
1790  using Base::a_thread_copy_;
1791  using Base::a_thread_desc_;
1792  using Base::c_thread_desc_;
1793 };
1794 
1795 } // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:211
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:270
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Tail number enumeration for pipeline buffering.
Definition: scheduler_enum.hpp:49
@ Even
Even number of iterations.
@ Odd
Odd number of iterations.
@ Full
Full tail iterations.
@ Empty
No tail iterations.
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
BlockGemmPipelineScheduler
Block GEMM pipeline scheduler enumeration.
Definition: scheduler_enum.hpp:33
@ Intrawave
Schedule within a single wavefront.
@ Interwave
Schedule across multiple wavefronts.
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:212
int32_t index_t
Definition: ck.hpp:301
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:759
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:183
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &a_scale_struct, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:396
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &, BScaleStruct &, index_t num_loop, index_t) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:1202
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &a_scale_struct, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:1558
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:35
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:11