/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp Source File
tile_gemm_shape.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 template <typename BlockTile_,
12  typename BlockWarps_,
13  typename WarpTile_,
14  bool PermuteA_ = false,
15  bool PermuteB_ = false>
17 {
21 
22  static constexpr index_t NumWarps =
23  reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
24 
25  static constexpr index_t kM = BlockTile::at(number<0>{});
26  static constexpr index_t kN = BlockTile::at(number<1>{});
27  static constexpr index_t kK = BlockTile::at(number<2>{});
28 
29  static constexpr bool PermuteA = PermuteA_;
30  static constexpr bool PermuteB = PermuteB_;
31 
32  static constexpr index_t flatNPerWarp = BlockWarps::at(number<1>{});
33  static constexpr index_t flatKPerWarp = WarpTile::at(number<2>{}) * WarpTile::at(number<1>{});
34  static constexpr index_t flatKPerBlock = flatKPerWarp * kK / WarpTile::at(number<2>{});
35 
36  CK_TILE_HOST static std::string GetName()
37  {
38  // clang-format off
39  return concat('_', "tile_gemm_shape",
40  concat('x', kM, kN, kK, NumWarps),
41  concat('x', BlockWarps::at(number<0>{}), BlockWarps::at(number<1>{}), BlockWarps::at(number<2>{})),
42  concat('x', (WarpTile::at(number<0>{})), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})));
43  // clang-format on
44  }
45 };
46 
47 template <typename PrecType, index_t M_Warp_Tile, bool IsFlatMM = false>
49 {
50 #if CK_TILE_USE_WMMA
51  return 16;
52 #else
53 #if defined(CK_GFX950_SUPPORT)
54  constexpr bool is_8bit_float =
55  std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
56  if constexpr(M_Warp_Tile == 32)
57  return is_8bit_float ? 64 : 16;
58  else
59  return is_8bit_float ? 128 : 32;
60 #else
61  if constexpr(M_Warp_Tile == 32)
62  return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 16 : 32;
63  else
64  return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 32 : 64;
65 #endif
66 #endif
67 }
68 
69 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:44
Definition: cluster_descriptor.hpp:13
constexpr index_t get_k_warp_tile()
Definition: tile_gemm_shape.hpp:48
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:993
Definition: tile_gemm_shape.hpp:17
static constexpr index_t flatNPerWarp
Definition: tile_gemm_shape.hpp:32
static constexpr index_t kK
Definition: tile_gemm_shape.hpp:27
static constexpr bool PermuteB
Definition: tile_gemm_shape.hpp:30
remove_cvref_t< BlockWarps_ > BlockWarps
Definition: tile_gemm_shape.hpp:19
static constexpr index_t kN
Definition: tile_gemm_shape.hpp:26
static constexpr index_t kM
Definition: tile_gemm_shape.hpp:25
static constexpr index_t flatKPerWarp
Definition: tile_gemm_shape.hpp:33
static CK_TILE_HOST std::string GetName()
Definition: tile_gemm_shape.hpp:36
static constexpr index_t flatKPerBlock
Definition: tile_gemm_shape.hpp:34
static constexpr bool PermuteA
Definition: tile_gemm_shape.hpp:29
remove_cvref_t< BlockTile_ > BlockTile
Definition: tile_gemm_shape.hpp:18
remove_cvref_t< WarpTile_ > WarpTile
Definition: tile_gemm_shape.hpp:20
static constexpr index_t NumWarps
Definition: tile_gemm_shape.hpp:22
Definition: integral_constant.hpp:13
Definition: math.hpp:98