/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp Source File
epilogue_direct_store.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 template <typename DsDataType,
11  typename EDataType,
12  typename AccDataType,
13  index_t MRepeat,
14  index_t NRepeat,
15  typename CDEElementwiseOperation,
16  typename BlockwiseGemmPipe>
18 {
19  static constexpr auto I0 = Number<0>{};
20  static constexpr auto I1 = Number<1>{};
21  static constexpr auto I2 = Number<2>{};
22  static constexpr auto I3 = Number<3>{};
23  static constexpr auto I4 = Number<4>{};
24  static constexpr auto I5 = Number<5>{};
25  static constexpr auto I6 = Number<6>{};
26 
27  __device__ static constexpr bool IsLDSNeeded() { return false; }
28 
29  template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
30  typename CThreadBuf,
31  typename DsGridPointer,
32  typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
33  typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
34  __device__ static void Run(CThreadBuf& c_thread_buf,
35  DsGridPointer,
36  EDataType* p_e_grid,
37  void*,
38  const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&,
39  const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
40  e_grid_desc_mblock_mperblock_nblock_nperblock,
41  CDEElementwiseOperation& cde_element_op,
42  const index_t& block_m_id,
43  const index_t& block_n_id)
44  {
45  auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
46  p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
47 
48  // C mapping in single thread.
49  constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
50  BlockwiseGemmPipe::
51  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
52 
53  // C mapping in single block
54  constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
55  BlockwiseGemmPipe::
56  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
57 
58  constexpr auto MWave =
59  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
60  .GetLength(I1);
61  constexpr auto MSubGroup =
62  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
63  .GetLength(I2);
64  constexpr auto NWave =
65  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
66  .GetLength(I4);
67  constexpr auto NThreadPerSubGroup =
68  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
69  .GetLength(I5);
70  constexpr auto MAccVgprs =
71  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
72  .GetLength(I6);
73 
74  // origin
75  const auto c_thread_mtx_on_block =
76  BlockwiseGemmPipe::CalculateCThreadOriginDataIndex(I0, I0);
77 
78  const auto m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
80  make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
83 
84  const auto m_thread_data_on_grid_idx =
85  m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
86  make_multi_index(c_thread_mtx_on_block[I0]));
87 
88  const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
90  make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
93 
94  const auto n_thread_data_on_grid_idx =
95  n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
96  make_multi_index(c_thread_mtx_on_block[I1]));
97 
98  // E grid descriptor
99  const auto c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
101  e_grid_desc_mblock_mperblock_nblock_nperblock,
102  make_tuple(make_freeze_transform(block_m_id),
104  Number<MWave>{},
106  Number<MAccVgprs>{})),
107  make_freeze_transform(block_n_id),
111  make_tuple(
113 
114  auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
115  AccDataType,
116  EDataType,
117  decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
118  decltype(c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
119  CDEElementwiseOperation,
122  3,
123  NRepeat, // VectorSize
124  EGlobalMemoryDataOperation,
125  1,
126  false>{c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
127  make_multi_index(m_thread_data_on_grid_idx[I0],
128  m_thread_data_on_grid_idx[I1],
129  m_thread_data_on_grid_idx[I2],
130  n_thread_data_on_grid_idx[I0],
131  n_thread_data_on_grid_idx[I1],
132  n_thread_data_on_grid_idx[I2],
133  m_thread_data_on_grid_idx[I3]),
134  cde_element_op};
135 
136  c_thread_copy.Run(
137  c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
138  make_tuple(I0, I0, I0, I0, I0, I0, I0),
139  c_thread_buf,
140  c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
141  e_grid_buf);
142  }
143 };
144 
145 } // namespace ck
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
InMemoryDataOperationEnum
Definition: ck.hpp:279
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:212
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: epilogue_direct_store.hpp:18
static constexpr auto I4
Definition: epilogue_direct_store.hpp:23
static constexpr auto I5
Definition: epilogue_direct_store.hpp:24
static constexpr auto I1
Definition: epilogue_direct_store.hpp:20
static constexpr __device__ bool IsLDSNeeded()
Definition: epilogue_direct_store.hpp:27
static constexpr auto I0
Definition: epilogue_direct_store.hpp:19
static constexpr auto I6
Definition: epilogue_direct_store.hpp:25
static constexpr auto I2
Definition: epilogue_direct_store.hpp:21
static constexpr auto I3
Definition: epilogue_direct_store.hpp:22
static __device__ void Run(CThreadBuf &c_thread_buf, DsGridPointer, EDataType *p_e_grid, void *, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, CDEElementwiseOperation &cde_element_op, const index_t &block_m_id, const index_t &block_n_id)
Definition: epilogue_direct_store.hpp:34
Definition: sequence.hpp:43
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20