/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp Source File
cshuffle_epilogue.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 #include "ck_tile/core.hpp"
12 
13 #include <type_traits>
14 
15 namespace ck_tile {
16 
17 template <typename AsDataType_,
18  typename BsDataType_,
19  typename DsDataType_,
20  typename AccDataType_,
21  typename ODataType_,
22  typename DsLayout_,
23  typename ELayout_,
24  typename CDElementwise_,
25  index_t kM_,
26  index_t kN_,
27  index_t MWave_,
28  index_t NWave_,
29  index_t MPerXdl_,
30  index_t NPerXdl_,
31  index_t KPerXdl_,
32  bool isCTransposed_,
33  index_t kNumWaveGroups_ = 1,
34  bool FixedVectorSize_ = false,
35  index_t VectorSizeC_ = 1,
36  bool TiledMMAPermuteN_ = false,
37  index_t BlockedXDLN_PerWarp_ = 1, // The number of continuous xdl_output per warp
38  bool DoubleSmemBuffer_ = false>
40 {
49  static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
50  static constexpr index_t kMPerBlock = kM_;
51  static constexpr index_t kNPerBlock = kN_;
52  static constexpr index_t MWave = MWave_;
53  static constexpr index_t NWave = NWave_;
54  static constexpr index_t MPerXdl = MPerXdl_;
55  static constexpr index_t NPerXdl = NPerXdl_;
56  static constexpr index_t KPerXdl = KPerXdl_;
57  static constexpr index_t isCTransposed = isCTransposed_;
58  static constexpr bool FixedVectorSize = FixedVectorSize_;
59  static constexpr index_t VectorSizeC = VectorSizeC_;
60  static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
61  static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
62  static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
63  static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
64  static constexpr index_t NumDTensor = DsDataType::size();
65 
66  static_assert(NumDTensor == DsLayout::size(),
67  "The size of DsDataType and DsLayout should be the same");
68 };
69 
70 template <typename Problem_, typename Policy_ = void>
72 {
80 
83 
87 
91 
94 
95  using ATypeToUse = std::conditional_t<std::is_same_v<ADataType, pk_int4_t> ||
96  std::is_same_v<ADataType, pk_fp4_t>,
97  BDataType,
98  ADataType>;
99  // Used for weight-only quantization kernel, B would be dequantized to the same data type as A
100  using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
101  std::is_same_v<BDataType, pk_fp4_t> ||
102  std::is_same_v<BDataType, pk_fp4_raw_t>,
103  ADataType,
104  BDataType>;
105 
108  static constexpr index_t kBlockSize = Problem::kBlockSize;
109  static constexpr index_t kMPerBlock = Problem::kMPerBlock;
110  static constexpr index_t kNPerBlock = Problem::kNPerBlock;
111  static constexpr index_t MWave = Problem::MWave;
112  static constexpr index_t NWave = Problem::NWave;
113  static constexpr index_t MPerXdl = Problem::MPerXdl;
114  static constexpr index_t NPerXdl = Problem::NPerXdl;
115  static constexpr index_t KPerXdl = Problem::KPerXdl;
116  static constexpr index_t isCTransposed = Problem::isCTransposed;
117  static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
118  static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
119  static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
120  static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
121  static constexpr index_t VectorSizeC = Problem::VectorSizeC;
122  static constexpr index_t MPerIteration = MPerXdl * MWave;
123  static constexpr index_t NPerIteration = NPerXdl * NWave;
124  static constexpr index_t NumDTensor = Problem::NumDTensor;
125  static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
126  static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
127 
129 
131 
132  static_assert(NumDTensor == DsLayout::size(),
133  "The size of DsDataType and DsLayout should be the same");
134 
135  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
136  {
137  // clang-format off
138  return concat('_', "CShuffleEpilogue",
139  concat('x', MWave, NWave),
140  concat('x', MPerXdl, NPerXdl, KPerXdl),
141  VectorSizeC,
142  isCTransposed ? "CTransposed" : "CNotTransposed");
143  // clang-format on
144  }
145 
157  {
158  if constexpr(FixedVectorSize)
159  {
160  return VectorSizeC;
161  }
162  constexpr index_t max_vector_size = 16;
163  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
164  {
165  return std::min(static_cast<int>(NPerIteration),
166  static_cast<int>(max_vector_size / sizeof(ODataType)));
167  }
168  else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
169  {
170  return std::min(static_cast<int>(MPerIteration),
171  static_cast<int>(max_vector_size / sizeof(ODataType)));
172  }
173  else
174  {
175  static_assert(false, "Unsupported ELayout!");
176  }
177  }
178 
184  template <index_t I>
186  {
187  constexpr index_t max_vector_size = 16;
188  using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
189  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
190  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
191  {
192  return std::min(static_cast<int>(NPerIteration),
193  static_cast<int>(max_vector_size / sizeof(DiDataType)));
194  }
195  else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
196  {
197  return std::min(static_cast<int>(MPerIteration),
198  static_cast<int>(max_vector_size / sizeof(DiDataType)));
199  }
200  else
201  {
202  static_assert(false, "Unsupported DLayout!");
203  }
204  return max_vector_size / sizeof(DiDataType);
205  }
206 
212  template <index_t m_shuffle_tile, index_t n_shuffle_tile>
214  {
215  constexpr index_t m_val = MPerXdl * MWave * m_shuffle_tile;
216  constexpr index_t n_val = NPerXdl * NWave * n_shuffle_tile;
217 
218  constexpr auto shuffle_tile =
219  m_val * n_val * sizeof(ODataType) > get_smem_capacity() || DoubleSmemBuffer
220  ? std::make_tuple(1, 1)
221  : std::make_tuple(m_shuffle_tile, n_shuffle_tile);
222 
223  return shuffle_tile;
224  }
225 
234  static constexpr auto shuffle_tile_tuple = [] {
235  constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
236  if constexpr(elem_per_thread <= GetVectorSizeC())
237  {
238  return std::make_tuple(1, 1);
239  }
240  else
241  {
242  constexpr index_t num_xdl_shuffles = elem_per_thread / GetVectorSizeC();
243  static_assert(elem_per_thread % GetVectorSizeC() == 0);
244  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
245  {
246  static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
247  (kMPerBlock % num_xdl_shuffles == 0),
248  "kMPerBlock must be divisible by MPerXdl*MWave and "
249  "num_xdl_shuffles for CShuffleEpilogue");
250  return AlignShuffleTileWithSmem<min(num_xdl_shuffles,
251  kMPerBlock / (MPerXdl * MWave)),
252  1>();
253  }
254  else
255  {
256  static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
257  (kNPerBlock % num_xdl_shuffles == 0),
258  "kNPerBlock must be divisible by NPerXdl*NWave and "
259  "num_xdl_shuffles for CShuffleEpilogue");
260  return AlignShuffleTileWithSmem<1,
261  min(num_xdl_shuffles,
262  kNPerBlock / (NPerXdl * NWave))>();
263  }
264  }
265  }();
266  static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
269 
270  static constexpr auto MNPerIterationShuffle = [] {
271  constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
272  constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
273  if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
275  else
276  return std::make_tuple(m_val, n_val);
277  }();
278  static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
279  static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
280 
282  BTypeToUse,
283  AccDataType,
284  MPerXdl,
285  NPerXdl,
286  KPerXdl,
287  isCTransposed>;
288 
289  using CWarpDstr = typename WG::CWarpDstr;
290  using CWarpTensor = typename WG::CWarpTensor;
291  using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
295 
296  template <typename Problem>
298  {
299  // N is contiguous dimension
300  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
301  {
305  }
306  // M is contiguous dimension
307  else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
308  {
312  }
313  else
314  {
315  static_assert(false, "Unsupported ELayout!");
316  }
317  }
318 
320  {
321  constexpr auto block_outer_dstr_encoding = [] {
322  if constexpr(BlockedXDLN_PerWarp == 1)
323  {
330  sequence<0, 0>>{};
331  }
332  else
333  {
334 #if defined(__gfx950__)
335  constexpr auto is_950 = true;
336 #else
337  constexpr auto is_950 = false;
338 #endif
339  constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
340  // BlockedLayout
341  // this branch is for original a16w4
342  if constexpr(is_950 || is_any_of<ADataType, pk_int4_t, pk_fp4_t>::value ||
344  {
346  sequence<>,
353  }
354  else
355  {
357  sequence<>,
364  }
365  }
366  }();
367  constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
368  block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
369 
370  return block_dstr_encoding;
371  }
372 
374  {
375  constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
376  return lds_block_desc.get_element_space_size() * sizeof(ODataType);
377  }
378 
379  template <index_t iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
380  CK_TILE_DEVICE void
381  scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window)
382  {
383  // Check if scales are EmptyScale first (no scaling needed)
384  if constexpr(std::is_same_v<ScaleM, EmptyScale> && std::is_same_v<ScaleN, EmptyScale>)
385  {
386  // No scaling needed - this is a no-op
387  }
388  // Check if scales are scalar AccDataType
389  else if constexpr(std::is_same_v<ScaleM, AccDataType> &&
390  std::is_same_v<ScaleN, AccDataType>)
391  {
392  // Handle scalar scales
393  const AccDataType scale_m = scale_m_window;
394  const AccDataType scale_n = scale_n_window;
395  tile_elementwise_inout([&](auto& element) { element = element * scale_m * scale_n; },
396  lds_tile);
397  }
398  // Otherwise, assume they are tile windows that can be loaded
399  else
400  {
401  // Load tiles
402  const auto scale_m_tile = load_tile(scale_m_window);
403  const auto scale_n_tile = load_tile(scale_n_window);
404 
405  // Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
407  element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
408 
409  // Move scale windows
410  constexpr index_t num_access = SFC::get_num_of_access();
411  if constexpr(iAccess != num_access - 1)
412  {
413  constexpr auto step = SFC::get_forward_step(number<iAccess>{});
414 
415  move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
416  move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
417  }
418  }
419  }
420 
421  template <index_t iAccess, typename OAccTile, typename LdsTile>
422  CK_TILE_DEVICE void slice_acc_tile(const OAccTile& o_acc_tile, LdsTile& lds_tile)
423  {
424  constexpr auto idx_y_start = SFC::get_index(number<iAccess>{});
425 
426  constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
427  constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
428  constexpr auto c_warp_y_lengths =
429  to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
430  constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
431 
432  lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
435  c_warp_y_index_zeros),
437  c_warp_y_lengths));
438  }
439 
440  template <typename LdsTile, typename InLdsWindow>
441  CK_TILE_DEVICE void cast_lds_tile(LdsTile& lds_tile, InLdsWindow& in_lds_window)
442  {
443  const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
444 
445  store_tile(in_lds_window, c_warptile_in_tensor_casted);
446  }
447 
448  template <typename DramWindows, typename COutTensor>
449  CK_TILE_DEVICE void apply_d_tensors(DramWindows& d_dram_windows, COutTensor& c_out_tensor)
450  {
451  const auto ds_tensor = generate_tuple(
452  [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
453 
454  const auto c_ds_tiles = concat_tuple_of_reference(
455  tie(c_out_tensor, c_out_tensor),
456  generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
457  number<NumDTensor>{}));
458 
460  }
461 
462  template <typename OutDramWindow, typename COutTensor>
463  CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window,
464  const COutTensor& c_out_tensor)
465  {
466  if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
467  memory_operation_enum::set)
468  {
469  store_tile(out_dram_window, c_out_tensor);
470  }
471  else
472  {
473  update_tile(out_dram_window, c_out_tensor);
474  }
475  }
476 
480  template <index_t iAccess, typename OutDramWindow, typename DDramWindows>
481  CK_TILE_DEVICE void move_windows(OutDramWindow& out_dram_window, DDramWindows& d_dram_windows)
482  {
483  constexpr index_t num_access = SFC::get_num_of_access();
484  if constexpr(iAccess != num_access - 1)
485  {
486  constexpr auto step = SFC::get_forward_step(number<iAccess>{});
487 
488  // move the output dram window
489  move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
490 
491  // move windows for each of the D matrices (inputs for element-wise)
492  static_for<0, NumDTensor, 1>{}([&](auto idx) {
493  move_tile_window(d_dram_windows[idx], {step.at(number<0>{}), step.at(number<1>{})});
494  });
495  }
496  }
497 
498  // TODO: Check if there would be nicer ways to overload rather than with EmptyScale or nullptr_t
499  struct EmptyScale
500  {
501  };
502 
503  template <typename, typename = void>
505  {
506  using DataType = float;
507  };
508 
509  template <typename T>
510  struct ScaleDataType<T, std::void_t<typename T::DataType>>
511  {
512  using DataType = typename T::DataType;
513  };
514 
515  template <typename ODramWindow,
516  typename OAccTile,
517  typename DsDramWindows,
518  typename ScaleM = EmptyScale,
519  typename ScaleN = EmptyScale,
520  int EnablePermuateN_ = TiledMMAPermuteN,
521  std::enable_if_t<EnablePermuateN_, int> = 0>
522  CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
523  const OAccTile& o_acc_tile,
524  const DsDramWindows& ds_dram_windows,
525  void* /* p_smem */,
526  const ScaleM& scale_m = {},
527  const ScaleN& scale_n = {})
528  {
529  static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size();
530 
531  static_assert(MPerXdl % RowsPerLane == 0,
532  "CShuffle (permuteN): MPerXdl must be divisible by per-lane row count.");
533  constexpr int kM0 = MWave;
534  constexpr int kM2 = RowsPerLane;
535  constexpr int kM1 = MPerXdl / kM2;
536 
537  constexpr int kN0 = NWave;
538  constexpr int kN1 = NPerXdl;
539  constexpr int kN2 = NRepeat;
540 
541  using IntrThreadShuffleEncode =
542  tile_distribution_encoding<sequence<>,
543  tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
544  tuple<sequence<1, 2>, sequence<1, 2>>,
545  tuple<sequence<0, 0>, sequence<1, 1>>,
546  sequence<1, 2>,
547  sequence<2, 2>>;
548  constexpr auto dram_tile_distribution =
549  make_static_tile_distribution(IntrThreadShuffleEncode{});
550 
551  auto d_dram_windows = generate_tuple(
552  [&](auto idx) {
553  return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
554  },
555  number<NumDTensor>{});
556 
557  constexpr auto c_warp_y_lengths =
558  to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
559  constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
560 
561  auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
562  auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
563 
564  // Optional scales (must share the same distribution to match per-thread indexing)
565  constexpr bool has_scales =
567  constexpr bool has_scalar_scales =
568  std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
569 
570  // Tiles to hold row/col scales when present
571  using SMType = typename ScaleDataType<ScaleM>::DataType;
572  using SNType = typename ScaleDataType<ScaleN>::DataType;
573 
574  auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
575  auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
576 
577  // Build windows only if non-scalar scales are provided
578  auto scale_m_window = [&]() {
579  if constexpr(has_scales && !has_scalar_scales)
580  {
581  return make_tile_window(scale_m, dram_tile_distribution);
582  }
583  else
584  {
585  return EmptyScale{};
586  }
587  }();
588  auto scale_n_window = [&]() {
589  if constexpr(has_scales && !has_scalar_scales)
590  {
591  return make_tile_window(scale_n, dram_tile_distribution);
592  }
593  else
594  {
595  return EmptyScale{};
596  }
597  }();
598 
599  static_for<0, MRepeat, 1>{}([&](auto mIter) {
600  // Slice accumulators for this M repeat into the permuted layout
601  shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
602  merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
603  merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
604 
605  // If non-scalar scales provided, load them with identical distribution
606  if constexpr(has_scales && !has_scalar_scales)
607  {
608  sm_tile = load_tile(scale_m_window); // row scales in permuted layout
609  sn_tile = load_tile(scale_n_window); // col scales in permuted layout
610  }
611 
612  // Pack 4 “rows per lane” as you already do
613  static_for<0, NRepeat, 1>{}([&](auto n_idx) {
614  // source indices in shuffle_acc: (n_idx * product(Y) + row)
615  const index_t plane = c_warp_y_lengths.product();
616 
617  // local lambda to fuse scale (if present) and convert
618  static_for<0, kM2, 1>{}([&](auto m_lane) {
619  const int src = n_idx * plane + m_lane; // source row in this N-plane
620  const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output
621  AccDataType v = shuffle_acc.get_thread_buffer()[src];
622 
623  if constexpr(has_scalar_scales)
624  {
625  v = static_cast<AccDataType>(v * scale_m * scale_n);
626  }
627  else if constexpr(has_scales && !has_scalar_scales)
628  {
629  const auto sm = static_cast<float>(sm_tile.get_thread_buffer()[dst]);
630  const auto sn = static_cast<float>(sn_tile.get_thread_buffer()[dst]);
631  v = static_cast<AccDataType>(v * sm * sn);
632  }
633 
634  c_out_tensor.get_thread_buffer()[dst] = type_convert<ODataType>(v);
635  });
636  });
637 
638  // store/update
639  if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
640  memory_operation_enum::set)
641  {
642  store_tile(out_dram_window, c_out_tensor);
643  }
644  else
645  {
646  update_tile(out_dram_window, c_out_tensor);
647  }
648 
649  // advance output (and any D-tensors) by one MPerXdl*MWave chunk
650  move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
651  static_for<0, NumDTensor, 1>{}([&](auto idx) {
652  move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
653  });
654  });
655  }
656 
657  template <typename ODramWindow,
658  typename OAccTile,
659  typename DsDramWindows,
660  typename ScaleM = EmptyScale,
661  typename ScaleN = EmptyScale,
662  int EnablePermuateN_ = TiledMMAPermuteN,
663  std::enable_if_t<!EnablePermuateN_, int> = 0>
664  CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
665  const OAccTile& o_acc_tile,
666  const DsDramWindows& ds_dram_windows,
667  void* p_smem,
668  const ScaleM& scale_m = {},
669  const ScaleN& scale_n = {})
670  {
671  constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
672 
673  auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
674 
675  constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
676  auto o_lds_block = make_tensor_view<address_space_enum::lds>(
677  static_cast<ODataType*>(p_smem), lds_block_desc);
678 
679  auto in_lds_window = make_tile_window(
680  o_lds_block,
681  make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
682  {0, 0},
683  LdsTileDistr);
684 
685  auto out_lds_window = make_tile_window(
686  o_lds_block,
687  make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
688  {0, 0});
689 
690  constexpr index_t num_access = SFC::get_num_of_access();
691 
692  static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
693  "Currently, the CShuffle Epilogue only supports the Row Major Output layout");
694 
695  using TileEncodingPattern =
696  tile_distribution_encoding_pattern_2d<kBlockSize,
699  GetVectorSizeC(),
701  Problem::kNumWaveGroups>;
702  constexpr auto dram_tile_distribution =
703  TileEncodingPattern::make_2d_static_tile_distribution();
704 
705  auto d_dram_windows = generate_tuple(
706  [&](auto idx) {
707  return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
708  },
709  number<NumDTensor>{});
710 
711  constexpr bool has_scales =
712  !std::is_same_v<ScaleM, EmptyScale> && !std::is_same_v<ScaleN, EmptyScale>;
713  constexpr bool has_scalar_scales =
714  std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
715  auto scale_m_window = [&]() {
716  if constexpr(has_scalar_scales)
717  {
718  return scale_m;
719  }
720  else if constexpr(has_scales)
721  {
722  return make_tile_window(scale_m, lds_tile.get_tile_distribution());
723  }
724  else
725  {
726  return EmptyScale{};
727  }
728  }();
729  auto scale_n_window = [&]() {
730  if constexpr(has_scalar_scales)
731  {
732  return scale_n;
733  }
734  else if constexpr(has_scales)
735  {
736  return make_tile_window(scale_n, lds_tile.get_tile_distribution());
737  }
738  else
739  {
740  return EmptyScale{};
741  }
742  }();
743 
744  static_for<0, num_access, 1>{}([&](auto iAccess) {
745  block_sync_lds();
746  slice_acc_tile<iAccess>(o_acc_tile, lds_tile);
747 
748  if constexpr(has_scales)
749  {
750  scale_tile<iAccess>(lds_tile, scale_m_window, scale_n_window);
751  }
752 
753  cast_lds_tile(lds_tile, in_lds_window);
754  block_sync_lds();
755 
756  auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
757 
758  apply_d_tensors(d_dram_windows, c_out_tensor);
759  store_to_dram(out_dram_window, c_out_tensor);
760  move_windows<iAccess>(out_dram_window, d_dram_windows);
761  });
762  }
763 };
764 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
constexpr tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:376
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
constexpr CK_TILE_HOST_DEVICE auto generate_tie(F &&f, number< N >)
Definition: tuple.hpp:435
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename impl::warp_gemm_dispatcher::Dispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition: warp_gemm_dispatcher.hpp:178
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc &in_element_func, const Tuple &t, std::index_sequence< I... >)
Template function that "unpacks" a tuple and applies an element-wise operation.
Definition: tile_elementwise.hpp:71
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition: shuffle_tile.hpp:154
@ thread_raked
Thread raked pattern.
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1066
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:837
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto concat_tuple_of_reference(const tuple< X &... > &tx, const tuple< Y &... > &ty)
Definition: tuple.hpp:443
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:24
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:207
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:158
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:495
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1037
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:209
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:212
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
Definition: cshuffle_epilogue.hpp:500
typename T::DataType DataType
Definition: cshuffle_epilogue.hpp:512
Definition: cshuffle_epilogue.hpp:505
float DataType
Definition: cshuffle_epilogue.hpp:506
Definition: cshuffle_epilogue.hpp:72
static constexpr index_t kBlockSize
Definition: cshuffle_epilogue.hpp:108
CK_TILE_DEVICE void scale_tile(LdsTile &lds_tile, ScaleM &scale_m_window, ScaleN &scale_n_window)
Definition: cshuffle_epilogue.hpp:381
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< AsDataType >, remove_cvref_t< tuple< AsDataType > >> AsDataTypeTuple
Definition: cshuffle_epilogue.hpp:86
static constexpr index_t NRepeat
Definition: cshuffle_epilogue.hpp:126
CK_TILE_DEVICE void slice_acc_tile(const OAccTile &o_acc_tile, LdsTile &lds_tile)
Definition: cshuffle_epilogue.hpp:422
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsBlockDescriptor()
Definition: cshuffle_epilogue.hpp:297
static constexpr index_t MRepeat
Definition: cshuffle_epilogue.hpp:125
typename WG::CWarpTensor CWarpTensor
Definition: cshuffle_epilogue.hpp:290
typename WG::CWarpDstrEncoding CWarpDstrEncoding
Definition: cshuffle_epilogue.hpp:291
remove_cvref_t< typename Problem::AsDataType > AsDataType
Definition: cshuffle_epilogue.hpp:74
remove_cvref_t< Problem_ > Problem
Definition: cshuffle_epilogue.hpp:73
static constexpr index_t MPerXdl
Definition: cshuffle_epilogue.hpp:113
static constexpr bool FixedVectorSize
Definition: cshuffle_epilogue.hpp:117
std::conditional_t< std::is_same_v< BDataType, pk_int4_t >||std::is_same_v< BDataType, pk_fp4_t >||std::is_same_v< BDataType, pk_fp4_raw_t >, ADataType, BDataType > BTypeToUse
Definition: cshuffle_epilogue.hpp:104
static constexpr CK_TILE_HOST_DEVICE index_t GetVectorSizeD(number< I > index)
Get the vector store size for Di tensor.
Definition: cshuffle_epilogue.hpp:185
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: cshuffle_epilogue.hpp:77
CK_TILE_DEVICE void store_to_dram(OutDramWindow &out_dram_window, const COutTensor &c_out_tensor)
Definition: cshuffle_epilogue.hpp:463
static constexpr bool ADataTypeIsTuple
Definition: cshuffle_epilogue.hpp:81
static constexpr index_t kNPerBlock
Definition: cshuffle_epilogue.hpp:110
static constexpr index_t BlockedXDLN_PerWarp
Definition: cshuffle_epilogue.hpp:119
static constexpr CK_TILE_HOST_DEVICE auto AlignShuffleTileWithSmem()
Shuffle tile configuration parameters check and aligment.
Definition: cshuffle_epilogue.hpp:213
remove_cvref_t< typename Problem::ELayout > ELayout
Definition: cshuffle_epilogue.hpp:106
static constexpr bool TiledMMAPermuteN
Definition: cshuffle_epilogue.hpp:118
static constexpr bool BDataTypeIsTuple
Definition: cshuffle_epilogue.hpp:82
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition: cshuffle_epilogue.hpp:93
remove_cvref_t< typename Problem::DsLayout > DsLayout
Definition: cshuffle_epilogue.hpp:79
static constexpr CK_TILE_DEVICE auto MakeLdsDistributionEncode()
Definition: cshuffle_epilogue.hpp:319
static constexpr index_t MPerIteration
Definition: cshuffle_epilogue.hpp:122
static constexpr auto MNPerIterationShuffle
Definition: cshuffle_epilogue.hpp:270
static constexpr index_t isCTransposed
Definition: cshuffle_epilogue.hpp:116
CDElementwise elfunc_
Definition: cshuffle_epilogue.hpp:128
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: cshuffle_epilogue.hpp:373
CK_TILE_DEVICE void apply_d_tensors(DramWindows &d_dram_windows, COutTensor &c_out_tensor)
Definition: cshuffle_epilogue.hpp:449
static constexpr index_t MWave
Definition: cshuffle_epilogue.hpp:111
static constexpr CK_TILE_HOST_DEVICE index_t GetVectorSizeC()
Get the vector store size for C tensor.
Definition: cshuffle_epilogue.hpp:156
static constexpr index_t VectorSizeC
Definition: cshuffle_epilogue.hpp:121
remove_cvref_t< typename Problem::DsDataType > DsDataType
Definition: cshuffle_epilogue.hpp:78
CK_TILE_DEVICE void move_windows(OutDramWindow &out_dram_window, DDramWindows &d_dram_windows)
Move both the output and D tensors windows for the next access.
Definition: cshuffle_epilogue.hpp:481
remove_cvref_t< typename Problem::CDElementwise > CDElementwise
Definition: cshuffle_epilogue.hpp:107
static CK_TILE_HOST const std::string GetName()
Definition: cshuffle_epilogue.hpp:135
static constexpr index_t NPerIterationShuffle
Definition: cshuffle_epilogue.hpp:279
std::conditional_t< std::is_same_v< ADataType, pk_int4_t >||std::is_same_v< ADataType, pk_fp4_t >, BDataType, ADataType > ATypeToUse
Definition: cshuffle_epilogue.hpp:98
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: cshuffle_epilogue.hpp:76
CK_TILE_DEVICE auto operator()(ODramWindow &out_dram_window, const OAccTile &o_acc_tile, const DsDramWindows &ds_dram_windows, void *, const ScaleM &scale_m={}, const ScaleN &scale_n={})
Definition: cshuffle_epilogue.hpp:522
static constexpr index_t NumDTensor
Definition: cshuffle_epilogue.hpp:124
static constexpr index_t KPerXdl
Definition: cshuffle_epilogue.hpp:115
CK_TILE_DEVICE auto operator()(ODramWindow &out_dram_window, const OAccTile &o_acc_tile, const DsDramWindows &ds_dram_windows, void *p_smem, const ScaleM &scale_m={}, const ScaleN &scale_n={})
Definition: cshuffle_epilogue.hpp:664
static constexpr bool DoubleSmemBuffer
Definition: cshuffle_epilogue.hpp:120
static constexpr index_t NumMXdlPerWavePerShuffle
Definition: cshuffle_epilogue.hpp:266
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< BsDataType >, remove_cvref_t< tuple< BsDataType > >> BsDataTypeTuple
Definition: cshuffle_epilogue.hpp:90
static constexpr index_t NumNXdlPerWavePerShuffle
Definition: cshuffle_epilogue.hpp:267
remove_cvref_t< typename Problem::BsDataType > BsDataType
Definition: cshuffle_epilogue.hpp:75
WarpGemmDispatcher< ATypeToUse, BTypeToUse, AccDataType, MPerXdl, NPerXdl, KPerXdl, isCTransposed > WG
Definition: cshuffle_epilogue.hpp:287
static constexpr index_t MPerIterationShuffle
Definition: cshuffle_epilogue.hpp:278
CK_TILE_DEVICE void cast_lds_tile(LdsTile &lds_tile, InLdsWindow &in_lds_window)
Definition: cshuffle_epilogue.hpp:441
CK_TILE_DEVICE CShuffleEpilogue(CDElementwise elfunc=CDElementwise{})
Definition: cshuffle_epilogue.hpp:130
static constexpr auto shuffle_tile_tuple
Shuffle tile configuration parameters.
Definition: cshuffle_epilogue.hpp:234
static constexpr index_t NWave
Definition: cshuffle_epilogue.hpp:112
static constexpr index_t kMPerBlock
Definition: cshuffle_epilogue.hpp:109
static constexpr index_t NPerIteration
Definition: cshuffle_epilogue.hpp:123
static constexpr index_t NPerXdl
Definition: cshuffle_epilogue.hpp:114
typename WG::CWarpDstr CWarpDstr
Definition: cshuffle_epilogue.hpp:289
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition: cshuffle_epilogue.hpp:92
Definition: cshuffle_epilogue.hpp:40
remove_cvref_t< BsDataType_ > BsDataType
Definition: cshuffle_epilogue.hpp:42
remove_cvref_t< ODataType_ > ODataType
Definition: cshuffle_epilogue.hpp:44
static constexpr index_t isCTransposed
Definition: cshuffle_epilogue.hpp:57
static constexpr index_t MWave
Definition: cshuffle_epilogue.hpp:52
static constexpr index_t BlockedXDLN_PerWarp
Definition: cshuffle_epilogue.hpp:60
static constexpr bool DoubleSmemBuffer
Definition: cshuffle_epilogue.hpp:61
static constexpr index_t VectorSizeC
Definition: cshuffle_epilogue.hpp:59
static constexpr index_t NWave
Definition: cshuffle_epilogue.hpp:53
static constexpr index_t KPerXdl
Definition: cshuffle_epilogue.hpp:56
static constexpr index_t kBlockSize
Definition: cshuffle_epilogue.hpp:49
static constexpr bool FixedVectorSize
Definition: cshuffle_epilogue.hpp:58
static constexpr index_t kMPerBlock
Definition: cshuffle_epilogue.hpp:50
static constexpr index_t MPerXdl
Definition: cshuffle_epilogue.hpp:54
remove_cvref_t< CDElementwise_ > CDElementwise
Definition: cshuffle_epilogue.hpp:48
remove_cvref_t< AsDataType_ > AsDataType
Definition: cshuffle_epilogue.hpp:41
static constexpr bool TiledMMAPermuteN
Definition: cshuffle_epilogue.hpp:62
static constexpr index_t NPerXdl
Definition: cshuffle_epilogue.hpp:55
remove_cvref_t< DsLayout_ > DsLayout
Definition: cshuffle_epilogue.hpp:46
remove_cvref_t< DsDataType_ > DsDataType
Definition: cshuffle_epilogue.hpp:45
remove_cvref_t< ELayout_ > ELayout
Definition: cshuffle_epilogue.hpp:47
remove_cvref_t< AccDataType_ > AccDataType
Definition: cshuffle_epilogue.hpp:43
static constexpr index_t kNumWaveGroups
Definition: cshuffle_epilogue.hpp:63
static constexpr index_t kNPerBlock
Definition: cshuffle_epilogue.hpp:51
static constexpr index_t NumDTensor
Definition: cshuffle_epilogue.hpp:64
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: unary_element_wise_operation.hpp:509
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
Definition: space_filling_curve.hpp:20
static constexpr CK_TILE_HOST_DEVICE auto get_forward_step(number< AccessIdx1d >)
Definition: space_filling_curve.hpp:70
static constexpr CK_TILE_HOST_DEVICE auto get_index(number< AccessIdx1d >)
Definition: space_filling_curve.hpp:158
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_access()
Definition: space_filling_curve.hpp:46
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192