/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/device/matrix_padder.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/device/matrix_padder.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/device/matrix_padder.hpp Source File
matrix_padder.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
10 
11 namespace ck {
12 namespace tensor_operation {
13 namespace device {
14 
15 template <typename TensorDesc,
16  typename TileLengths, // Tuple<...>
17  typename DoPads> // Sequence<bool, bool, ...>
18 __host__ __device__ constexpr auto
19 PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoPads)
20 {
21  constexpr index_t num_dim = DoPads::Size();
22 
23  static_assert(num_dim == TileLengths::Size() && num_dim == TensorDesc::GetNumOfDimension(),
24  "wrong! inconsistent # of dimensions");
25 
26  // transforms
27  const auto transforms = generate_tuple(
28  [&](auto idim) {
29  const auto MRaw = desc.GetLength(idim);
30 
31  const auto MPerTile = tile_lengths[idim];
32 
33  const auto M = math::integer_divide_ceil(MRaw, MPerTile) * MPerTile;
34 
35  const auto MPad = M - MRaw;
36 
37  const bool DoPadM = DoPads::At(idim);
38 
39  const auto MTransform = conditional_expr<DoPadM>(make_right_pad_transform(MRaw, MPad),
41 
42  return MTransform;
43  },
44  Number<num_dim>{});
45 
46  // lower dimension Id
47  const auto lower_dimss =
48  generate_tuple([&](auto idim) { return Sequence<idim.value>{}; }, Number<num_dim>{});
49 
50  // upper dimension Id
51  const auto upper_dimss = lower_dimss;
52 
53  return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss);
54 }
55 
56 // M/N/K/OPerTileType could be index_t or Number<>
57 template <GemmSpecialization GemmSpec,
58  typename MPerTileType,
59  typename NPerTileType,
60  typename KPerTileType,
61  typename OPerTileType>
63 {
64  // TODO: hard to scale; use mask instead
65  static constexpr bool PadM =
70  static constexpr bool PadN =
75  static constexpr bool PadK =
80  static constexpr bool PadO =
85 
86  // A[M, K]
87  template <typename ADesc_MRaw_KRaw>
88  __host__ __device__ constexpr auto
89  PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
90  {
91  return PadTensorDescriptor(
92  a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
93  }
94 
95  // B[K, N]
96  template <typename BDesc_NRaw_KRaw>
97  __host__ __device__ constexpr auto
98  PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
99  {
100  return PadTensorDescriptor(
101  b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
102  }
103 
104  // D0[M, N]
105  template <typename D0Desc_MRaw_NRaw>
106  __host__ __device__ constexpr auto
107  PadD0Descriptor_N_K(const D0Desc_MRaw_NRaw& d0_desc_mraw_nraw) const
108  {
109  return PadTensorDescriptor(
110  d0_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
111  }
112 
113  // B1[Gemm1N, Gemm1K] = B1[O, N]
114  template <typename B1Desc_NRaw_KRaw>
115  __host__ __device__ constexpr auto
116  PadB1Descriptor_N_K(const B1Desc_NRaw_KRaw& b1_desc_nraw_kraw) const
117  {
118  return PadTensorDescriptor(
119  b1_desc_nraw_kraw, make_tuple(OPerTile_, NPerTile_), Sequence<PadO, PadN>{});
120  }
121 
122  // C[M, Gemm1N] = C[M, O]
123  template <typename CDesc_MRaw_NRaw>
124  __host__ __device__ constexpr auto
125  PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
126  {
127  return PadTensorDescriptor(
128  c_desc_mraw_nraw, make_tuple(MPerTile_, OPerTile_), Sequence<PadM, PadO>{});
129  }
130 
131  MPerTileType MPerTile_;
132  NPerTileType NPerTile_;
133  KPerTileType KPerTile_;
134  OPerTileType OPerTile_;
135 };
136 
137 // M/N/KPerTileType could be index_t or Number<>
138 template <GemmSpecialization GemmSpec,
139  typename MPerTileType,
140  typename NPerTileType,
141  typename KPerTileType>
143 {
144  static constexpr bool PadM =
145  (GemmSpec == GemmSpecialization::MPadding || GemmSpec == GemmSpecialization::MNPadding ||
147  static constexpr bool PadN =
148  (GemmSpec == GemmSpecialization::NPadding || GemmSpec == GemmSpecialization::MNPadding ||
150  static constexpr bool PadK =
151  (GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::MKPadding ||
153 
154  template <typename ADesc_MRaw_KRaw>
155  __host__ __device__ constexpr auto
156  PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
157  {
158  return PadTensorDescriptor(
159  a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
160  }
161 
162  template <typename BDesc_NRaw_KRaw>
163  __host__ __device__ constexpr auto
164  PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
165  {
166  return PadTensorDescriptor(
167  b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
168  }
169 
170  template <typename CDesc_MRaw_NRaw>
171  __host__ __device__ constexpr auto
172  PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
173  {
174  return PadTensorDescriptor(
175  c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
176  }
177 
178  MPerTileType MPerTile_;
179  NPerTileType NPerTile_;
180  KPerTileType KPerTile_;
181 };
182 
183 // Alias of GemmPadder; to deprecate
184 template <GemmSpecialization GemmSpec,
185  typename MPerTileType,
186  typename NPerTileType,
187  typename KPerTileType>
188 struct MatrixPadder : public GemmPadder<GemmSpec, MPerTileType, NPerTileType, KPerTileType>
189 {
190 };
191 
192 // function to take in a struct of type MatrixPadder and call the appropriate function to get
193 // the output descriptor at runtime for codegen
194 template <GemmSpecialization GemmSpec,
195  typename MPerTileType,
196  typename NPerTileType,
197  typename KPerTileType,
198  typename CDesc_MRaw_NRaw>
200  CDesc_MRaw_NRaw conv_desc)
201 {
202  auto res = matrix_padder.PadCDescriptor_M_N(conv_desc);
203  return res;
204 }
205 // M/N/KPerTileType could be index_t or Number<>
206 template <bool PadM,
207  bool PadN,
208  bool PadK,
209  typename MPerTileType,
210  typename NPerTileType,
211  typename KPerTileType>
213 {
214  template <typename ADesc_MRaw_KRaw>
215  __host__ __device__ constexpr auto
216  PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
217  {
218  return PadTensorDescriptor(
219  a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
220  }
221 
222  template <typename BDesc_NRaw_KRaw>
223  __host__ __device__ constexpr auto
224  PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
225  {
226  return PadTensorDescriptor(
227  b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
228  }
229 
230  template <typename CDesc_MRaw_NRaw>
231  __host__ __device__ constexpr auto
232  PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
233  {
234  return PadTensorDescriptor(
235  c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
236  }
237 
238  MPerTileType MPerTile_;
239  NPerTileType NPerTile_;
240  KPerTileType KPerTile_;
241 };
242 
243 // M/N/KPerTileType could be index_t or Number<>
244 template <bool PadM,
245  bool PadN,
246  bool PadK,
247  typename MPerTileType,
248  typename NPerTileType,
249  typename KPerTileType>
251 {
252  static constexpr auto I0 = Number<0>{};
253  static constexpr auto I1 = Number<1>{};
254  static constexpr auto I2 = Number<2>{};
255  static constexpr auto I3 = Number<3>{};
256 
257  template <typename ADesc_MRaw_KRaw>
258  __host__ __device__ constexpr auto
259  PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
260  {
261  const auto MRaw = a_desc_mraw_kraw.GetLength(I0);
262  const auto KRaw = a_desc_mraw_kraw.GetLength(I1);
263 
264  const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
265  const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
266 
267  const auto MPad = M - MRaw;
268  const auto KPad = K - KRaw;
269 
270  if constexpr(PadM && PadK)
271  {
272  // pad both M and K
273  return transform_tensor_descriptor(a_desc_mraw_kraw,
275  make_right_pad_transform(KRaw, KPad)),
278  }
279  else if constexpr(PadM && (!PadK))
280  {
281  // pad M, but not K
283  a_desc_mraw_kraw,
287  }
288  else if constexpr((!PadM) && PadK)
289  {
290  // pad K, but not M
292  a_desc_mraw_kraw,
296  }
297  else
298  {
299  // not pad M or K
300  return a_desc_mraw_kraw;
301  }
302  }
303 
304  template <typename BDesc_NRaw_KRaw>
305  __host__ __device__ constexpr auto
306  PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
307  {
308  const auto NRaw = b_desc_nraw_kraw.GetLength(I0);
309  const auto KRaw = b_desc_nraw_kraw.GetLength(I1);
310 
311  const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
312  const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
313 
314  const auto NPad = N - NRaw;
315  const auto KPad = K - KRaw;
316 
317  if constexpr(PadN && PadK)
318  {
319  // pad both N and K
320  return transform_tensor_descriptor(b_desc_nraw_kraw,
322  make_right_pad_transform(KRaw, KPad)),
325  }
326  else if constexpr(PadN && (!PadK))
327  {
328  // pad N, but not K
330  b_desc_nraw_kraw,
334  }
335  else if constexpr((!PadN) && PadK)
336  {
337  // pad K, but not N
339  b_desc_nraw_kraw,
343  }
344  else
345  {
346  // not pad N or K
347  return b_desc_nraw_kraw;
348  }
349  }
350 
351  template <typename CDesc_MRaw_NRaw>
352  __host__ __device__ constexpr auto
353  PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
354  {
355  const auto MRaw = c_desc_mraw_nraw.GetLength(I0);
356  const auto NRaw = c_desc_mraw_nraw.GetLength(I1);
357 
358  const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
359  const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
360 
361  const auto MPad = M - MRaw;
362  const auto NPad = N - NRaw;
363 
364  if constexpr(PadM && PadN)
365  {
366  // pad M and N
367  return transform_tensor_descriptor(c_desc_mraw_nraw,
369  make_right_pad_transform(NRaw, NPad)),
372  }
373  else if constexpr(PadM && (!PadN))
374  {
375  // pad M, but not N
377  c_desc_mraw_nraw,
381  }
382  else if constexpr((!PadM) && PadN)
383  {
384  // pad N, but not M
386  c_desc_mraw_nraw,
390  }
391  else
392  {
393  // not pad M or N
394  return c_desc_mraw_nraw;
395  }
396  }
397 
398  MPerTileType MPerTile_;
399  NPerTileType NPerTile_;
400  KPerTileType KPerTile_;
401 };
402 } // namespace device
403 } // namespace tensor_operation
404 } // namespace ck
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
auto grid_desc(MatrixPadder< GemmSpec, MPerTileType, NPerTileType, KPerTileType > matrix_padder, CDesc_MRaw_NRaw conv_desc)
Definition: matrix_padder.hpp:199
__host__ constexpr __device__ auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition: matrix_padder.hpp:19
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:270
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:212
int32_t index_t
Definition: ck.hpp:301
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: matrix_padder.hpp:63
KPerTileType KPerTile_
Definition: matrix_padder.hpp:133
OPerTileType OPerTile_
Definition: matrix_padder.hpp:134
static constexpr bool PadM
Definition: matrix_padder.hpp:65
MPerTileType MPerTile_
Definition: matrix_padder.hpp:131
__host__ constexpr __device__ auto PadD0Descriptor_N_K(const D0Desc_MRaw_NRaw &d0_desc_mraw_nraw) const
Definition: matrix_padder.hpp:107
static constexpr bool PadN
Definition: matrix_padder.hpp:70
__host__ constexpr __device__ auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition: matrix_padder.hpp:125
__host__ constexpr __device__ auto PadADescriptor_M_K(const ADesc_MRaw_KRaw &a_desc_mraw_kraw) const
Definition: matrix_padder.hpp:89
NPerTileType NPerTile_
Definition: matrix_padder.hpp:132
static constexpr bool PadO
Definition: matrix_padder.hpp:80
__host__ constexpr __device__ auto PadB1Descriptor_N_K(const B1Desc_NRaw_KRaw &b1_desc_nraw_kraw) const
Definition: matrix_padder.hpp:116
static constexpr bool PadK
Definition: matrix_padder.hpp:75
__host__ constexpr __device__ auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw &b_desc_nraw_kraw) const
Definition: matrix_padder.hpp:98
Definition: matrix_padder.hpp:213
__host__ constexpr __device__ auto PadADescriptor_M_K(const ADesc_MRaw_KRaw &a_desc_mraw_kraw) const
Definition: matrix_padder.hpp:216
MPerTileType MPerTile_
Definition: matrix_padder.hpp:238
NPerTileType NPerTile_
Definition: matrix_padder.hpp:239
__host__ constexpr __device__ auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw &b_desc_nraw_kraw) const
Definition: matrix_padder.hpp:224
__host__ constexpr __device__ auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition: matrix_padder.hpp:232
KPerTileType KPerTile_
Definition: matrix_padder.hpp:240
Definition: matrix_padder.hpp:143
NPerTileType NPerTile_
Definition: matrix_padder.hpp:179
MPerTileType MPerTile_
Definition: matrix_padder.hpp:178
static constexpr bool PadK
Definition: matrix_padder.hpp:150
KPerTileType KPerTile_
Definition: matrix_padder.hpp:180
__host__ constexpr __device__ auto PadADescriptor_M_K(const ADesc_MRaw_KRaw &a_desc_mraw_kraw) const
Definition: matrix_padder.hpp:156
__host__ constexpr __device__ auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw &b_desc_nraw_kraw) const
Definition: matrix_padder.hpp:164
__host__ constexpr __device__ auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition: matrix_padder.hpp:172
static constexpr bool PadM
Definition: matrix_padder.hpp:144
static constexpr bool PadN
Definition: matrix_padder.hpp:147
Definition: matrix_padder.hpp:251
KPerTileType KPerTile_
Definition: matrix_padder.hpp:400
__host__ constexpr __device__ auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw &b_desc_nraw_kraw) const
Definition: matrix_padder.hpp:306
static constexpr auto I2
Definition: matrix_padder.hpp:254
MPerTileType MPerTile_
Definition: matrix_padder.hpp:398
__host__ constexpr __device__ auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition: matrix_padder.hpp:353
static constexpr auto I3
Definition: matrix_padder.hpp:255
static constexpr auto I0
Definition: matrix_padder.hpp:252
NPerTileType NPerTile_
Definition: matrix_padder.hpp:399
__host__ constexpr __device__ auto PadADescriptor_M_K(const ADesc_MRaw_KRaw &a_desc_mraw_kraw) const
Definition: matrix_padder.hpp:259
static constexpr auto I1
Definition: matrix_padder.hpp:253
Definition: matrix_padder.hpp:189