14 throw std::runtime_error(
"Host tensor is not rank 2 tensor.");
19 if(aqk_ % block_aq_k != 0)
21 throw std::runtime_error(
"shuffle_aq needs a aqk of multiple times of block_aq_k.");
32 const size_t rank = lengths.size();
35 int bqk_dim = (
rank == 5) ? lengths[4] : (
rank == 2) ? lengths[0] : -1;
39 throw std::runtime_error(
"shuffle_bq expects either rank-2 or rank-5 tensor, got rank " +
40 std::to_string(
rank));
43 if(bqk_dim % block_bq_k != 0)
45 throw std::runtime_error(
"shuffle_bq needs bqk dimension to be a multiple of block_bq_k.");
53 static_cast<int>(lengths[1]),
54 static_cast<int>(lengths[2]),
55 static_cast<int>(lengths[3]),
71 template <
typename GemmConfig,
typename T>
80 constexpr
int divisor = 2;
81 constexpr
int kABK1PerLane = 8;
82 int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
84 gemmConfig.N_Warp_Tile,
85 k_ / gemmConfig.K_Warp_Tile,
101 assert(is_wave32() ==
false);
105 gemmConfig.N_Warp_Tile,
106 k_ / gemmConfig.K_Warp_Tile,
108 gemmConfig.K_Warp_Tile / divisor});
114 template <
typename GemmConfig,
typename T>
120 template <
typename GemmConfig,
typename T>
127 constexpr
int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
131 GemmConfig::N_Warp_Tile / group_n,
138 template <
typename GemmConfig,
typename T>
144 int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
147 constexpr
int divisor = 2;
148 constexpr
int kABK1PerLane = 8;
149 int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
152 gemmConfig.N_Warp_Tile,
154 k_ / gemmConfig.K_Warp_Tile,
170 assert(is_wave32() ==
false);
175 gemmConfig.N_Warp_Tile,
177 k_ / gemmConfig.K_Warp_Tile,
179 gemmConfig.K_Warp_Tile / divisor});
185 template <
typename GemmConfig,
typename T>
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
auto copy(InputRange &&range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward< InputRange >(range)), std::end(std::forward< InputRange >(range)), iter))
Definition: algorithm.hpp:14
Definition: cluster_descriptor.hpp:13
auto shuffle_bq(const ck_tile::HostTensor< T > *t, int block_bq_k)
Definition: tensor_shuffle_utils.hpp:29
bool is_gfx12_supported()
Definition: device_prop.hpp:63
int32_t index_t
Definition: integer.hpp:9
auto shuffle_aq(const ck_tile::HostTensor< T > *t, int block_aq_k)
Definition: tensor_shuffle_utils.hpp:10
auto shuffle_b_permuteN(const ck_tile::HostTensor< T > &t, const GemmConfig &gemmConfig)
Definition: tensor_shuffle_utils.hpp:139
bool is_gfx11_supported()
Definition: device_prop.hpp:55
auto bq_permuteN(const ck_tile::HostTensor< T > &t, index_t group_n)
Definition: tensor_shuffle_utils.hpp:121
auto shuffle_b(const ck_tile::HostTensor< T > &t, const GemmConfig &gemmConfig)
Definition: tensor_shuffle_utils.hpp:72
CK_TILE_HOST void reference_permute(const HostTensor< DataType > &x, HostTensor< DataType > &y, std::vector< index_t > perm)
Definition: reference_permute.hpp:19
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
Data::iterator end()
Definition: host_tensor.hpp:588
Data::iterator begin()
Definition: host_tensor.hpp:586