/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp Source File
universal_gemm_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
18 
19 namespace ck_tile {
20 
32 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
34 {
36  const std::array<const void*, NumATensor>& as_ptr_,
37  const std::array<const void*, NumBTensor>& bs_ptr_,
38  const std::array<const void*, NumDTensor>& ds_ptr_,
39  void* e_ptr_,
40  index_t k_batch_,
41  index_t M_,
42  index_t N_,
43  index_t K_,
44  const std::array<index_t, NumATensor>& stride_As_,
45  const std::array<index_t, NumBTensor>& stride_Bs_,
46  const std::array<index_t, NumDTensor>& stride_Ds_,
47  index_t stride_E_,
49  : as_ptr(as_ptr_),
50  bs_ptr(bs_ptr_),
51  ds_ptr(ds_ptr_),
52  e_ptr(e_ptr_),
53  M(M_),
54  N(N_),
55  K(K_),
56  stride_As(stride_As_),
57  stride_Bs(stride_Bs_),
58  stride_Ds(stride_Ds_),
59  stride_E(stride_E_),
60  k_batch(k_batch_),
61  async_input_scheduler(async_input_scheduler_)
62  {
63  }
64 
65  const std::array<const void*, NumATensor> as_ptr;
66  const std::array<const void*, NumBTensor> bs_ptr;
67  const std::array<const void*, NumDTensor> ds_ptr;
68  union
69  {
70  void* e_ptr;
71  void* c_ptr;
72  };
76  const std::array<index_t, NumATensor> stride_As;
77  const std::array<index_t, NumBTensor> stride_Bs;
78  const std::array<index_t, NumDTensor> stride_Ds;
79  union
80  {
83  };
84 
87 };
88 
90 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
92 {
94  const std::array<const void*, NumATensor> as_ptr;
96  const std::array<const void*, NumBTensor> bs_ptr;
98  const std::array<const void*, NumDTensor> ds_ptr;
100  void* e_ptr;
109  std::array<index_t, NumATensor> stride_As;
112  std::array<index_t, NumBTensor> stride_Bs;
115  std::array<index_t, NumDTensor> stride_Ds;
122 };
123 
160 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
162 {
166 
167  static constexpr bool ADataTypeIsTuple =
169  static constexpr bool BDataTypeIsTuple =
171  static constexpr bool DDataTypeIsTuple =
173  static constexpr bool ALayoutIsTuple =
175  static constexpr bool BLayoutIsTuple =
177  static constexpr bool DLayoutIsTuple =
179 
186 
190 
194 
198 
199  using DsDataType =
203 
206 
209 
210  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
211 
212  // Detect persistent kernel support to select appropriate entry point
214  {
215  template <typename T>
216  using has_persistent_type = decltype(T::UsePersistentKernel);
217 
218  static constexpr bool value = []() {
220  return GemmPipeline::UsePersistentKernel;
221  else
222  return false;
223  }();
224  };
226 
227  // Detect custom output offset support for advanced partitioning schemes
229  {
230  template <typename T, typename KernelArgs>
232  decltype(T::GetOutputOffset(std::declval<KernelArgs>(), std::declval<index_t>()));
233 
234  static constexpr bool value = []() {
236  return true;
237  else
238  return false;
239  }();
240  };
241  static constexpr bool has_tile_partitioner_output_offset =
243 
244  static constexpr auto I0 = number<0>();
245  static constexpr auto I1 = number<1>();
246  static constexpr auto I2 = number<2>();
247  static constexpr auto I3 = number<3>{};
248 
249  static constexpr index_t NumATensor = AsDataType::size();
250  static constexpr index_t NumBTensor = BsDataType::size();
251  static constexpr index_t NumDTensor = DsDataType::size();
252 
255 
256  static_assert(AsLayout::size() == AsDataType::size(),
257  "The size of AsLayout and AsDataType should be the same");
258 
259  static_assert(BsLayout::size() == BsDataType::size(),
260  "The size of BsLayout and BsDataType should be the same");
261 
262  static_assert(DsLayout::size() == DsDataType::size(),
263  "The size of DsLayout and DsDataType should be the same");
264 
265  static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
266 
267  using KernelArgs =
268  UniversalGemmKernelArgs<AsLayout::size(), BsLayout::size(), DsLayout::size()>;
269 
270  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
271  {
272  // clang-format off
273  return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>(), GemmPipeline::GetName());
274  // clang-format on
275  }
276 
277  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
278  {
279  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
280  }
281 
288  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
289  {
291  const auto kernel = kentry<1, Kernel, KernelArgs>;
292  int occupancy;
294  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0));
295 
296  const int grid_size = get_available_compute_units(s) * occupancy;
297  return dim3(grid_size, 1, 1);
298  }
299 
300  CK_TILE_HOST static auto BlockSize()
301  {
302  if(ck_tile::is_wave32())
303  {
304  return dim3(kBlockSize / 2);
305  }
306  else
307  {
308  return dim3(kBlockSize);
309  }
310  }
311 
312  CK_TILE_HOST static constexpr KernelArgs
314  {
315  return KernelArgs{hostArgs.as_ptr,
316  hostArgs.bs_ptr,
317  hostArgs.ds_ptr,
318  hostArgs.e_ptr,
319  hostArgs.M,
320  hostArgs.N,
321  hostArgs.K,
322  hostArgs.stride_As,
323  hostArgs.stride_Bs,
324  hostArgs.stride_Ds,
325  hostArgs.stride_E,
326  hostArgs.k_batch,
327  hostArgs.async_input_scheduler};
328  }
329 
331  {
332  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
333  }
334 
336  {
337  // Balances K-dimension work across batches to maximize parallelism while minimizing
338  // load imbalance. Uses ceil division to distribute remainder work evenly.
339  __device__ SplitKBatchOffset(const KernelArgs& kargs, const index_t k_id = blockIdx.z)
340  {
341  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
342  const index_t num_all = amd_wave_read_first_lane(
343  kargs.K / K1); // num of all loops not including potential tail
344  index_t num_full = amd_wave_read_first_lane(num_all % kargs.k_batch);
345  num_full = num_full == 0 ? kargs.k_batch : num_full;
346 
347  const index_t num_full_iters =
349  const index_t full_k_read = num_full_iters * K1;
350  const index_t partial_k_read = (num_full_iters - 1) * K1;
351 
352  static_for<0, NumATensor, 1>{}([&](auto index) {
353  using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
354  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, AiLayout>)
355  {
356  as_k_split_offset[index] =
357  amd_wave_read_first_lane(std::min(k_id, num_full) * full_k_read +
358  std::max(k_id - num_full, 0) * partial_k_read);
359  }
360  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, AiLayout>)
361  {
362  as_k_split_offset[index] =
363  amd_wave_read_first_lane((std::min(k_id, num_full) * full_k_read +
364  std::max(k_id - num_full, 0) * partial_k_read) *
365  kargs.stride_As[index]);
366  }
367  });
368 
369  static_for<0, NumBTensor, 1>{}([&](auto index) {
370  using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
371  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BiLayout>)
372  {
373  bs_k_split_offset[index] =
374  amd_wave_read_first_lane((std::min(k_id, num_full) * full_k_read +
375  std::max(k_id - num_full, 0) * partial_k_read) *
376  kargs.stride_Bs[index]);
377  }
378  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BiLayout>)
379  {
380  bs_k_split_offset[index] =
381  amd_wave_read_first_lane(std::min(k_id, num_full) * full_k_read +
382  std::max(k_id - num_full, 0) * partial_k_read);
383  }
384  });
385 
386  if(k_id == kargs.k_batch - 1)
387  {
388  splitted_k = kargs.K - std::min(k_id, num_full) * full_k_read -
389  std::max(k_id - num_full, 0) * partial_k_read;
390  }
391  else if(k_id < num_full)
392  {
393  splitted_k = full_k_read;
394  }
395  else
396  {
397  splitted_k = partial_k_read;
398  }
399  }
400 
401  std::array<index_t, NumATensor> as_k_split_offset;
402  std::array<index_t, NumBTensor> bs_k_split_offset;
404  };
405 
406  CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
407  {
408  if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
410  {
411  if(kargs.k_batch != 1)
412  {
413  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
414  {
415  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
416  }
417  return false;
418  }
419  }
420 
421  if(kargs.K < GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
422  {
423  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
424  {
425  CK_TILE_ERROR("KBatch is too large, part of GPU wouldn't be utilized!");
426  }
427  return false;
428  }
429 
430  const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
431  : GemmPipeline::template GetVectorSizeA<false>();
432  bool AsTensorIsValid = {true};
433  static_for<0, NumATensor, 1>{}([&](auto index) {
434  using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
435  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
436  {
437  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
438  GemmPipeline::kPadK == false)
439  {
440  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
441  {
443  "Can't support K that is not a multiple of k_batch * KPerBlock "
444  "without padding!");
445  }
446  AsTensorIsValid = false;
447  }
448  if(kargs.K % vectorSizeA != 0)
449  {
450  const auto remainder = kargs.K % vectorSizeA;
451  constexpr ck_tile::index_t APackedSize =
453  const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize;
454  // oob can support to dword level
455  if(remainder_in_bytes % 4 == 0)
456  {
457  AsTensorIsValid = true;
458  }
459  else
460  {
461  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
462  {
463  CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
464  }
465  AsTensorIsValid = false;
466  }
467  }
468  }
469  else
470  {
471  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
472  {
473  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
474  {
476  "Can't support M that is not a multiple of MPerBlock without padding!");
477  }
478  AsTensorIsValid = false;
479  }
480  if(kargs.M % vectorSizeA != 0)
481  {
482  const auto remainder = kargs.M % vectorSizeA;
483  constexpr ck_tile::index_t APackedSize =
485  const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize;
486  // oob can support to dword level
487  if(remainder_in_bytes % 4 == 0)
488  {
489 
490  AsTensorIsValid = true;
491  }
492  else
493  {
494  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
495  {
496  CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
497  }
498  AsTensorIsValid = false;
499  }
500  }
501  }
502  });
503 
504  bool BsTensorIsValid = {true};
505  const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB<true>()
506  : GemmPipeline::template GetVectorSizeB<false>();
507  static_for<0, NumBTensor, 1>{}([&](auto index) {
508  using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
509  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
510  {
511  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
512  {
513  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
514  {
516  "Can't support N that is not a multiple of NPerBlock without padding!");
517  }
518  BsTensorIsValid = false;
519  }
520  if(kargs.N % vectorSizeB != 0)
521  {
522  const auto remainder = kargs.N % vectorSizeB;
523  constexpr ck_tile::index_t BPackedSize =
525  const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize;
526  // oob can support to dword level
527  if(remainder_in_bytes % 4 == 0)
528  {
529  BsTensorIsValid = true;
530  }
531  else
532  {
533  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
534  {
535  CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
536  }
537  BsTensorIsValid = false;
538  }
539  }
540  else
541  {
542  if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
543  GemmPipeline::kPadK == false)
544  {
545  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
546  {
548  "Can't support K that is not a multiple of k_batch * KPerBlock "
549  "without padding!");
550  }
551  BsTensorIsValid = false;
552  }
553  if(kargs.K % vectorSizeB != 0)
554  {
555  const auto remainder = kargs.K % vectorSizeB;
556  constexpr ck_tile::index_t BPackedSize =
558  const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize;
559  // oob can support to dword level
560  if(remainder_in_bytes % 4 == 0)
561  {
562  BsTensorIsValid = true;
563  }
564  else
565  {
566  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
567  {
569  "K is not a multiple of vector load size for B tensor!");
570  }
571  BsTensorIsValid = false;
572  }
573  }
574  }
575  }
576  });
577 
578  bool DTensorIsValid = {true};
579  static_for<0, NumDTensor, 1>{}([&](auto index) {
580  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
581  if(std::is_same_v<DiLayout, CLayout> == false)
582  {
583  DTensorIsValid = false;
584  }
585  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
586  {
587  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
588  {
589  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
590  {
591  CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
592  "NPerBlock without padding!");
593  }
594  DTensorIsValid = false;
595  }
596  if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
597  {
598  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
599  {
600  CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
601  }
602  DTensorIsValid = false;
603  }
604  }
605  else
606  {
607  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
608  {
609  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
610  {
611  CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
612  "MPerBlock without padding!");
613  }
614  DTensorIsValid = false;
615  }
616  if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
617  {
618  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
619  {
620  CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
621  }
622  DTensorIsValid = false;
623  }
624  }
625  });
626 
627  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
628  {
629  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
630  {
631  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
632  {
634  "Can't support N that is not a multiple of NPerBlock without padding!");
635  }
636  return false;
637  }
638  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
639  {
640  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
641  {
642  CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
643  }
644  return false;
645  }
646  }
647  else
648  {
649  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
650  {
651  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
652  {
654  "Can't support M that is not a multiple of MPerBlock without padding!");
655  }
656  return false;
657  }
658  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
659  {
660  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
661  {
662  CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
663  }
664  return false;
665  }
666  }
667 
668  // Verify async scheduler parameters to prevent division-by-zero and invalid memory access
669  if(kargs.async_input_scheduler.chunk_signals != nullptr)
670  {
672  {
673  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
674  {
675  CK_TILE_ERROR("tiles_per_chunk_m must be positive when chunk_signals is set!");
676  }
677  return false;
678  }
679  if(kargs.async_input_scheduler.num_chunks == 0)
680  {
681  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
682  {
683  CK_TILE_ERROR("num_chunks must be positive when chunk_signals is set!");
684  }
685  return false;
686  }
687  }
688 
689  return AsTensorIsValid && BsTensorIsValid && DTensorIsValid;
690  }
691 
692  CK_TILE_DEVICE static auto
693  MakeABlockWindows(const std::array<const ADataType*, NumATensor>& as_ptr,
694  const KernelArgs& kargs,
695  const index_t k_size,
696  const index_t i_m)
697  {
698  // Step 1: Create tensor views for A tensors (from MakeGemmTensorViews)
699  const auto& as_tensor_view = generate_tuple(
700  [&](auto i) {
701  using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
702  using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
703  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
704  {
705  return make_naive_tensor_view<address_space_enum::global>(
706  static_cast<const AiDataType*>(as_ptr[i]),
707  make_tuple(kargs.M, k_size),
708  make_tuple(kargs.stride_As[i], 1),
709  number<GemmPipeline::GetVectorSizeA()>{},
710  number<1>{});
711  }
712  else
713  {
714  return make_naive_tensor_view<address_space_enum::global>(
715  static_cast<const AiDataType*>(as_ptr[i]),
716  make_tuple(k_size, kargs.M),
717  make_tuple(kargs.stride_As[i], 1),
718  number<GemmPipeline::GetVectorSizeA()>{},
719  number<1>{});
720  }
721  },
723 
724  // Step 2: Create padded views (from MakeGemmPadViews)
725  const auto& as_pad_view = generate_tuple(
726  [&](auto i) {
727  using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
728  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
729  {
730  return pad_tensor_view(as_tensor_view[i],
734  }
735  else
736  {
737  return pad_tensor_view(as_tensor_view[i],
741  }
742  },
744 
745  // Step 3: Create tile windows (from MakeGemmTileWindows)
746  const auto& as_block_window = generate_tuple(
747  [&](auto i) {
748  using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
749  if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
750  {
751  return make_tile_window(as_pad_view[i],
754  {i_m, 0});
755  }
756  else
757  {
758  return make_tile_window(as_pad_view[i],
761  {0, i_m});
762  }
763  },
765 
766  return as_block_window;
767  }
768 
769  CK_TILE_DEVICE static auto
770  MakeBBlockWindows(const std::array<const BDataType*, NumBTensor>& bs_ptr,
771  const KernelArgs& kargs,
772  const index_t k_size,
773  const index_t i_n)
774  {
775  // Step 1: Create tensor views for B tensors (from MakeGemmTensorViews)
776  const auto& bs_tensor_view = generate_tuple(
777  [&](auto i) {
778  using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
779  using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
780  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
781  {
782  if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
783  {
784  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
785  const index_t K0 = k_size / K1;
786  constexpr index_t VectorSizeB =
787  std::min(K1, GemmPipeline::GetVectorSizeB());
788  const auto b_k0_n_k1_desc =
790  make_tuple(kargs.N * K1, K1, I1),
792  number<1>{});
793  const auto b_n_k_desc = transform_tensor_descriptor(
794  b_k0_n_k1_desc,
799  return make_tensor_view<address_space_enum::global>(
800  static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
801  }
802  else
803  {
804  return make_naive_tensor_view<address_space_enum::global>(
805  bs_ptr[i],
806  make_tuple(k_size, kargs.N),
807  make_tuple(kargs.stride_Bs[i], 1),
808  number<GemmPipeline::GetVectorSizeB()>{},
809  number<1>{});
810  }
811  }
812  else
813  {
814  if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
815  {
816  constexpr index_t K1 = GemmPipeline::GetSmemPackB();
817  const index_t K0 = k_size / K1;
818  constexpr index_t VectorSizeB =
819  std::min(K1, GemmPipeline::GetVectorSizeB());
820  const auto b_k0_n_k1_desc =
822  make_tuple(kargs.N * K1, K1, I1),
824  number<1>{});
825  const auto b_n_k_desc = transform_tensor_descriptor(
826  b_k0_n_k1_desc,
831  return make_tensor_view<address_space_enum::global>(
832  static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
833  }
834  else
835  {
836  if constexpr(GemmPipeline::Preshuffle)
837  {
838  index_t kFlatK =
839  GemmPipeline::BlockGemmShape::flatKPerWarp *
840  (k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
841  index_t kFlatN = kargs.N * kargs.K / kFlatK;
842 
843  return make_naive_tensor_view<address_space_enum::global>(
844  bs_ptr[i],
845  make_tuple(kFlatN, kFlatK),
846  make_tuple(kFlatK, 1),
847  number<GemmPipeline::GetVectorSizeB()>{},
848  number<1>{});
849  }
850  else
851  {
852  return make_naive_tensor_view<address_space_enum::global>(
853  bs_ptr[i],
854  make_tuple(kargs.N, k_size),
855  make_tuple(kargs.stride_Bs[i], 1),
856  number<GemmPipeline::GetVectorSizeB()>{},
857  number<1>{});
858  }
859  }
860  }
861  },
863 
864  // Step 2: Create padded views (from MakeGemmPadViews)
865  const auto& bs_pad_view = generate_tuple(
866  [&](auto i) {
867  using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
868  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
869  {
870  return pad_tensor_view(bs_tensor_view[i],
874  }
875  else
876  {
877  return pad_tensor_view(bs_tensor_view[i],
881  }
882  },
884 
885  // Step 3: Create tile windows (from MakeGemmTileWindows)
886  const auto& bs_block_window = generate_tuple(
887  [&](auto i) {
888  using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
889  if constexpr(GemmPipeline::Preshuffle)
890  {
891  return make_tile_window(
892  bs_pad_view[i],
895  {static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)),
896  0});
897  }
898  else
899  {
900  if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
901  {
902  return make_tile_window(bs_pad_view[i],
905  {i_n, 0});
906  }
907  else
908  {
909  return make_tile_window(bs_pad_view[i],
912  {0, i_n});
913  }
914  }
915  },
917 
918  return bs_block_window;
919  }
920 
921  CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
922  const KernelArgs& kargs,
923  const index_t i_m,
924  const index_t i_n)
925  {
926  // Step 1: Create tensor views for D tensors (from MakeGemmTensorViews)
927  const auto& ds_tensor_view = generate_tuple(
928  [&](auto i) {
929  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
930  using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
931  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
932  {
933  return make_naive_tensor_view<address_space_enum::global>(
934  static_cast<const DDataType_*>(ds_ptr[i]),
935  make_tuple(kargs.M, kargs.N),
936  make_tuple(kargs.stride_Ds[i], 1),
937  number<EpiloguePipeline::GetVectorSizeD(i)>{},
938  number<1>{});
939  }
940  else
941  {
942  return make_naive_tensor_view<address_space_enum::global>(
943  static_cast<const DDataType_*>(ds_ptr[i]),
944  make_tuple(kargs.N, kargs.M),
945  make_tuple(kargs.stride_Ds[i], 1),
946  number<EpiloguePipeline::GetVectorSizeD(i)>{},
947  number<1>{});
948  }
949  },
951 
952  // Step 2: Create padded views (from MakeGemmPadViews)
953  const auto& ds_pad_view = generate_tuple(
954  [&](auto i) {
955  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
956  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
957  {
958  return pad_tensor_view(ds_tensor_view[i],
962  }
963  else
964  {
965  return pad_tensor_view(ds_tensor_view[i],
969  }
970  },
972 
973  // Step 3: Create tile windows (from MakeGemmTileWindows)
974  const auto& ds_block_window = generate_tuple(
975  [&](auto i) {
976  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
977  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
978  {
979  return make_tile_window(ds_pad_view[i],
982  {i_m, i_n});
983  }
984  else
985  {
986  return make_tile_window(ds_pad_view[i],
989  {i_n, i_m});
990  }
991  },
993 
994  return ds_block_window;
995  }
996 
997  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
999  const KernelArgs& kargs,
1000  const index_t i_m,
1001  const index_t i_n)
1002  {
1003  // Step 1: Create tensor view for E/C tensor (from MakeGemmTensorViews)
1004  const auto& e_tensor_view = [&]() {
1005  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
1006  {
1007  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
1008  e_ptr,
1009  make_tuple(kargs.M, kargs.N),
1010  make_tuple(kargs.stride_E, 1),
1011  number<EpiloguePipeline::GetVectorSizeC()>{},
1012  number<1>{});
1013  }
1014  else
1015  {
1016  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
1017  e_ptr,
1018  make_tuple(kargs.M, kargs.N),
1019  make_tuple(1, kargs.stride_E),
1020  number<1>{},
1021  number<1>{});
1022  }
1023  }();
1024 
1025  // Step 2: Create padded view (from MakeGemmPadViews)
1026  const auto& e_pad_view = [&]() {
1027  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
1028  {
1029  return pad_tensor_view(e_tensor_view,
1033  }
1034  else
1035  {
1036  return pad_tensor_view(e_tensor_view,
1040  }
1041  }();
1042 
1043  // Step 3: Create tile window (from MakeGemmTileWindows)
1044  auto e_block_window = make_tile_window(
1045  e_pad_view,
1047  {i_m, i_n});
1048 
1049  return e_block_window;
1050  }
1051 
1066  CK_TILE_DEVICE static void RunGemm(const std::array<const ADataType*, NumATensor>& as_ptr,
1067  const std::array<const BDataType*, NumBTensor>& bs_ptr,
1068  const std::array<const void*, NumDTensor>& ds_ptr,
1069  EDataType* e_ptr,
1070  void* smem_ptr,
1071  const KernelArgs& kargs,
1072  const SplitKBatchOffset& splitk_batch_offset,
1073  const index_t block_idx_m,
1074  const index_t block_idx_n)
1075  {
1076  // Create block windows using specialized methods
1077  const auto& as_block_window =
1078  MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
1079  const auto& bs_block_window =
1080  MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
1081  const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
1082 
1083  const index_t num_loop =
1084  amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
1085 
1086  // Run GEMM cooperatively by whole workgroup.
1087  const auto& c_block_tile = GemmPipeline{}.template operator()(
1088  as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr);
1089 
1090  const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
1091  // Run Epilogue Pipeline
1092  if(k_batch == 1)
1093  {
1094  auto c_block_window = MakeCBlockWindows<memory_operation_enum::set>(
1095  e_ptr, kargs, block_idx_m, block_idx_n);
1096  EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
1097  }
1098  else
1099  {
1100  auto c_block_window = MakeCBlockWindows<memory_operation_enum::atomic_add>(
1101  e_ptr, kargs, block_idx_m, block_idx_n);
1102  EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
1103  }
1104  }
1105 
1106  CK_TILE_DEVICE static auto
1108  {
1109  index_t iM, iN;
1110 
1111  // Regular launch: use 1D block indexing
1112  const auto blockId = amd_wave_read_first_lane(blockIdx.x);
1113  const auto [tile_m, tile_n] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
1114  iM = tile_m;
1115  iN = tile_n;
1116 
1117  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1118  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1119 
1120  return make_tuple(i_m, i_n);
1121  }
1122 
1123  // Helper functions
1125  {
1126  // For 1D regular launch
1127  return amd_wave_read_first_lane(get_block_id());
1128  }
1129 
1131  {
1132  // For 1D regular launch
1134  }
1135 
1136  // Helper to get total number of tiles, handling both dim3 and index_t return types
1137  template <typename... Args>
1138  CK_TILE_HOST_DEVICE static auto GetNumTiles(Args&&... args) -> index_t
1139  {
1140  auto grid_size = TilePartitioner::GridSize(std::forward<Args>(args)...);
1141 
1142  using GridSizeType = decltype(grid_size);
1143 
1144  if constexpr(std::is_same_v<GridSizeType, dim3>)
1145  {
1146  // GridSize returns dim3: compute total tiles as x * y * z
1147  return amd_wave_read_first_lane(grid_size.x * grid_size.y * grid_size.z);
1148  }
1149  else
1150  {
1151  // GridSize returns scalar (index_t): use directly
1152  return amd_wave_read_first_lane(grid_size);
1153  }
1154  }
1155 
1156  // Non-persistent kernel entry point
1157  template <bool U = !PersistentKernel, typename = std::enable_if_t<U>>
1159  {
1160  const auto blockId = amd_wave_read_first_lane(blockIdx.x);
1161  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
1162  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1163  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1164 
1165  const SplitKBatchOffset splitk_batch_offset(kargs);
1166 
1167  // options
1168  std::array<const ADataType*, NumATensor> as_ptr;
1169  static_for<0, NumATensor, 1>{}([&](auto i) {
1170  as_ptr[i] = static_cast<const ADataType*>(kargs.as_ptr[i]) +
1171  splitk_batch_offset.as_k_split_offset[i];
1172  });
1173 
1174  std::array<const BDataType*, NumBTensor> bs_ptr;
1175  static_for<0, NumBTensor, 1>{}([&](auto i) {
1176  bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]) +
1177  splitk_batch_offset.bs_k_split_offset[i];
1178  });
1179 
1180  // Calculate output offset from tile partitioner and apply to output pointer
1181  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
1183  {
1184  const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, blockIdx.z);
1185  e_ptr += output_offset;
1186  }
1187 
1188  // allocate LDS
1189  __shared__ char smem_ptr[GetSmemSize()];
1190 
1191  RunGemm(
1192  as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
1193  }
1194 
1195  // Persistent kernel entry point
1196  template <bool U = PersistentKernel, typename = std::enable_if_t<U>, typename = void>
1198  {
1199  const auto grid_size = amd_wave_read_first_lane(get_grid_size());
1200  const auto num_tiles =
1201  amd_wave_read_first_lane(TilePartitioner::GridSize(kargs.M, kargs.N));
1202  const auto num_work = amd_wave_read_first_lane(num_tiles * kargs.k_batch);
1203  auto block_id = amd_wave_read_first_lane(get_block_id());
1204 
1205  while(block_id < num_work)
1206  {
1207  s_waitcnt_barrier();
1208  const auto tile_idx = amd_wave_read_first_lane(block_id % num_tiles);
1209  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
1210  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1211  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1212 
1213  // Synchronize with producer to ensure input data is ready before processing tile
1214  if(kargs.async_input_scheduler.chunk_signals != nullptr)
1215  {
1216  const auto tiles_per_chunk =
1218  const auto tile_idx_pivot =
1220  const auto num_chunks =
1222  if(tiles_per_chunk > 0 && num_chunks > 0)
1223  {
1224  // Pivot allows rotating chunk assignments for load balancing
1225  const auto chunk_idx = amd_wave_read_first_lane(
1226  ((iM + tile_idx_pivot) / tiles_per_chunk) % num_chunks);
1228  chunk_barrier.wait_eq_wave(/*value=*/1, /*offset=*/chunk_idx);
1229  }
1230  }
1231 
1232  // Get the SplitK offset for this block
1233  const auto k_batch = amd_wave_read_first_lane(block_id / num_tiles);
1234  const SplitKBatchOffset splitk_batch_offset(kargs, k_batch);
1235 
1236  std::array<const ADataType*, NumATensor> as_ptr;
1237  static_for<0, NumATensor, 1>{}([&](auto i) {
1238  as_ptr[i] = static_cast<const ADataType*>(kargs.as_ptr[i]) +
1239  splitk_batch_offset.as_k_split_offset[i];
1240  });
1241 
1242  std::array<const BDataType*, NumBTensor> bs_ptr;
1243  static_for<0, NumBTensor, 1>{}([&](auto i) {
1244  bs_ptr[i] = static_cast<const BDataType*>(kargs.bs_ptr[i]) +
1245  splitk_batch_offset.bs_k_split_offset[i];
1246  });
1247 
1248  // Calculate output offset from tile partitioner and apply to output pointer
1249  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
1251  {
1252  const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, k_batch);
1253  e_ptr += output_offset;
1254  }
1255 
1256  // allocate LDS
1257  __shared__ char smem_ptr[GetSmemSize()];
1258  // Run the GEMM
1259 
1260  RunGemm(as_ptr,
1261  bs_ptr,
1262  kargs.ds_ptr,
1263  e_ptr,
1264  smem_ptr,
1265  kargs,
1266  splitk_batch_offset,
1267  i_m,
1268  i_n);
1269 
1270  // Advance to the next work item
1271  block_id += grid_size;
1272  if(block_id >= num_work)
1273  {
1274  break;
1275  }
1276  }
1277  }
1278 };
1279 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:146
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1691
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1634
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:158
__device__ index_t get_grid_size()
Definition: get_id.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:209
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Scheduler for persistent GEMM kernels with asynchronous input streaming.
Definition: persistent_async_input_scheduler.hpp:27
uint32_t tiles_per_chunk_m
Number of M-dimension tiles grouped into each chunk. Grouping tiles balances synchronization overhead...
Definition: persistent_async_input_scheduler.hpp:31
int32_t tile_idx_pivot_m
Pivot offset for rotating the chunk assignment. Allows shifting which tiles map to which chunks,...
Definition: persistent_async_input_scheduler.hpp:41
uint32_t * chunk_signals
Device pointer to array of signal values (uint32_t), one per chunk. Producer sets signals to coordina...
Definition: persistent_async_input_scheduler.hpp:36
uint32_t num_chunks
Number of signal chunks allocated. Must equal ceil((tiles_m + tile_idx_pivot_m) / tiles_per_chunk_m)....
Definition: persistent_async_input_scheduler.hpp:46
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:34
void * c_ptr
Definition: universal_gemm_kernel.hpp:71
const std::array< index_t, NumDTensor > stride_Ds
Definition: universal_gemm_kernel.hpp:78
const std::array< index_t, NumBTensor > stride_Bs
Definition: universal_gemm_kernel.hpp:77
index_t K
Definition: universal_gemm_kernel.hpp:75
void * e_ptr
Definition: universal_gemm_kernel.hpp:70
index_t M
Definition: universal_gemm_kernel.hpp:73
const std::array< const void *, NumDTensor > ds_ptr
Definition: universal_gemm_kernel.hpp:67
const std::array< const void *, NumATensor > as_ptr
Definition: universal_gemm_kernel.hpp:65
const std::array< index_t, NumATensor > stride_As
Definition: universal_gemm_kernel.hpp:76
CK_TILE_HOST UniversalGemmHostArgs(const std::array< const void *, NumATensor > &as_ptr_, const std::array< const void *, NumBTensor > &bs_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, const std::array< index_t, NumATensor > &stride_As_, const std::array< index_t, NumBTensor > &stride_Bs_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_, PersistentAsyncInputScheduler async_input_scheduler_=PersistentAsyncInputScheduler{})
Definition: universal_gemm_kernel.hpp:35
index_t N
Definition: universal_gemm_kernel.hpp:74
index_t stride_E
Definition: universal_gemm_kernel.hpp:81
PersistentAsyncInputScheduler async_input_scheduler
Definition: universal_gemm_kernel.hpp:86
const std::array< const void *, NumBTensor > bs_ptr
Definition: universal_gemm_kernel.hpp:66
index_t stride_C
Definition: universal_gemm_kernel.hpp:82
index_t k_batch
Definition: universal_gemm_kernel.hpp:85
Definition: universal_gemm_kernel.hpp:336
std::array< index_t, NumATensor > as_k_split_offset
Definition: universal_gemm_kernel.hpp:401
index_t splitted_k
Definition: universal_gemm_kernel.hpp:403
__device__ SplitKBatchOffset(const KernelArgs &kargs, const index_t k_id=blockIdx.z)
Definition: universal_gemm_kernel.hpp:339
std::array< index_t, NumBTensor > bs_k_split_offset
Definition: universal_gemm_kernel.hpp:402
Definition: universal_gemm_kernel.hpp:214
static constexpr bool value
Definition: universal_gemm_kernel.hpp:218
decltype(T::UsePersistentKernel) has_persistent_type
Definition: universal_gemm_kernel.hpp:216
decltype(T::GetOutputOffset(std::declval< KernelArgs >(), std::declval< index_t >())) has_get_output_offset_t
Definition: universal_gemm_kernel.hpp:232
static constexpr bool value
Definition: universal_gemm_kernel.hpp:234
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:92
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:100
std::array< index_t, NumBTensor > stride_Bs
The distance between consecutive elements of non-contiguous dimension (in memory) of Bs tensor.
Definition: universal_gemm_kernel.hpp:112
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:98
std::array< index_t, NumATensor > stride_As
The distance between consecutive elements of non-contiguous dimension (in memory) of As tensor.
Definition: universal_gemm_kernel.hpp:109
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
index_t k_batch
Definition: universal_gemm_kernel.hpp:119
index_t N
GEMM's N dimension size.
Definition: universal_gemm_kernel.hpp:104
index_t stride_E
The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.
Definition: universal_gemm_kernel.hpp:118
index_t K
GEMM's K dimension size.
Definition: universal_gemm_kernel.hpp:106
PersistentAsyncInputScheduler async_input_scheduler
Persistent async input scheduler for chunk-based tile scheduling.
Definition: universal_gemm_kernel.hpp:121
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:96
std::array< index_t, NumDTensor > stride_Ds
The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.
Definition: universal_gemm_kernel.hpp:115
index_t M
GEMM's M dimension size.
Definition: universal_gemm_kernel.hpp:102
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:162
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: universal_gemm_kernel.hpp:1158
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: universal_gemm_kernel.hpp:164
static CK_TILE_HOST const std::string GetName()
Definition: universal_gemm_kernel.hpp:270
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: universal_gemm_kernel.hpp:163
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: universal_gemm_kernel.hpp:1197
static CK_TILE_DEVICE auto GetGridSize() -> index_t
Definition: universal_gemm_kernel.hpp:1130
static constexpr bool BDataTypeIsTuple
Definition: universal_gemm_kernel.hpp:169
static CK_TILE_DEVICE auto MakeBBlockWindows(const std::array< const BDataType *, NumBTensor > &bs_ptr, const KernelArgs &kargs, const index_t k_size, const index_t i_n)
Definition: universal_gemm_kernel.hpp:770
static constexpr auto I2
Definition: universal_gemm_kernel.hpp:246
static constexpr bool BLayoutIsTuple
Definition: universal_gemm_kernel.hpp:175
remove_cvref_t< typename GemmPipeline::BElementWise > BElementWise
Definition: universal_gemm_kernel.hpp:208
static constexpr index_t NumATensor
Definition: universal_gemm_kernel.hpp:249
static constexpr bool ALayoutIsTuple
Definition: universal_gemm_kernel.hpp:173
remove_cvref_t< std::tuple_element_t< I0, AsDataType > > ADataType
Definition: universal_gemm_kernel.hpp:253
std::conditional_t< ALayoutIsTuple, remove_cvref_t< typename GemmPipeline::AsLayout >, remove_cvref_t< tuple< typename GemmPipeline::ALayout > >> AsLayout
Definition: universal_gemm_kernel.hpp:182
static constexpr auto I3
Definition: universal_gemm_kernel.hpp:247
std::conditional_t< DDataTypeIsTuple, remove_cvref_t< typename EpiloguePipeline::DsDataType >, remove_cvref_t< tuple< typename EpiloguePipeline::DsDataType > >> DsDataType
Definition: universal_gemm_kernel.hpp:202
static constexpr bool ADataTypeIsTuple
Definition: universal_gemm_kernel.hpp:167
static constexpr bool has_tile_partitioner_output_offset
Definition: universal_gemm_kernel.hpp:241
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< typename GemmPipeline::AsDataType >, remove_cvref_t< tuple< typename GemmPipeline::ADataType > >> AsDataType
Definition: universal_gemm_kernel.hpp:193
static constexpr index_t NumDTensor
Definition: universal_gemm_kernel.hpp:251
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:1066
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition: universal_gemm_kernel.hpp:268
static constexpr bool DDataTypeIsTuple
Definition: universal_gemm_kernel.hpp:171
static CK_TILE_DEVICE auto GetTileCoordinates(const KernelArgs &kargs) -> tuple< index_t, index_t >
Definition: universal_gemm_kernel.hpp:1107
static CK_TILE_DEVICE auto MakeCBlockWindows(EDataType *e_ptr, const KernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:998
static constexpr bool PersistentKernel
Definition: universal_gemm_kernel.hpp:225
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: universal_gemm_kernel.hpp:204
static constexpr auto I1
Definition: universal_gemm_kernel.hpp:245
static CK_TILE_DEVICE auto MakeDBlockWindows(const std::array< const void *, NumDTensor > &ds_ptr, const KernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:921
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: universal_gemm_kernel.hpp:277
static CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:300
remove_cvref_t< std::tuple_element_t< I0, BsDataType > > BDataType
Definition: universal_gemm_kernel.hpp:254
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Calculate grid size that maximizes hardware utilization for persistent kernels.
Definition: universal_gemm_kernel.hpp:288
static constexpr index_t NumBTensor
Definition: universal_gemm_kernel.hpp:250
static constexpr auto I0
Definition: universal_gemm_kernel.hpp:244
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:406
std::conditional_t< DLayoutIsTuple, remove_cvref_t< typename EpiloguePipeline::DsLayout >, remove_cvref_t< tuple< typename EpiloguePipeline::DsLayout > >> DsLayout
Definition: universal_gemm_kernel.hpp:189
static CK_TILE_HOST_DEVICE auto GetNumTiles(Args &&... args) -> index_t
Definition: universal_gemm_kernel.hpp:1138
static constexpr bool DLayoutIsTuple
Definition: universal_gemm_kernel.hpp:177
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: universal_gemm_kernel.hpp:165
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< typename GemmPipeline::BsDataType >, remove_cvref_t< tuple< typename GemmPipeline::BDataType > >> BsDataType
Definition: universal_gemm_kernel.hpp:197
static CK_TILE_DEVICE auto MakeABlockWindows(const std::array< const ADataType *, NumATensor > &as_ptr, const KernelArgs &kargs, const index_t k_size, const index_t i_m)
Definition: universal_gemm_kernel.hpp:693
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:330
remove_cvref_t< typename GemmPipeline::AElementWise > AElementWise
Definition: universal_gemm_kernel.hpp:207
std::conditional_t< BLayoutIsTuple, remove_cvref_t< typename GemmPipeline::BsLayout >, remove_cvref_t< tuple< typename GemmPipeline::BLayout > >> BsLayout
Definition: universal_gemm_kernel.hpp:185
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const UniversalGemmHostArgs< NumATensor, NumBTensor, NumDTensor > &hostArgs)
Definition: universal_gemm_kernel.hpp:313
static CK_TILE_DEVICE auto GetBlockId() -> index_t
Definition: universal_gemm_kernel.hpp:1124
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:210
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: universal_gemm_kernel.hpp:205
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: numeric.hpp:81
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: stream_config.hpp:30
Definition: tuple.hpp:192
Definition: workgroup_barrier.hpp:12
CK_TILE_DEVICE void wait_eq_wave(uint32_t value, uint32_t offset=0)
Definition: workgroup_barrier.hpp:30
#define CK_TILE_ENV(name)
Definition: env.hpp:145