/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp Source File
gridwise_ab_transfer_wave_tiles.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
8 #include "ck/utility/math.hpp"
9 
10 namespace ck {
11 
12 template <typename ABLayout,
13  typename ABMajorLayout,
14  typename LDSTypeAB,
15  index_t BlockSize,
16  index_t MNPerBlock,
17  index_t KPerBlock,
18  index_t MNPerWmma,
19  index_t KPack,
20  index_t ABK1Value,
21  index_t WaveSize>
23 {
24  __device__ static constexpr bool IsLDSNeeded() { return true; }
25 
26  static_assert(!(is_same_v<remove_cvref_t<LDSTypeAB>, pk_i4_t>),
27  "wave tile transfer method does not support pk_i4_t");
28  static constexpr auto I0 = Number<0>{};
29  static constexpr auto I1 = Number<1>{};
30  static constexpr auto I2 = Number<2>{};
31  static constexpr auto I3 = Number<3>{};
32 
33  static constexpr index_t MNKRow = 2;
34 
36 
37  // Tiles distribution for global memory loading
38  // Notes: support for not power of 2 needs to be reviewed later on
39  // The tiles are distributed along the non-contiguous matrix dimension
40  // Example 4 waves A row-major MPerBlock = 64, KPerBlock = 64
41  // MRepeat = 1, KRepeat = 4
42  // -------------
43  // |W0| | | |
44  // -------------
45  // |W1| | | |
46  // -------------
47  // |W2| | | |
48  // -------------
49  // |W3| | | |
50  // -------------
51  // Example 4 waves A column-major MPerBlock = 64, KPerBlock = 64
52  // MRepeat = 4, KRepeat = 1
53  // -------------
54  // |W0|W1|W2|W3|
55  // -------------
56  // | | | | |
57  // -------------
58  // | | | | |
59  // -------------
60  // | | | | |
61  // -------------
62  static constexpr index_t NumberOfWaves = BlockSize / WaveSize;
63  static constexpr index_t MNMajorWaves_ =
64  MNPerBlock / MNPerWmma % std::min(MNPerBlock / MNPerWmma, NumberOfWaves) == 0
65  ? std::min(MNPerBlock / MNPerWmma, NumberOfWaves)
66  : (MNPerBlock / MNPerWmma % 2 == 0 ? 2 : 1);
67  static constexpr index_t KMajorWaves_ =
68  KPerBlock / KPack % std::min(KPerBlock / KPack, NumberOfWaves) == 0
69  ? std::min(KPerBlock / KPack, NumberOfWaves)
70  : (KPerBlock / KPack % 2 == 0 ? 2 : 1);
71 
72  static constexpr bool ABDoTranspose = !is_same_v<ABLayout, ABMajorLayout>;
73 
74  static constexpr index_t MNWaves_ =
77  static constexpr index_t KRepeat_ = KPerBlock / (KWaves_ * KPack);
78  static constexpr index_t MNRepeat_ = MNPerBlock / (MNWaves_ * MNPerWmma);
79 
80  template <bool PadMN, bool PadK, typename GridDescriptorBase>
81  __host__ __device__ static auto PadGridDescriptor(GridDescriptorBase& base_desc,
82  index_t sizeMN,
83  index_t MNPad,
84  index_t sizeK,
85  index_t KPad,
86  index_t,
87  index_t)
88  {
89  if constexpr(PadMN && PadK)
90  {
91  // pad both MN and K
93  base_desc,
94  make_tuple(make_right_pad_transform(sizeMN, MNPad - sizeMN),
95  make_right_pad_transform(sizeK, KPad - sizeK)),
98  }
99  else if constexpr(PadMN && !PadK)
100  {
101  // pad MN, but not K
103  base_desc,
104  make_tuple(make_right_pad_transform(sizeMN, MNPad - sizeMN),
108  }
109  else if constexpr(!PadMN && PadK)
110  {
111  // pad K, but not MN
113  base_desc,
115  make_right_pad_transform(sizeK, KPad - sizeK)),
118  }
119  else
120  {
121  // not pad MN or K
122  return base_desc;
123  }
124  }
125 
126  template <bool PadMN, bool PadK, typename GridDescriptorBase>
127  __host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc,
128  index_t sizeMN,
129  index_t MNPad,
130  index_t sizeK,
131  index_t KPad,
132  index_t,
133  index_t)
134  {
135  // Notes: padding is currently not supported with transpose
136  static_assert(!((PadMN || PadK) && ABDoTranspose),
137  "padding is currently not supported with transpose");
138 
139  const index_t MN_grid = !PadMN ? sizeMN : MNPad;
140  const index_t K_grid = !PadK ? sizeK : KPad;
141 
142  const auto base_desc_padded =
143  PadGridDescriptor<PadMN, PadK>(base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0);
144 
145  // Divide the base descriptor MN_K into tiles
146  const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor(
147  base_desc_padded,
148  make_tuple(
155 
156  // The distinction is needed to get the same global indices for both layouts
157  // Divide each tile in 2 16x8 subtile
158  // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize
159  // MNKRow = 0-1
160  // LaneLocal = 0-15
161  // VectorSize must be 8
162  if constexpr(!ABDoTranspose)
163  {
164  const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 =
166  ab_grid_desc_mntiles_ktiles,
173  make_tuple(Number<MNKRow>{}, Number<KPack / MNKRow>{}))),
176 
177  // Freeze VectorSize to first element of the loading chunk (for convenience)
178  // Swap MNPerWmma and MNKRow for consistency with transpose descriptor
180  ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
181  make_tuple(
188  make_tuple(
190  make_tuple(
192  }
193  else
194  {
195  const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 =
197  ab_grid_desc_mntiles_ktiles,
203  make_tuple(Number<MNKRow>{}, Number<MNPerWmma / MNKRow>{})),
207 
208  // Freeze VectorSize to first element of the loading chunk (for convenience)
210  ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
211  make_tuple(
218  make_tuple(
220  make_tuple(
222  }
223  }
224 
225  __device__ static constexpr auto GetBlockDescriptor()
226  {
227  // LDS memory layouts:
228  // lanes within tiles stored contiguously in chunks of 8 elements
229  // tiles are then stored first in K dimension
230  // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize
231  const auto a_grid_desc_mraw_kraw = [&]() {
235  Number<MNKRow>{},
242  I1));
243  }();
244 
245  // Freeze VectorSize to first element of the chunk (for convenience)
247  a_grid_desc_mraw_kraw,
255  }
256 
257  __device__ static auto GetWaveIdx()
258  {
259  const index_t thread_id = ThisThreadBlock::GetThreadId();
260 
261  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
265 
266  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
267  }
268 
269  __device__ static auto GetBlockLaneIdx()
270  {
271  const index_t lane_id = __lane_id();
272 
273  constexpr index_t LanesPerSubTile = ABDoTranspose ? KPack : MNPerWmma;
274 
275  constexpr auto laneid_to_block_lane_idx_adaptor = make_single_stage_tensor_adaptor(
276  make_tuple(make_merge_transform(make_tuple(MNKRow, LanesPerSubTile))),
279 
280  return laneid_to_block_lane_idx_adaptor.CalculateBottomIndex(make_multi_index(lane_id));
281  }
282 
283  template <typename ABDataType>
284  __device__ static auto GetGridLaneIdx()
285  {
286  const index_t lane_id = __lane_id();
287 
288  constexpr index_t SubTilesRow = MNKRow;
289  constexpr index_t SubTilesCol = 4 / sizeof(ABDataType);
290  constexpr index_t LanesPerSubTile =
291  ABDoTranspose ? KPack / SubTilesCol : MNPerWmma / SubTilesCol;
292  constexpr auto dims_tuple = ABDoTranspose
293  ? make_tuple(SubTilesCol, SubTilesRow, LanesPerSubTile)
294  : make_tuple(SubTilesRow, SubTilesCol, LanesPerSubTile);
295 
296  constexpr auto laneid_to_grid_lane_idx_adaptor =
300 
301  const auto indices =
302  laneid_to_grid_lane_idx_adaptor.CalculateBottomIndex(make_multi_index(lane_id));
303 
304  if constexpr(!ABDoTranspose)
305  {
306  return make_multi_index(indices[I0], indices[I1] * LanesPerSubTile + indices[I2]);
307  }
308  else
309  {
310  return make_multi_index(indices[I1], indices[I0] * LanesPerSubTile + indices[I2]);
311  }
312  }
313 
314  template <typename GridDescriptor,
315  typename BlockDescriptor,
316  typename ABsDataType,
317  typename ABElementwiseOperation,
318  index_t GlobalBufferNum>
319  __device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor,
320  BlockDescriptor& block_descriptor,
321  ABElementwiseOperation& ab_element_op,
322  const index_t block_mn_id,
323  const index_t)
324  {
325  // Note: GlobalBufferNum is currently not used but it will be needed
326  // once we add other pipelines. It is currently needed only for
327  // consistency with the thread tiles approach
328  static_assert(GlobalBufferNum == 1, "single global buffer is only supported");
329  constexpr index_t NumABTensor = ABsDataType::Size();
330  static_assert(NumABTensor == 1, "multiAB currently not supported");
331 
333 
334  const auto wave_idx = GetWaveIdx();
335  index_t wave_idK = wave_idx[I1];
336  index_t wave_idMN = wave_idx[I0];
337 
338  const auto grid_lane_id = GetGridLaneIdx<ABDataType>();
339  index_t lane_group_grid = grid_lane_id[I0];
340  index_t lane_local_id_grid = grid_lane_id[I1];
341 
342  const auto block_lane_id = GetBlockLaneIdx();
343  index_t lane_group_block = block_lane_id[I0];
344  index_t lane_local_id_block = block_lane_id[I1];
345 
346  return ThreadGroupTransferGlobal<decltype(grid_descriptor[I0]),
347  BlockDescriptor,
348  ABDataType,
349  ABDataType,
350  ABElementwiseOperation,
354  ABK1Value,
355  ABDoTranspose>(
356  grid_descriptor[I0],
357  block_descriptor,
358  make_multi_index(block_mn_id * (MNRepeat_ * MNWaves_) + wave_idMN,
359  wave_idK,
360  lane_group_grid,
361  lane_local_id_grid),
362  make_multi_index(wave_idMN, wave_idK, lane_group_block, lane_local_id_block),
363  ab_element_op);
364  }
365 
366  template <index_t MNRepeat, index_t MNWaves>
367  __host__ __device__ static constexpr auto MakeWmmaTileDescriptor()
368  {
369  // This is a block descriptor used to read LDS memory into register
370  // It's defined in a way consistent with the existing implementation to
371  // avoid changes in the pipelines
374  Number<KPerBlock / KPack>{},
375  Number<MNWaves>{},
376  Number<MNKRow>{},
379  make_tuple(I0,
385  I1));
386  }
387 
388  __device__ static constexpr auto GetBlockStep()
389  {
390  // Grid descriptor step (MoveSrcSliceWindow)
391  return make_multi_index(I0, KWaves_ * KRepeat_, I0, I0);
392  }
393 
394  template <typename GridDescriptor>
395  __device__ static constexpr index_t GetKDimension(const GridDescriptor& grid_desc)
396  {
397  return grid_desc.GetLength(I1) * KPack;
398  }
399 
400  template <typename LDSType, typename IndexType>
401  __device__ static auto GetBuffer(LDSType* p_shared_AB, const IndexType& size)
402  {
403  return make_dynamic_buffer<AddressSpaceEnum::Lds>(p_shared_AB, size);
404  }
405 };
406 
407 } // namespace ck
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
auto grid_desc(MatrixPadder< GemmSpec, MPerTileType, NPerTileType, KPerTileType > matrix_padder, CDesc_MRaw_NRaw conv_desc)
Definition: matrix_padder.hpp:199
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:212
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: gridwise_ab_transfer_wave_tiles.hpp:23
static __device__ auto GetWaveIdx()
Definition: gridwise_ab_transfer_wave_tiles.hpp:257
static constexpr __device__ bool IsLDSNeeded()
Definition: gridwise_ab_transfer_wave_tiles.hpp:24
static constexpr index_t MNRepeat_
Definition: gridwise_ab_transfer_wave_tiles.hpp:78
static __device__ auto GetBuffer(LDSType *p_shared_AB, const IndexType &size)
Definition: gridwise_ab_transfer_wave_tiles.hpp:401
static __device__ auto GetGridLaneIdx()
Definition: gridwise_ab_transfer_wave_tiles.hpp:284
static constexpr __device__ auto GetBlockDescriptor()
Definition: gridwise_ab_transfer_wave_tiles.hpp:225
static __device__ auto GetBlockTransfer(GridDescriptor &grid_descriptor, BlockDescriptor &block_descriptor, ABElementwiseOperation &ab_element_op, const index_t block_mn_id, const index_t)
Definition: gridwise_ab_transfer_wave_tiles.hpp:319
static constexpr auto I2
Definition: gridwise_ab_transfer_wave_tiles.hpp:30
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_ab_transfer_wave_tiles.hpp:35
static constexpr index_t KWaves_
Definition: gridwise_ab_transfer_wave_tiles.hpp:76
__host__ static constexpr __device__ auto MakeWmmaTileDescriptor()
Definition: gridwise_ab_transfer_wave_tiles.hpp:367
static constexpr __device__ index_t GetKDimension(const GridDescriptor &grid_desc)
Definition: gridwise_ab_transfer_wave_tiles.hpp:395
static constexpr index_t KMajorWaves_
Definition: gridwise_ab_transfer_wave_tiles.hpp:67
static constexpr index_t MNMajorWaves_
Definition: gridwise_ab_transfer_wave_tiles.hpp:63
static constexpr auto I1
Definition: gridwise_ab_transfer_wave_tiles.hpp:29
static constexpr auto I3
Definition: gridwise_ab_transfer_wave_tiles.hpp:31
static constexpr index_t MNKRow
Definition: gridwise_ab_transfer_wave_tiles.hpp:33
static constexpr auto I0
Definition: gridwise_ab_transfer_wave_tiles.hpp:28
static constexpr bool ABDoTranspose
Definition: gridwise_ab_transfer_wave_tiles.hpp:72
__host__ static __device__ auto MakeGridDescriptor(GridDescriptorBase &base_desc, index_t sizeMN, index_t MNPad, index_t sizeK, index_t KPad, index_t, index_t)
Definition: gridwise_ab_transfer_wave_tiles.hpp:127
static constexpr index_t MNWaves_
Definition: gridwise_ab_transfer_wave_tiles.hpp:74
__host__ static __device__ auto PadGridDescriptor(GridDescriptorBase &base_desc, index_t sizeMN, index_t MNPad, index_t sizeK, index_t KPad, index_t, index_t)
Definition: gridwise_ab_transfer_wave_tiles.hpp:81
static constexpr __device__ auto GetBlockStep()
Definition: gridwise_ab_transfer_wave_tiles.hpp:388
static constexpr index_t KRepeat_
Definition: gridwise_ab_transfer_wave_tiles.hpp:77
static constexpr index_t NumberOfWaves
Definition: gridwise_ab_transfer_wave_tiles.hpp:62
static __device__ auto GetBlockLaneIdx()
Definition: gridwise_ab_transfer_wave_tiles.hpp:269
Definition: sequence.hpp:43
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
Definition: thread_group_tensor_slice_transfer_global.hpp:26
Definition: integral_constant.hpp:20
Definition: data_type.hpp:187