/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp Source File
gemm_quant_pipeline_problem.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 #include <string>
11 
12 namespace ck_tile {
13 
14 template <typename ADataType_,
15  typename AQDataType_,
16  typename BDataType_,
17  typename BQDataType_,
18  typename CDataType_,
19  typename BlockGemmShape_,
20  typename Traits_,
21  typename AQuantGroupSize_,
22  typename BQuantGroupSize_,
23  bool TransposeC_,
24  typename ComputeDataType_ = BDataType_,
26  bool HasHotLoop_ = true,
27  TailNumber TailNum_ = TailNumber::Full>
29  BDataType_,
30  CDataType_,
31  BlockGemmShape_,
32  Traits_,
33  ComputeDataType_>
34 {
35  using Base = GemmPipelineProblemBase<ADataType_,
36  BDataType_,
37  CDataType_,
38  BlockGemmShape_,
39  Traits_,
40  ComputeDataType_>;
41 
42  using Traits = typename Base::Traits;
43 
44  using typename Base::ADataType;
45  using typename Base::BDataType;
46  using typename Base::CDataType;
47  using typename Base::ComputeDataType;
50 
53  std::conditional_t<!std::is_void_v<AQuantGroupSize_>, AQuantGroupSize_, BQuantGroupSize_>;
55  std::conditional_t<!std::is_void_v<BQuantGroupSize_>, BQuantGroupSize_, AQuantGroupSize_>;
56  // Unified alias for 1D quantization usage, to avoid forcing users to pick one.
58 
59  using typename Base::ALayout;
60  using typename Base::BLayout;
61  using typename Base::CLayout;
62 
63  static constexpr bool TransposeC = TransposeC_;
64  static constexpr bool PreshuffleB = Traits::PreshuffleB;
65  static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
66  using Base::kBlockSize;
67 
68  using Base::kPadK;
69  using Base::kPadM;
70  using Base::kPadN;
71 
73 
76 
77  static constexpr auto Scheduler = Scheduler_;
78  static constexpr auto HasHotLoop = HasHotLoop_;
79  static constexpr auto TailNum = TailNum_;
80 
81  static_assert(BlockGemmShape::kM % AQuantGroupSize::kM == 0);
82  static_assert(BlockGemmShape::kK % AQuantGroupSize::kK == 0);
83  static_assert(BlockGemmShape::kM % BQuantGroupSize::kM == 0);
84  static_assert(BlockGemmShape::kK % BQuantGroupSize::kK == 0);
85 
86  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
87  {
88  // clang-format off
89  return concat('_', "gemm_quant_problem",
91  concat('x', kPadM, kPadN, kPadK),
92  Scheduler,
93  AQuantGroupSize::GetName(),
94  BQuantGroupSize::GetName());
95  // clang-format on
96  }
97 
98  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentAQ()
99  {
100  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
101  return VectorLoadSize / sizeof(AQDataType);
102  }
103 
104  static constexpr index_t VectorSizeAQ = []() {
105  static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
106  return kPadK ? 1 : GetAlignmentAQ();
107  }();
108 
109  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBQ()
110  {
111  return VectorLoadSize / sizeof(BQDataType);
112  }
113 
114  static constexpr index_t VectorSizeBQ = []() { return kPadK ? 1 : GetAlignmentBQ(); }();
115 };
116 
117 template <typename ADataType_,
118  typename AQDataType_,
119  typename BDataType_,
120  typename CDataType_,
121  typename BlockGemmShape_,
122  typename Traits_,
123  typename QuantGroupSize_,
124  bool TransposeC_,
125  typename ComputeDataType_ = BDataType_,
127  bool HasHotLoop_ = true,
128  TailNumber TailNum_ = TailNumber::Full>
130  AQDataType_,
131  BDataType_,
132  void, // no BQDataType for AQuant
133  CDataType_,
134  BlockGemmShape_,
135  Traits_,
136  QuantGroupSize_,
137  void,
138  TransposeC_,
139  ComputeDataType_,
140  Scheduler_,
141  HasHotLoop_,
142  TailNum_>;
143 
144 template <typename ADataType_,
145  typename BDataType_,
146  typename BQDataType_,
147  typename CDataType_,
148  typename BlockGemmShape_,
149  typename Traits_,
150  typename QuantGroupSize_,
151  typename ComputeDataType_ = ADataType_,
153  bool HasHotLoop_ = true,
154  TailNumber TailNum_ = TailNumber::Full>
156  void, // no AQDataType for BQuant
157  BDataType_,
158  BQDataType_,
159  CDataType_,
160  BlockGemmShape_,
161  Traits_,
162  void,
163  QuantGroupSize_,
164  false, // no TransposeC
165  ComputeDataType_,
166  Scheduler_,
167  HasHotLoop_,
168  TailNum_>;
169 
170 template <typename ADataType_,
171  typename AQDataType_,
172  typename BDataType_,
173  typename BQDataType_,
174  typename CDataType_,
175  typename BlockGemmShape_,
176  typename Traits_,
177  typename AQuantGroupSize_,
178  typename BQuantGroupSize_,
179  bool TransposeC_,
180  typename ComputeDataType_ = ADataType_,
182  bool HasHotLoop_ = true,
183  TailNumber TailNum_ = TailNumber::Full>
185  AQDataType_,
186  BDataType_,
187  BQDataType_,
188  CDataType_,
189  BlockGemmShape_,
190  Traits_,
191  AQuantGroupSize_,
192  BQuantGroupSize_,
193  TransposeC_,
194  ComputeDataType_,
195  Scheduler_,
196  HasHotLoop_,
197  TailNum_>;
198 
199 template <typename ADataType_,
200  typename BDataType_,
201  typename CDataType_,
202  typename AccDataType_,
203  typename BlockGemmShape_,
204  typename Traits_,
205  bool TransposeC_ = false,
206  typename ComputeDataType_ = BDataType_,
208  bool HasHotLoop_ = true,
209  TailNumber TailNum_ = TailNumber::Full>
211  GemmQuantPipelineProblemBase<ADataType_,
212  AccDataType_,
213  BDataType_,
214  AccDataType_,
215  CDataType_,
216  BlockGemmShape_,
217  Traits_,
218  void,
219  QuantGroupShape<sequence<1, 1, 1>>, // no group size applicable
220  TransposeC_,
221  ComputeDataType_,
222  Scheduler_,
223  HasHotLoop_,
224  TailNum_>;
225 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
GemmPipelineScheduler
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:14
Definition: gemm_pipeline_problem.hpp:25
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:34
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:80
remove_cvref_t< std::tuple_element_t< number< 0 >{}, ComputeDataTypeTuple > > ComputeDataType
Definition: gemm_pipeline_problem.hpp:66
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsLayoutTuple > > ALayout
Definition: gemm_pipeline_problem.hpp:68
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition: gemm_pipeline_problem.hpp:69
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:76
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:78
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition: gemm_pipeline_problem.hpp:67
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:41
remove_cvref_t< EDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:30
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:26
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:79
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:84
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsLayoutTuple > > BLayout
Definition: gemm_pipeline_problem.hpp:70
Definition: gemm_quant_pipeline_problem.hpp:34
static constexpr auto HasHotLoop
Definition: gemm_quant_pipeline_problem.hpp:78
typename Base::BlockGemmShape BlockGemmShape
Definition: gemm_quant_pipeline_problem.hpp:51
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:80
remove_cvref_t< BQDataType_ > BQDataType
Definition: gemm_quant_pipeline_problem.hpp:49
static CK_TILE_HOST const std::string GetName()
Definition: gemm_quant_pipeline_problem.hpp:86
static constexpr bool DoubleSmemBuffer
Definition: gemm_quant_pipeline_problem.hpp:65
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentBQ()
Definition: gemm_quant_pipeline_problem.hpp:109
remove_cvref_t< typename Traits::BQLayout > BQLayout
Definition: gemm_quant_pipeline_problem.hpp:75
remove_cvref_t< typename Traits::AQLayout > AQLayout
Definition: gemm_quant_pipeline_problem.hpp:74
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:76
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:78
static constexpr index_t VectorSizeBQ
Definition: gemm_quant_pipeline_problem.hpp:114
static constexpr auto Scheduler
Definition: gemm_quant_pipeline_problem.hpp:77
BQuantGroupSize QuantGroupSize
Definition: gemm_quant_pipeline_problem.hpp:57
static constexpr index_t VectorSizeAQ
Definition: gemm_quant_pipeline_problem.hpp:104
static constexpr bool PreshuffleB
Definition: gemm_quant_pipeline_problem.hpp:64
static constexpr bool TransposeC
Definition: gemm_quant_pipeline_problem.hpp:63
remove_cvref_t< AQDataType_ > AQDataType
Definition: gemm_quant_pipeline_problem.hpp:48
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentAQ()
Definition: gemm_quant_pipeline_problem.hpp:98
std::conditional_t<!std::is_void_v< AQuantGroupSize_ >, AQuantGroupSize_, BQuantGroupSize_ > AQuantGroupSize
Definition: gemm_quant_pipeline_problem.hpp:53
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:79
typename Base::Traits Traits
Definition: gemm_quant_pipeline_problem.hpp:42
static constexpr auto TailNum
Definition: gemm_quant_pipeline_problem.hpp:79
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:84
std::conditional_t<!std::is_void_v< BQuantGroupSize_ >, BQuantGroupSize_, AQuantGroupSize_ > BQuantGroupSize
Definition: gemm_quant_pipeline_problem.hpp:55
Definition: gemm_group_quant_utils.hpp:470