/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File
gridwise_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck/utility/env.hpp"
19 
20 namespace ck {
21 
28 // operations that could be applied on each tensor respectively. The CDE_op is an
29 // elementwise operation applied to the C and all D tensors.
129 template <typename ALayout,
130  typename BLayout,
131  typename DsLayout,
132  typename ELayout,
133  typename AsDataType,
134  typename BsDataType,
135  typename AccDataType,
136  typename CShuffleDataType,
137  typename DsDataType,
138  typename EDataType,
139  typename AElementwiseOperation,
140  typename BElementwiseOperation,
141  typename CDEElementwiseOperation,
143  index_t BlockSize,
144  index_t MPerBlock,
145  index_t NPerBlock,
146  index_t KPerBlock,
147  index_t AK1Value,
148  index_t BK1Value,
149  index_t MPerWmma,
150  index_t NPerWmma,
151  index_t MRepeat,
152  index_t NRepeat,
153  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
154  typename ABlockTransferThreadClusterArrangeOrder,
155  typename ABlockTransferSrcAccessOrder,
156  index_t ABlockTransferSrcVectorDim,
157  index_t ABlockTransferSrcScalarPerVector,
158  index_t ABlockTransferDstScalarPerVector_AK1,
159  bool AThreadTransferSrcResetCoordinateAfterRun,
160  index_t ABlockLdsExtraM,
161  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
162  typename BBlockTransferThreadClusterArrangeOrder,
163  typename BBlockTransferSrcAccessOrder,
164  index_t BBlockTransferSrcVectorDim,
165  index_t BBlockTransferSrcScalarPerVector,
166  index_t BBlockTransferDstScalarPerVector_BK1,
167  bool BThreadTransferSrcResetCoordinateAfterRun,
168  index_t BBlockLdsExtraN,
169  index_t CShuffleMRepeatPerShuffle,
170  index_t CShuffleNRepeatPerShuffle,
171  typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
172  typename CDEShuffleBlockTransferScalarPerVectors,
173  BlockGemmPipelineScheduler BlkGemmPipeSched,
174  BlockGemmPipelineVersion BlkGemmPipelineVer,
175  typename ComputeTypeA,
176  typename ComputeTypeB,
177  bool PermuteA,
178  bool PermuteB,
179  bool IsBPreShuffled = false,
180  bool ForceThreadTileTransfer = false,
181  bool IsFusedKernel = false>
184  ALayout,
185  BLayout,
186  DsLayout,
187  ELayout,
188  AsDataType,
189  BsDataType,
190  AccDataType,
191  CShuffleDataType,
192  DsDataType,
193  EDataType,
194  AElementwiseOperation,
195  BElementwiseOperation,
196  CDEElementwiseOperation,
197  GemmSpec,
198  BlockSize,
199  MPerBlock,
200  NPerBlock,
201  KPerBlock,
202  AK1Value,
203  BK1Value,
204  MPerWmma,
205  NPerWmma,
206  MRepeat,
207  NRepeat,
208  ABlockTransferThreadClusterLengths_AK0_M_AK1,
209  ABlockTransferThreadClusterArrangeOrder,
210  ABlockTransferSrcAccessOrder,
211  ABlockTransferSrcVectorDim,
212  ABlockTransferSrcScalarPerVector,
213  ABlockTransferDstScalarPerVector_AK1,
214  AThreadTransferSrcResetCoordinateAfterRun,
215  ABlockLdsExtraM,
216  BBlockTransferThreadClusterLengths_BK0_N_BK1,
217  BBlockTransferThreadClusterArrangeOrder,
218  BBlockTransferSrcAccessOrder,
219  BBlockTransferSrcVectorDim,
220  BBlockTransferSrcScalarPerVector,
221  BBlockTransferDstScalarPerVector_BK1,
222  BThreadTransferSrcResetCoordinateAfterRun,
223  BBlockLdsExtraN,
224  CShuffleMRepeatPerShuffle,
225  CShuffleNRepeatPerShuffle,
226  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
227  CDEShuffleBlockTransferScalarPerVectors,
228  BlkGemmPipeSched,
229  BlkGemmPipelineVer,
230  ComputeTypeA,
231  ComputeTypeB,
232  PermuteA,
233  PermuteB,
234  IsBPreShuffled,
235  ForceThreadTileTransfer,
236  IsFusedKernel>
237 {
239  ALayout,
240  BLayout,
241  DsLayout,
242  ELayout,
243  AsDataType,
244  BsDataType,
245  AccDataType,
246  CShuffleDataType,
247  DsDataType,
248  EDataType,
249  AElementwiseOperation,
250  BElementwiseOperation,
251  CDEElementwiseOperation,
252  GemmSpec,
253  BlockSize,
254  MPerBlock,
255  NPerBlock,
256  KPerBlock,
257  AK1Value,
258  BK1Value,
259  MPerWmma,
260  NPerWmma,
261  MRepeat,
262  NRepeat,
263  ABlockTransferThreadClusterLengths_AK0_M_AK1,
264  ABlockTransferThreadClusterArrangeOrder,
265  ABlockTransferSrcAccessOrder,
266  ABlockTransferSrcVectorDim,
267  ABlockTransferSrcScalarPerVector,
268  ABlockTransferDstScalarPerVector_AK1,
269  AThreadTransferSrcResetCoordinateAfterRun,
270  ABlockLdsExtraM,
271  BBlockTransferThreadClusterLengths_BK0_N_BK1,
272  BBlockTransferThreadClusterArrangeOrder,
273  BBlockTransferSrcAccessOrder,
274  BBlockTransferSrcVectorDim,
275  BBlockTransferSrcScalarPerVector,
276  BBlockTransferDstScalarPerVector_BK1,
277  BThreadTransferSrcResetCoordinateAfterRun,
278  BBlockLdsExtraN,
279  CShuffleMRepeatPerShuffle,
280  CShuffleNRepeatPerShuffle,
281  CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
282  CDEShuffleBlockTransferScalarPerVectors,
283  BlkGemmPipeSched,
284  BlkGemmPipelineVer,
285  ComputeTypeA,
286  ComputeTypeB,
287  PermuteA,
288  PermuteB,
289  IsBPreShuffled,
290  ForceThreadTileTransfer,
291  IsFusedKernel>;
292 
293  using Base::I0;
294  using Base::I1;
295  using Base::I2;
296  using Base::I3;
297  using Base::I4;
298  using Base::I5;
299  using Base::I6;
300  using Base::I7;
301 
302  using Base::AK0Number;
303  using Base::AK1Number;
304  using Base::BK0Number;
305  using Base::BK1Number;
306 
307  using Base::APackedSize;
308  using Base::BPackedSize;
309 
313  using Base::CalculateKRead;
314  using Base::CalculateMBlock;
316  using Base::CalculateNBlock;
323 
325 
327 
328  using Base::NumATensor;
329  using Base::NumBTensor;
330  using Base::NumDTensor;
331  using typename Base::AsGridPointer;
332  using typename Base::BsGridPointer;
333  using typename Base::DsGridPointer;
334  using AsDataType_ = AsDataType;
335  using BsDataType_ = BsDataType;
336 
337  struct Problem
338  {
339  __host__ Problem() = default;
340  __host__ __device__ Problem(index_t M_,
341  index_t N_,
342  index_t K_,
343  std::array<index_t, NumATensor> StrideAs_,
344  std::array<index_t, NumBTensor> StrideBs_,
345  std::array<index_t, NumDTensor> StrideDs_,
346  index_t StrideE_,
347  index_t KBatch_)
348  : M{M_},
349  N{N_},
350  K{K_},
351  StrideAs{StrideAs_},
352  StrideBs{StrideBs_},
353  StrideDs{StrideDs_},
354  StrideE{StrideE_},
355  KBatch{KBatch_},
358  KRead{CalculateKRead(K_, KBatch_)},
359  KPadded{CalculateKPadded(K_, KBatch_)},
360  AK0{CalculateAK0Padded(K_, KBatch_)},
361  BK0{CalculateBK0Padded(K_, KBatch_)},
362  MBlock{CalculateMBlock(M_)},
363  NBlock{CalculateNBlock(N_)},
364  Kt{K_}
365  {
366  }
367 
368  __host__ void Print() const
369  {
370  std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
371  << "SAs: {";
372  static_for<0, NumATensor, 1>{}([&](auto i) {
373  std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : "");
374  });
375  std::cout << "}, " << "SBs: {";
376  static_for<0, NumBTensor, 1>{}([&](auto i) {
377  std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : "");
378  });
379  std::cout << "}, ";
380  if constexpr(NumDTensor > 0)
381  {
382  std::cout << "SDs: { ";
383  static_for<0, NumDTensor, 1>{}([&](auto i) {
384  std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
385  });
386  std::cout << " }, ";
387  }
388  std::cout << "SE:" << StrideE << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
389  << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
390  << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
391  << ", " << "NBlock: " << NBlock << "}" << std::endl;
392  }
393 
397  std::array<index_t, NumATensor> StrideAs;
398  std::array<index_t, NumBTensor> StrideBs;
399  std::array<index_t, NumDTensor> StrideDs;
411  };
412 
413  // Argument
415  {
416  __host__ Argument() = default;
417  __host__ __device__ Argument(std::array<const void*, NumATensor> p_as_grid_,
418  std::array<const void*, NumBTensor> p_bs_grid_,
419  std::array<const void*, NumDTensor> p_ds_grid_,
420  EDataType* p_e_grid_,
421  index_t M_,
422  index_t N_,
423  index_t K_,
424  std::array<index_t, NumATensor> StrideAs_,
425  std::array<index_t, NumBTensor> StrideBs_,
426  std::array<index_t, NumDTensor> StrideDs_,
427  index_t StrideE_,
428  index_t k_batch_,
429  AElementwiseOperation a_element_op_,
430  BElementwiseOperation b_element_op_,
431  CDEElementwiseOperation cde_element_op_,
432  bool is_reduce_ = false)
433  : Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideE_, k_batch_},
434  p_as_grid{},
435  p_bs_grid{},
436  p_ds_grid{},
437  p_e_grid{p_e_grid_},
438  a_element_op{a_element_op_},
439  b_element_op{b_element_op_},
440  cde_element_op{cde_element_op_},
441  is_reduce(is_reduce_)
442  {
443  // populate pointer, desc for As
444  static_for<0, NumATensor, 1>{}([&](auto i) {
445  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
446 
447  // A pointer
448  p_as_grid(i) = static_cast<const ADataType_*>(p_as_grid_[i]);
449  });
450 
451  // populate pointer, desc for Bs
452  static_for<0, NumBTensor, 1>{}([&](auto i) {
453  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
454 
455  // B pointer
456  p_bs_grid(i) = static_cast<const BDataType_*>(p_bs_grid_[i]);
457  });
458 
459  // populate pointer, desc for Ds
460  static_for<0, NumDTensor, 1>{}([&](auto i) {
461  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
462 
463  // D pointer
464  p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
465  });
466  }
467 
468  __host__ __device__ inline bool IsReduceAdd() const
469  {
470  return (Problem::KBatch > 1) && is_reduce;
471  }
472 
473  __host__ __device__ inline bool IsAtomicAdd() const
474  {
475  return (Problem::KBatch > 1) && (!is_reduce);
476  }
477 
481  EDataType* p_e_grid;
482 
483  AElementwiseOperation a_element_op;
484  BElementwiseOperation b_element_op;
485  CDEElementwiseOperation cde_element_op;
486 
487  // TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
488  bool is_reduce;
489  };
490 
492  {
493 
494  __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
495  {
496  // Note: in xdl implementation multiple AB supports one layout
497  // but multiple strides, so we create an array of offsets with
498  // the same values.
499  // It should be fixed later on. Once we will have a thread transfer
500  // more flexible.
501  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
502  {
504  [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; });
505  }
506  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
507  {
509  [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
510  }
511 
512  if constexpr(IsBPreShuffled)
513  {
514  static_for<0, NumBTensor, 1>{}([&](auto i) { b_k_split_offset[i] = 0; });
515  }
516  else
517  {
518  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
519  {
520  static_for<0, NumBTensor, 1>{}([&](auto i) {
521  b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i];
522  });
523  }
524  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
525  {
526  if constexpr(!PermuteB)
527  {
529  [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
530  }
531  else
532  {
533  const int k0_offset = karg.KRead * karg.N;
535  [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
536  }
537  }
538  }
539 
540  if(k_id < karg.KBatch - 1)
541  {
542  karg.K = karg.KRead;
543  }
544  else
545  {
546  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
547  }
548 
549  if(karg.IsReduceAdd())
550  {
551  c_reduce_offset = k_id * karg.M * karg.N;
552  }
553  else
554  {
555  c_reduce_offset = 0;
556  }
557  }
558 
559  std::array<index_t, NumATensor> a_k_split_offset;
560  std::array<index_t, NumBTensor> b_k_split_offset;
562  };
563 
565 
566  // return block_id to C matrix tile idx (m0, n0) mapping
567  // if arch = gfx942
569  // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
570 
571  __device__ static index_t GetKBlockPerScale() { return 1; }
572 
573  template <bool HasMainKBlockLoop,
574  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
575  TailNumber TailNum,
576  typename Block2CTileMap,
577  typename EpilogueArgument,
578  int BlockMapMBlockIndex = 0,
579  int BlockMapNBlockIndex = 1>
580  __device__ static void Run(AsGridPointer& p_as_grid,
581  BsGridPointer& p_bs_grid,
582  DsGridPointer& p_ds_grid,
583  EDataType* p_e_grid,
584  void* p_shared,
585  const Problem& problem,
586  const Block2CTileMap& block_2_ctile_map,
587  AElementwiseOperation a_element_op,
588  BElementwiseOperation b_element_op,
589  CDEElementwiseOperation cde_element_op,
590  EpilogueArgument& epilogue_args,
591  const index_t A_k_id = 0,
592  const index_t B_k_id = 0)
593  {
594  const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
595  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
596  const index_t K_b = IsBPreShuffled ? problem.Kt : problem.K;
597  const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
598  K_b, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
599  const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
600  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
601  const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
602  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
603  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
605  ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
606  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
608  e_grid_desc_m_n, problem.MBlock, problem.NBlock);
609 
610  Run<HasMainKBlockLoop,
611  EGlobalMemoryDataOperation,
612  TailNum,
613  decltype(as_grid_desc_ak0_m_ak1),
614  decltype(bs_grid_desc_bk0_n_bk1),
615  decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
616  decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
618  EpilogueArgument,
619  BlockMapMBlockIndex,
620  BlockMapNBlockIndex>(p_as_grid,
621  p_bs_grid,
622  p_ds_grid,
623  p_e_grid,
624  p_shared,
625  as_grid_desc_ak0_m_ak1,
626  bs_grid_desc_bk0_n_bk1,
627  ds_grid_desc_mblock_mperblock_nblock_nperblock,
628  e_grid_desc_mblock_mperblock_nblock_nperblock,
629  block_2_ctile_map,
630  a_element_op,
631  b_element_op,
632  cde_element_op,
633  epilogue_args,
634  A_k_id,
635  B_k_id);
636  }
637 
638  // Overload to pass in custom As/Bs/Ds/E grid descriptors
639  // Used for contraction operations, where tensor transforms are non-trivial
640  template <bool HasMainKBlockLoop,
641  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
642  TailNumber TailNum,
643  typename AsGridDescriptor_AK0_M_AK1,
644  typename BsGridDescriptor_BK0_N_BK1,
645  typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
646  typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
647  typename Block2CTileMap,
648  typename EpilogueArgument,
649  int BlockMapMBlockIndex = 0,
650  int BlockMapNBlockIndex = 1>
651  __device__ static void Run(AsGridPointer& p_as_grid,
652  BsGridPointer& p_bs_grid,
653  DsGridPointer& p_ds_grid,
654  EDataType* p_e_grid,
655  void* p_shared,
656  const AsGridDescriptor_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
657  const BsGridDescriptor_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
658  const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
659  ds_grid_desc_mblock_mperblock_nblock_nperblock,
660  const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
661  e_grid_desc_mblock_mperblock_nblock_nperblock,
662  const Block2CTileMap& block_2_ctile_map,
663  AElementwiseOperation a_element_op,
664  BElementwiseOperation b_element_op,
665  CDEElementwiseOperation cde_element_op,
666  EpilogueArgument& epilogue_args,
667  const index_t A_k_id = 0,
668  const index_t B_k_id = 0)
669  {
670 
671  const auto block_work_idx =
673 
674  if(!block_2_ctile_map.ValidCTileIndex(
675  block_work_idx,
676  make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
677  e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
678  {
679  return;
680  }
681 
682  const index_t block_m_id =
683  __builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapMBlockIndex>{}]);
684  const index_t block_n_id =
685  __builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
686 
687  // BScale struct (Empty)
688  using Scale = typename BlockwiseGemmPipe::Empty;
689  auto a_scale_struct = Scale{};
690  auto b_scale_struct = Scale{};
691 
692  const index_t num_k_block_per_scale = GetKBlockPerScale();
693 
694  Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
695  decltype(bs_grid_desc_bk0_n_bk1),
696  decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
697  decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
698  decltype(a_scale_struct),
699  decltype(b_scale_struct),
700  decltype(epilogue_args),
701  HasMainKBlockLoop,
702  EGlobalMemoryDataOperation,
703  TailNum>(p_as_grid,
704  p_bs_grid,
705  p_ds_grid,
706  p_e_grid,
707  p_shared,
708  as_grid_desc_ak0_m_ak1,
709  bs_grid_desc_bk0_n_bk1,
710  ds_grid_desc_mblock_mperblock_nblock_nperblock,
711  e_grid_desc_mblock_mperblock_nblock_nperblock,
712  a_element_op,
713  b_element_op,
714  cde_element_op,
715  block_m_id,
716  block_n_id,
717  num_k_block_per_scale,
718  a_scale_struct,
719  b_scale_struct,
720  epilogue_args,
721  A_k_id,
722  B_k_id);
723  }
724 
725  template <bool HasMainKBlockLoop,
726  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
727  TailNumber TailNum,
728  typename EpilogueArgument>
729  __device__ static void Run(AsGridPointer& p_as_grid,
730  BsGridPointer& p_bs_grid,
731  DsGridPointer& p_ds_grid,
732  EDataType* p_e_grid,
733  void* p_shared,
734  const Problem& problem,
735  AElementwiseOperation a_element_op,
736  BElementwiseOperation b_element_op,
737  CDEElementwiseOperation cde_element_op,
738  EpilogueArgument& epilogue_args)
739  {
740  Run<HasMainKBlockLoop,
741  EGlobalMemoryDataOperation,
742  TailNum,
744  EpilogueArgument>(p_as_grid,
745  p_bs_grid,
746  p_ds_grid,
747  p_e_grid,
748  p_shared,
749  problem,
750  DefaultBlock2CTileMap(problem),
751  a_element_op,
752  b_element_op,
753  cde_element_op,
754  epilogue_args);
755  }
756 
757  // Wrapper function to have __global__ function in common
758  // between gemm_universal, b_scale, ab_scale, etc.
759  template <bool HasMainKBlockLoop,
760  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
761  TailNumber TailNum,
762  typename Block2CTileMap,
763  typename EpilogueArgument,
764  int BlockMapMBlockIndex = 0,
765  int BlockMapNBlockIndex = 1>
766  __device__ static void Run(void* p_shared,
767  const SplitKBatchOffset& splitk_batch_offset,
768  Argument& karg,
769  const Block2CTileMap& block_2_ctile_map,
770  EpilogueArgument& epilogue_args,
771  const index_t A_k_id = 0,
772  const index_t B_k_id = 0)
773  {
774  // shift A matrices pointer for splitk
775  AsGridPointer p_as_grid_splitk;
776  static_for<0, NumATensor, 1>{}([&](auto i) {
777  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
778  p_as_grid_splitk(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
779  splitk_batch_offset.a_k_split_offset[i];
780  });
781 
782  // shift B matrices pointer for splitk
783  BsGridPointer p_bs_grid_splitk;
784  static_for<0, NumBTensor, 1>{}([&](auto i) {
785  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
786  p_bs_grid_splitk(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
787  splitk_batch_offset.b_k_split_offset[i];
788  });
789 
790  Run<HasMainKBlockLoop,
791  EGlobalMemoryDataOperation,
792  TailNum,
794  EpilogueArgument,
795  BlockMapMBlockIndex,
796  BlockMapNBlockIndex>(p_as_grid_splitk,
797  p_bs_grid_splitk,
798  karg.p_ds_grid,
799  karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
800  p_shared,
801  karg,
802  block_2_ctile_map,
803  karg.a_element_op,
804  karg.b_element_op,
805  karg.cde_element_op,
806  epilogue_args,
807  A_k_id,
808  B_k_id);
809  }
810 
811  // Wrapper function to have __global__ function in common
812  // between gemm_universal, b_scale, ab_scale, etc.
813  template <bool HasMainKBlockLoop,
814  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
815  TailNumber TailNum,
816  typename EpilogueArgument>
817  __device__ static void Run(void* p_shared,
818  const SplitKBatchOffset& splitk_batch_offset,
819  Argument& karg,
820  EpilogueArgument& epilogue_args,
821  const index_t A_k_id = 0,
822  const index_t B_k_id = 0)
823  {
824  Run<HasMainKBlockLoop,
825  EGlobalMemoryDataOperation,
826  TailNum,
828  EpilogueArgument>(p_shared,
829  splitk_batch_offset,
830  karg,
831  DefaultBlock2CTileMap(karg),
832  epilogue_args,
833  A_k_id,
834  B_k_id);
835  }
836 
837  __device__ __host__ static auto DefaultBlock2CTileMap(const Problem& problem)
838  {
839  return DefaultBlock2CTileMap(problem.M, problem.N);
840  }
841  __device__ __host__ static auto DefaultBlock2CTileMap(const index_t M, const index_t N)
842  {
843  return Block2CTileMap{M, N, 4};
844  }
845 
846  // Run method for convolution for bwd_data (grid descriptors are passed as arguments,
847  // not generated internally)
848  template <typename AGridDesc_AK0_M_K1,
849  typename BGridDesc_BK0_N_K1,
850  typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
851  typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
852  typename Block2CTileMapExt,
853  typename ComputePtrOffsetOfBatch,
854  typename ComputePtrOffsetOfN,
855  bool HasMainKBlockLoop,
856  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
857  bool CTranspose,
858  TailNumber TailNum,
859  typename EpilogueArgument>
860  __device__ static void Run(void* p_shared,
861  const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
862  const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
863  const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
864  ds_grid_desc_mblock_mperblock_nblock_nperblock,
865  const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
866  e_grid_desc_mblock_mperblock_nblock_nperblock,
867  const Block2CTileMapExt& block_2_ctile_map,
868  const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
869  const ComputePtrOffsetOfN compute_ptr_offset_of_n,
870  const index_t num_k_per_block,
871  Argument& karg,
872  EpilogueArgument& epilogue_args)
873  {
874  const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
875  const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
876  const index_t k_idx =
877  __builtin_amdgcn_readfirstlane((blockIdx.z - n_idx * karg.KBatch) * num_k_per_block);
878 
879  // offset base pointer for each work-group
880  const long_index_t a_batch_offset =
881  CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))
882  : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
883  const long_index_t b_batch_offset =
884  CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))
885  : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
886  const long_index_t e_batch_offset =
887  amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
888 
889  const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
890 
891  const long_index_t a_n_offset =
892  CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
893  const long_index_t b_n_offset =
894  CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
895  const long_index_t e_n_offset =
896  amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
897 
898  AsGridPointer p_as_grid_;
899  static_for<0, NumATensor, 1>{}([&](auto i) {
900  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
901  p_as_grid_(i) =
902  static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset + a_n_offset;
903  });
904 
905  BsGridPointer p_bs_grid_;
906  static_for<0, NumBTensor, 1>{}([&](auto i) {
907  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
908  p_bs_grid_(i) =
909  static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
910  });
911 
912  DsGridPointer p_ds_grid_grp;
914  [&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_batch_offset[i]; });
915 
916  // Currently supporting one A and one B
917  const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
918  [&](auto i) {
919  ignore = i;
920  return a_grid_desc_ak0_m_ak1;
921  },
923 
924  const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
925  [&](auto i) {
926  ignore = i;
927  return b_grid_desc_bk0_n_bk1;
928  },
930 
931  const auto block_work_idx =
932  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
933 
934  if(!block_2_ctile_map.ValidCTileIndex(
935  block_work_idx,
936  make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
937  e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
938  {
939  return;
940  }
941 
942  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
943  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
944 
945  // AScale struct (Empty)
946  using AScale = typename BlockwiseGemmPipe::Empty;
947  auto a_scale_struct = AScale{};
948 
949  // BScale struct (Empty)
950  using BScale = typename BlockwiseGemmPipe::Empty;
951  auto b_scale_struct = BScale{};
952 
953  const index_t num_k_block_per_scale = GetKBlockPerScale();
954 
955  Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
956  decltype(bs_grid_desc_bk0_n_bk1),
957  decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
958  decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
959  decltype(a_scale_struct),
960  decltype(b_scale_struct),
961  decltype(epilogue_args),
962  HasMainKBlockLoop,
963  EGlobalMemoryDataOperation,
964  TailNum>(p_as_grid_,
965  p_bs_grid_,
966  p_ds_grid_grp,
967  karg.p_e_grid + e_batch_offset + e_n_offset,
968  p_shared,
969  as_grid_desc_ak0_m_ak1,
970  bs_grid_desc_bk0_n_bk1,
971  ds_grid_desc_mblock_mperblock_nblock_nperblock,
972  e_grid_desc_mblock_mperblock_nblock_nperblock,
973  karg.a_element_op,
974  karg.b_element_op,
975  karg.cde_element_op,
976  block_m_id,
977  block_n_id,
978  num_k_block_per_scale,
979  a_scale_struct,
980  b_scale_struct,
981  epilogue_args,
982  k_idx,
983  k_idx,
984  karg.KBatch);
985  }
986 
987  // Run method for convolution (grid descriptors are passed as arguments,
988  // not generated internally)
989  template <typename AGridDesc_AK0_M_K1,
990  typename BGridDesc_BK0_N_K1,
991  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
992  typename ComputePtrOffsetOfBatch,
993  index_t NumGroupsToMerge,
994  bool HasMainKBlockLoop,
995  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
996  TailNumber TailNum,
997  typename EpilogueArgument>
998  __device__ static void Run(void* p_shared,
999  const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
1000  const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
1001  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
1002  c_grid_desc_mblock_mperblock_nblock_nperblock,
1003  const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
1004  const index_t num_k_per_block,
1005  Argument& karg,
1006  EpilogueArgument& epilogue_args)
1007  {
1008  const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
1009  const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
1010 
1011  const long_index_t a_batch_offset =
1012  amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
1013  const long_index_t b_batch_offset =
1014  amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
1015  const long_index_t e_batch_offset =
1016  amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
1017 
1018  AsGridPointer p_as_grid_;
1019  static_for<0, NumATensor, 1>{}([&](auto i) {
1020  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
1021  p_as_grid_(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset;
1022  });
1023 
1024  BsGridPointer p_bs_grid_;
1025  static_for<0, NumBTensor, 1>{}([&](auto i) {
1026  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
1027  p_bs_grid_(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset;
1028  });
1029 
1030  const auto ds_grid_desc_m_n =
1031  MakeDsGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs);
1032 
1033  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1035  ds_grid_desc_m_n, karg.MBlock, karg.NBlock);
1036 
1037  const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
1038  [&](auto i) {
1039  ignore = i;
1040  return a_grid_desc_ak0_m_ak1;
1041  },
1042  Number<NumATensor>{});
1043 
1044  const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
1045  [&](auto i) {
1046  ignore = i;
1047  return b_grid_desc_bk0_n_bk1;
1048  },
1049  Number<NumBTensor>{});
1050 
1051  // divide block work by [M, N]
1052  const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
1053 
1054  const auto block_work_idx =
1055  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1056 
1057  if(!block_2_ctile_map.ValidCTileIndex(
1058  block_work_idx,
1059  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1060  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1061  {
1062  return;
1063  }
1064 
1065  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1066  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1067 
1068  // Scale structs (Empty)
1069  using Scale = typename BlockwiseGemmPipe::Empty;
1070  auto b_scale_struct = Scale{};
1071  auto a_scale_struct = Scale{};
1072 
1073  const index_t num_k_block_per_scale = GetKBlockPerScale();
1074 
1075  Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
1076  decltype(bs_grid_desc_bk0_n_bk1),
1077  decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
1078  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1079  decltype(a_scale_struct),
1080  decltype(b_scale_struct),
1081  decltype(epilogue_args),
1082  HasMainKBlockLoop,
1083  CGlobalMemoryDataOperation,
1084  TailNum>(p_as_grid_,
1085  p_bs_grid_,
1086  karg.p_ds_grid,
1087  karg.p_e_grid + e_batch_offset,
1088  p_shared,
1089  as_grid_desc_ak0_m_ak1,
1090  bs_grid_desc_bk0_n_bk1,
1091  ds_grid_desc_mblock_mperblock_nblock_nperblock,
1092  c_grid_desc_mblock_mperblock_nblock_nperblock,
1093  karg.a_element_op,
1094  karg.b_element_op,
1095  karg.cde_element_op,
1096  block_m_id,
1097  block_n_id,
1098  num_k_block_per_scale,
1099  a_scale_struct,
1100  b_scale_struct,
1101  epilogue_args,
1102  k_idx,
1103  k_idx,
1104  karg.KBatch);
1105  }
1106 
1107  // Run method for convolution fwd (grid descriptors are passed as arguments,
1108  // not generated internally)
1109  template <typename AGridDesc_AK0_M_K1,
1110  typename BGridDesc_BK0_N_K1,
1111  typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
1112  typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1113  typename ComputePtrOffsetOfBatch,
1114  typename ComputePtrOffsetOfN,
1115  bool HasMainKBlockLoop,
1116  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
1117  TailNumber TailNum,
1118  typename EpilogueArgument>
1119  __device__ static void Run(void* p_shared,
1120  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1121  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1122  const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
1123  ds_grid_desc_mblock_mperblock_nblock_nperblock,
1124  const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1125  e_grid_desc_mblock_mperblock_nblock_nperblock,
1126  const ComputePtrOffsetOfBatch& compute_ptr_offset_of_batch,
1127  const ComputePtrOffsetOfN& compute_ptr_offset_of_n,
1128  [[maybe_unused]] const index_t num_k_per_block,
1129  Argument& karg,
1130  EpilogueArgument& epilogue_args)
1131  {
1132  const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
1133  const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
1134  // offset base pointer for each work-group
1135  const long_index_t a_batch_offset =
1136  amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
1137  const long_index_t b_batch_offset =
1138  amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
1139  const long_index_t e_batch_offset =
1140  amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
1141 
1142  const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
1143 
1144  const long_index_t a_n_offset =
1145  amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
1146  const long_index_t b_n_offset =
1147  amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx));
1148  const long_index_t e_n_offset =
1149  amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
1150 
1151  const auto ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
1152 
1153  AsGridPointer p_as_grid_;
1154  static_for<0, NumATensor, 1>{}([&](auto i) {
1155  using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
1156  p_as_grid_(i) =
1157  static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset + a_n_offset;
1158  });
1159 
1160  BsGridPointer p_bs_grid_;
1161  static_for<0, NumBTensor, 1>{}([&](auto i) {
1162  using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
1163  p_bs_grid_(i) =
1164  static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
1165  });
1166 
1167  DsGridPointer p_ds_grid_grp;
1168  static_for<0, NumDTensor, 1>{}([&](auto i) {
1169  using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
1170  p_ds_grid_grp(i) = static_cast<const DDataType_*>(karg.p_ds_grid[i]) +
1171  ds_batch_offset[i] + ds_n_offset[i];
1172  });
1173 
1174  // Currently supporting one A and one B
1175  const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
1176  [&](auto i) {
1177  ignore = i;
1178  return a_grid_desc_ak0_m_ak1;
1179  },
1180  Number<NumATensor>{});
1181 
1182  const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
1183  [&](auto i) {
1184  ignore = i;
1185  return b_grid_desc_bk0_n_bk1;
1186  },
1187  Number<NumBTensor>{});
1188 
1189  // divide block work by [M, N]
1190  const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
1191 
1192  const auto block_work_idx =
1193  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1194 
1195  if(!block_2_ctile_map.ValidCTileIndex(
1196  block_work_idx,
1197  make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1198  e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1199  {
1200  return;
1201  }
1202 
1203  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1204  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1205 
1206  // AScale struct (Empty)
1207  using AScale = typename BlockwiseGemmPipe::Empty;
1208  auto a_scale_struct = AScale{};
1209 
1210  // BScale struct (Empty)
1211  using BScale = typename BlockwiseGemmPipe::Empty;
1212  auto b_scale_struct = BScale{};
1213 
1214  const index_t num_k_block_per_scale = GetKBlockPerScale();
1215 
1216  Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
1217  decltype(bs_grid_desc_bk0_n_bk1),
1218  decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
1219  decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
1220  decltype(a_scale_struct),
1221  decltype(b_scale_struct),
1222  decltype(epilogue_args),
1223  HasMainKBlockLoop,
1224  EGlobalMemoryDataOperation,
1225  TailNum>(p_as_grid_,
1226  p_bs_grid_,
1227  p_ds_grid_grp,
1228  karg.p_e_grid + e_batch_offset + e_n_offset,
1229  p_shared,
1230  as_grid_desc_ak0_m_ak1,
1231  bs_grid_desc_bk0_n_bk1,
1232  ds_grid_desc_mblock_mperblock_nblock_nperblock,
1233  e_grid_desc_mblock_mperblock_nblock_nperblock,
1234  karg.a_element_op,
1235  karg.b_element_op,
1236  karg.cde_element_op,
1237  block_m_id,
1238  block_n_id,
1239  num_k_block_per_scale,
1240  a_scale_struct,
1241  b_scale_struct,
1242  epilogue_args);
1243  }
1244 };
1245 
1246 } // namespace ck
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:209
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:279
BlockGemmPipelineVersion
Block GEMM pipeline version enumeration.
Definition: scheduler_enum.hpp:17
int64_t long_index_t
Definition: ck.hpp:302
TailNumber
Tail number enumeration for pipeline buffering.
Definition: scheduler_enum.hpp:49
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
BlockGemmPipelineScheduler
Block GEMM pipeline scheduler enumeration.
Definition: scheduler_enum.hpp:33
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:212
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:301
Definition: block_to_ctile_map.hpp:271
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:298
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:384
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:415
BsGridPointer p_bs_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:479
BElementwiseOperation b_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:484
bool is_reduce
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:488
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:468
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:473
AsGridPointer p_as_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:478
EDataType * p_e_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:481
AElementwiseOperation a_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:483
CDEElementwiseOperation cde_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:485
DsGridPointer p_ds_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:480
__host__ __device__ Argument(std::array< const void *, NumATensor > p_as_grid_, std::array< const void *, NumBTensor > p_bs_grid_, std::array< const void *, NumDTensor > p_ds_grid_, EDataType *p_e_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_, bool is_reduce_=false)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:417
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:338
std::array< index_t, NumBTensor > StrideBs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:398
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:401
index_t BK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:407
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:396
index_t Kt
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:410
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:394
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:368
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:395
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:399
index_t MPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:402
index_t MBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:408
index_t AK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:406
__host__ __device__ Problem(index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t KBatch_)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:340
index_t KRead
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:404
index_t StrideE
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:400
index_t NBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:409
std::array< index_t, NumATensor > StrideAs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:397
index_t KPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:405
index_t NPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:403
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:492
index_t c_reduce_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:561
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:494
std::array< index_t, NumATensor > a_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:559
std::array< index_t, NumBTensor > b_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:560
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:297
decltype(MakeBsGridPointer()) BsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:592
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:300
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:848
static constexpr auto AK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:327
__host__ static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:971
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:310
decltype(MakeAsGridPointer()) AsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:591
__host__ static __device__ auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:773
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:299
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:559
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:540
static constexpr auto BK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:326
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:552
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:354
static constexpr auto I3
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:302
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:529
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const BaseDescriptors_M_K &base_descs, const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:620
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:524
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:519
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, LDSTypeA, LDSTypeB, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor()), decltype(MakeBWmmaTileDescriptor()), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, false, IsBPreShuffled >())> BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:896
static constexpr index_t NumBTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:309
static constexpr auto I4
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:303
static constexpr index_t NumATensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:308
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:347
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:846
static constexpr auto I6
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:305
static constexpr auto BK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:328
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:301
static constexpr auto AK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:325
static constexpr auto I5
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:304
static constexpr auto I7
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:306
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const BaseDescriptors_N_K &base_descs, const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:689
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:564
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:534
__host__ static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:861
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:237
decltype(MakeBsGridPointer()) BsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:592
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:743
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:729
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:300
BsDataType BsDataType_
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:335
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:848
__device__ static __host__ auto DefaultBlock2CTileMap(const Problem &problem)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:837
typename Base::BlockwiseGemmPipe BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:564
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const Problem &problem, const Block2CTileMap &block_2_ctile_map, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument &epilogue_args, const index_t A_k_id=0, const index_t B_k_id=0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:580
__host__ static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:971
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:310
decltype(MakeAsGridPointer()) AsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:591
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:326
BlockToCTileMap_Grouped_M00_N0_M01Adapt< 8, MPerBlock, NPerBlock > Block2CTileMap
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:568
AsDataType AsDataType_
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:334
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:299
static __device__ void Run(void *p_shared, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMapExt &block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfN compute_ptr_offset_of_n, const index_t num_k_per_block, Argument &karg, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:860
static __device__ void Run(void *p_shared, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch &compute_ptr_offset_of_batch, const ComputePtrOffsetOfN &compute_ptr_offset_of_n, [[maybe_unused]] const index_t num_k_per_block, Argument &karg, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1119
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const AsGridDescriptor_AK0_M_AK1 as_grid_desc_ak0_m_ak1, const BsGridDescriptor_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap &block_2_ctile_map, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument &epilogue_args, const index_t A_k_id=0, const index_t B_k_id=0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:651
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:559
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:540
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:552
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:354
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:529
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:524
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:519
__device__ static __host__ auto DefaultBlock2CTileMap(const index_t M, const index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:841
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg, const Block2CTileMap &block_2_ctile_map, EpilogueArgument &epilogue_args, const index_t A_k_id=0, const index_t B_k_id=0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:766
static constexpr index_t NumBTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:309
static constexpr index_t NumATensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:308
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:347
static __device__ void Run(void *p_shared, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const index_t num_k_per_block, Argument &karg, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:998
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:846
static __device__ index_t GetKBlockPerScale()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:571
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:301
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg, EpilogueArgument &epilogue_args, const index_t A_k_id=0, const index_t B_k_id=0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:817
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:564
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:534
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:674
__host__ static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:861
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: device_base.hpp:270