8 #include <hip/hip_runtime.h>
22 namespace tensor_operation {
25 template <
typename ALayout,
30 typename AScaleDataType,
32 typename BScaleDataType,
35 typename GemmAccDataType,
36 typename CShuffleDataType,
37 typename AElementwiseOperation,
38 typename BElementwiseOperation,
39 typename CElementwiseOperation,
54 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
55 typename ABlockTransferThreadClusterArrangeOrder,
56 typename ABlockTransferSrcAccessOrder,
57 index_t ABlockTransferSrcVectorDim,
58 index_t ABlockTransferSrcScalarPerVector,
59 index_t ABlockTransferDstScalarPerVector_AK1,
61 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
62 typename BBlockTransferThreadClusterArrangeOrder,
63 typename BBlockTransferSrcAccessOrder,
64 index_t BBlockTransferSrcVectorDim,
65 index_t BBlockTransferSrcScalarPerVector,
66 index_t BBlockTransferDstScalarPerVector_BK1,
68 index_t CShuffleMXdlPerWavePerShuffle,
69 index_t CShuffleNXdlPerWavePerShuffle,
70 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
71 typename CDEShuffleBlockTransferScalarPerVectors,
75 bool NSwizzle =
false,
76 bool IsInputGemm =
true,
77 bool IsSplitK =
false,
78 bool MulRoutedWeight =
false,
80 typename ComputeTypeA = CDataType,
81 typename ComputeTypeB = ComputeTypeA,
82 typename LDSTypeA = ComputeTypeA,
83 typename LDSTypeB = ComputeTypeB,
84 bool NonTemporalLoadB =
false>
99 AElementwiseOperation,
100 BElementwiseOperation,
101 CElementwiseOperation>
107 template <index_t NXdlPerWave_>
119 AElementwiseOperation,
120 BElementwiseOperation,
121 CElementwiseOperation,
136 ABlockTransferThreadClusterLengths_AK0_M_AK1,
137 ABlockTransferThreadClusterArrangeOrder,
138 ABlockTransferSrcAccessOrder,
139 ABlockTransferSrcVectorDim,
140 ABlockTransferSrcScalarPerVector,
141 ABlockTransferDstScalarPerVector_AK1,
144 BBlockTransferThreadClusterLengths_BK0_N_BK1,
145 BBlockTransferThreadClusterArrangeOrder,
146 BBlockTransferSrcAccessOrder,
147 BBlockTransferSrcVectorDim,
148 BBlockTransferSrcScalarPerVector,
149 BBlockTransferDstScalarPerVector_BK1,
152 CShuffleMXdlPerWavePerShuffle,
153 math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_),
154 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
155 CDEShuffleBlockTransferScalarPerVectors,
193 template <
typename Gr
idwiseGemm>
194 float RunImp(
const typename GridwiseGemm::Argument& arg,
197 if(stream_config.log_level_ > 0)
202 if(!GridwiseGemm::CheckValidity(arg))
204 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
208 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
209 arg.M, arg.N * (IsInputGemm && IsSplitK ? 2 : 1), arg.K, arg.KBatch);
213 index_t K_split = arg.KBatch == 1 ? arg.K : arg.KBatch * KPerBlock;
215 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
216 const auto RunKernel = [&](
const auto& kernel) {
217 if(stream_config.flush_cache)
220 std::array<std::size_t, NumDTensor> DsSize;
224 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
225 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
226 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
227 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
229 auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
231 auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
234 const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
235 arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
239 DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() *
sizeof(DDataType);
244 stream_config.rotating_count,
248 rotating_mem.Print();
250 auto run_flush_cache = [&]() {
264 ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
283 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
287 constexpr
auto estimated_reg_a = MPerBlock * KPerBlock *
sizeof(ADataType) / BlockSize /
288 4 * (1 + GridwiseGemm::NWave);
289 constexpr
auto estimated_reg_b = NPerBlock * KPerBlock *
sizeof(BDataType) / BlockSize /
290 4 * (2) * (IsInputGemm ? 2 : 1);
291 constexpr
auto estimated_reg_c = MPerBlock * NPerBlock *
sizeof(GemmAccDataType) /
292 BlockSize / 4 * (IsInputGemm ? 2 : 1);
293 constexpr
auto estimated_reg_total =
294 estimated_reg_a + estimated_reg_b + estimated_reg_c;
296 constexpr
index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
298 constexpr
auto MemoryDataOp = (IsInputGemm && !IsSplitK)
302 if(has_main_k_block_loop)
308 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
331 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
352 throw std::runtime_error(
"todo: only v1 & v2 support now");
361 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
383 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
414 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
427 if(arg.KBatch > 1 && !std::is_same_v<CDataType, float>)
431 if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
447 if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
451 if(arg.KBatch > 1 && arg.K % (KPerBlock * arg.KBatch) != 0)
482 const void* p_sorted_expert_ids,
483 const void* p_max_token_id,
486 std::array<const void*, NumDTensor> p_ds,
495 std::array<index_t, NumDTensor> StrideDs,
497 const void* p_a_scale,
498 const void* p_b_scale,
500 AElementwiseOperation a_element_op,
501 BElementwiseOperation b_element_op,
502 CElementwiseOperation c_element_op)
505 static_cast<const index_t*
>(p_sorted_expert_ids),
506 static_cast<const index_t*
>(p_max_token_id),
507 static_cast<const ADataType*
>(p_a),
508 static_cast<const BDataType*
>(p_b),
510 static_cast<CDataType*
>(p_c),
520 static_cast<const AScaleDataType*
>(p_a_scale),
521 static_cast<const BScaleDataType*
>(p_b_scale),
533 std::array<const void*, NumDTensor> p_ds,
540 std::array<ck::index_t, NumDTensor> StrideDs,
542 const void* p_a_scale,
543 const void* p_b_scale,
545 AElementwiseOperation a_element_op,
546 BElementwiseOperation b_element_op,
547 CElementwiseOperation c_element_op)
override
549 return std::make_unique<Argument>(
nullptr,
552 static_cast<const ADataType*
>(p_a),
553 static_cast<const BDataType*
>(p_b),
555 static_cast<CDataType*
>(p_c),
565 static_cast<const AScaleDataType*
>(p_a_scale),
566 static_cast<const BScaleDataType*
>(p_b_scale),
576 return std::make_unique<Invoker>(
Invoker{});
582 auto str = std::stringstream();
584 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
588 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
594 str <<
"DeviceMoeGEmm"
597 << std::string(ALayout::name)[0]
598 << std::string(BLayout::name)[0]
599 << std::string(CLayout::name)[0]
604 << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock <<
", "
606 << MPerXDL<<
"x"<<NPerXDL <<
", "
608 << MXdlPerWave<<
"x" << NXdlPerWave<<
", "
610 << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", "
611 <<
"BlkGemmPipelineScheduler: "
612 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
613 <<
"BlkGemmPipelineVersion: "
614 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
615 <<
"BlkGemmPipelinePrefetchStages: "
616 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
#define INVOKER_RUN3_IMPL
Definition: device_base.hpp:187
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:87
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:383
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:209
BlockGemmPipelineVersion
Block GEMM pipeline version enumeration.
Definition: scheduler_enum.hpp:17
@ v2
Memory-optimized pipeline.
@ v3
Compute-optimized pipeline.
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
@ Even
Even number of iterations.
@ Odd
Odd number of iterations.
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:219
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Block GEMM pipeline scheduler enumeration.
Definition: scheduler_enum.hpp:33
@ Intrawave
Schedule within a single wavefront.
@ Interwave
Schedule across multiple wavefronts.
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:301
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:16
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:84
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:113
Definition: stream_config.hpp:9
Definition: gridwise_moe_gemm_blockscale.hpp:674
Definition: gridwise_moe_gemm_blockscale.hpp:179
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:982
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:270
Definition: device_base.hpp:281
Definition: device_gemm_multiple_d_ab_scale.hpp:82
Definition: device_moe_gemm_blockscale.hpp:192
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_moe_gemm_blockscale.hpp:194
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_moe_gemm_blockscale.hpp:411
Definition: device_moe_gemm_blockscale.hpp:102
static constexpr index_t BPackedSize
Definition: device_moe_gemm_blockscale.hpp:181
static constexpr index_t APackedSize
Definition: device_moe_gemm_blockscale.hpp:174
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_moe_gemm_blockscale.hpp:476
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, const void *p_a_scale, const void *p_b_scale, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_moe_gemm_blockscale.hpp:531
typename GridwiseGemm64::Argument Argument
Definition: device_moe_gemm_blockscale.hpp:172
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_moe_gemm_blockscale.hpp:104
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_moe_gemm_blockscale.hpp:574
static bool IsSupportedArgument(const Argument &arg)
Definition: device_moe_gemm_blockscale.hpp:424
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, const void *p_a_scale, const void *p_b_scale, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_moe_gemm_blockscale.hpp:481
static auto MakeInvoker()
Definition: device_moe_gemm_blockscale.hpp:528
static constexpr index_t NumDTensor
Definition: device_moe_gemm_blockscale.hpp:106
static constexpr bool IsValidCompilationParameter()
Definition: device_moe_gemm_blockscale.hpp:418
std::string GetTypeString() const override
Definition: device_moe_gemm_blockscale.hpp:580
int GetPreShuffleParameters() override
Definition: device_moe_gemm_blockscale.hpp:188
static constexpr auto NXdlPerWave32
Definition: device_moe_gemm_blockscale.hpp:105
Definition: flush_cache.hpp:174