/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/host/tensor_shuffle_utils.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/host/tensor_shuffle_utils.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/host/tensor_shuffle_utils.hpp Source File
tensor_shuffle_utils.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 #include "device_prop.hpp"
6 #include <stdexcept>
7 
8 namespace ck_tile {
9 template <typename T>
10 auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
11 {
12  if(t->get_lengths().size() != 2)
13  {
14  throw std::runtime_error("Host tensor is not rank 2 tensor.");
15  }
16  int m_ = t->get_lengths()[0];
17  int aqk_ = t->get_lengths()[1];
18 
19  if(aqk_ % block_aq_k != 0)
20  {
21  throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
22  }
23  ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
24  std::copy(t->begin(), t->end(), t_view.begin());
25  return ck_tile::reference_permute(t_view, {1, 0, 2});
26 }
27 
28 template <typename T>
29 auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
30 {
31  const auto& lengths = t->get_lengths();
32  const size_t rank = lengths.size();
33 
34  // Validate block_bq_k divisibility based on rank
35  int bqk_dim = (rank == 5) ? lengths[4] : (rank == 2) ? lengths[0] : -1;
36 
37  if(bqk_dim < 0)
38  {
39  throw std::runtime_error("shuffle_bq expects either rank-2 or rank-5 tensor, got rank " +
40  std::to_string(rank));
41  }
42 
43  if(bqk_dim % block_bq_k != 0)
44  {
45  throw std::runtime_error("shuffle_bq needs bqk dimension to be a multiple of block_bq_k.");
46  }
47 
48  // For TilePermuteN
49  if(rank == 5)
50  {
51  // Handle 5D tensor: [n, nrepeat, nwarp, n_warp_tile, bqk]
52  ck_tile::HostTensor<T> t_view({static_cast<int>(lengths[0]),
53  static_cast<int>(lengths[1]),
54  static_cast<int>(lengths[2]),
55  static_cast<int>(lengths[3]),
56  bqk_dim / block_bq_k,
57  block_bq_k});
58  std::copy(t->begin(), t->end(), t_view.begin());
59  return ck_tile::reference_permute(t_view, {4, 0, 1, 2, 3, 5});
60  }
61  else // rank == 2
62  {
63  // Handle 2D tensor: [bqk, n]
64  int n_ = lengths[1];
65  ck_tile::HostTensor<T> t_view({n_, bqk_dim / block_bq_k, block_bq_k});
66  std::copy(t->begin(), t->end(), t_view.begin());
67  return ck_tile::reference_permute(t_view, {1, 0, 2});
68  }
69 }
70 
71 template <typename GemmConfig, typename T>
72 auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
73 {
74  assert(t.get_lengths().size() == 2);
75  int n_ = t.get_lengths()[1];
76  int k_ = t.get_lengths()[0];
77 
79  {
80  constexpr int divisor = 2;
81  constexpr int kABK1PerLane = 8;
82  int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
83  ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
84  gemmConfig.N_Warp_Tile,
85  k_ / gemmConfig.K_Warp_Tile,
86  kABK0PerLane,
87  divisor,
88  kABK1PerLane});
89  std::copy(t.begin(), t.end(), t_view.begin());
90  return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
91  }
92  else
93  {
94  int divisor = 1;
96  {
97  divisor = 1;
98  }
99  else
100  {
101  assert(is_wave32() == false);
102  divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
103  }
104  ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
105  gemmConfig.N_Warp_Tile,
106  k_ / gemmConfig.K_Warp_Tile,
107  divisor,
108  gemmConfig.K_Warp_Tile / divisor});
109  std::copy(t.begin(), t.end(), t_view.begin());
110  return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
111  }
112 }
113 
114 template <typename GemmConfig, typename T>
116 {
117  return shuffle_b(t, GemmConfig{});
118 }
119 
120 template <typename GemmConfig, typename T>
122 {
123  assert(t.get_lengths().size() == 2);
124 
125  int n_ = t.get_lengths()[1];
126  int bqk_ = t.get_lengths()[0];
127  constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
128 
129  ck_tile::HostTensor<T> t_view({n_ / (GemmConfig::N_Tile / group_n),
130  GemmConfig::N_Warp,
131  GemmConfig::N_Warp_Tile / group_n,
132  NRepeat,
133  bqk_});
134  std::copy(t.begin(), t.end(), t_view.begin());
135  return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
136 }
137 
138 template <typename GemmConfig, typename T>
139 auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
140 {
141  assert(t.get_lengths().size() == 2);
142  int n_ = t.get_lengths()[1];
143  int k_ = t.get_lengths()[0];
144  int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
146  {
147  constexpr int divisor = 2;
148  constexpr int kABK1PerLane = 8;
149  int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
150  ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
151  gemmConfig.N_Warp,
152  gemmConfig.N_Warp_Tile,
153  NRepeat,
154  k_ / gemmConfig.K_Warp_Tile,
155  kABK0PerLane,
156  divisor,
157  kABK1PerLane});
158  std::copy(t.begin(), t.end(), t_view.begin());
159  return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7});
160  }
161  else
162  {
163  int divisor = 1;
165  {
166  divisor = 1;
167  }
168  else
169  {
170  assert(is_wave32() == false);
171  divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
172  }
173  ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
174  gemmConfig.N_Warp,
175  gemmConfig.N_Warp_Tile,
176  NRepeat,
177  k_ / gemmConfig.K_Warp_Tile,
178  divisor,
179  gemmConfig.K_Warp_Tile / divisor});
180  std::copy(t.begin(), t.end(), t_view.begin());
181  return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
182  }
183 }
184 
185 template <typename GemmConfig, typename T>
187 {
188  return shuffle_b_permuteN(t, GemmConfig{});
189 }
190 } // namespace ck_tile
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
auto copy(InputRange &&range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward< InputRange >(range)), std::end(std::forward< InputRange >(range)), iter))
Definition: algorithm.hpp:14
Definition: cluster_descriptor.hpp:13
auto shuffle_bq(const ck_tile::HostTensor< T > *t, int block_bq_k)
Definition: tensor_shuffle_utils.hpp:29
bool is_gfx12_supported()
Definition: device_prop.hpp:63
int32_t index_t
Definition: integer.hpp:9
auto shuffle_aq(const ck_tile::HostTensor< T > *t, int block_aq_k)
Definition: tensor_shuffle_utils.hpp:10
auto shuffle_b_permuteN(const ck_tile::HostTensor< T > &t, const GemmConfig &gemmConfig)
Definition: tensor_shuffle_utils.hpp:139
bool is_gfx11_supported()
Definition: device_prop.hpp:55
auto bq_permuteN(const ck_tile::HostTensor< T > &t, index_t group_n)
Definition: tensor_shuffle_utils.hpp:121
auto shuffle_b(const ck_tile::HostTensor< T > &t, const GemmConfig &gemmConfig)
Definition: tensor_shuffle_utils.hpp:72
CK_TILE_HOST void reference_permute(const HostTensor< DataType > &x, HostTensor< DataType > &y, std::vector< index_t > perm)
Definition: reference_permute.hpp:19
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
Data::iterator end()
Definition: host_tensor.hpp:588
Data::iterator begin()
Definition: host_tensor.hpp:586