36 template <
typename GridwiseGemm,
37 bool HasMainKBlockLoop,
42 #if CK_USE_LAUNCH_BOUNDS
48 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
49 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
51 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
53 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
55 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56 karg.p_sorted_token_ids,
57 karg.p_sorted_expert_ids,
59 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
63 karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
64 karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
76 template <
typename GridwiseGemm,
77 bool HasMainKBlockLoop,
82 #if CK_USE_LAUNCH_BOUNDS
88 #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
89 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
91 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
92 __shared__
char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
94 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
96 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
97 karg.p_sorted_token_ids,
98 karg.p_sorted_expert_ids,
100 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
101 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
104 karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
105 karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
118 template <
typename ALayout,
124 typename AccDataType,
125 typename CShuffleDataType,
128 typename AElementwiseOperation,
129 typename BElementwiseOperation,
130 typename CElementwiseOperation,
145 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
146 typename ABlockTransferThreadClusterArrangeOrder,
147 typename ABlockTransferSrcAccessOrder,
148 index_t ABlockTransferSrcVectorDim,
149 index_t ABlockTransferSrcScalarPerVector,
150 index_t ABlockTransferDstScalarPerVector_AK1,
151 bool AThreadTransferSrcResetCoordinateAfterRun,
153 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
154 typename BBlockTransferThreadClusterArrangeOrder,
155 typename BBlockTransferSrcAccessOrder,
156 index_t BBlockTransferSrcVectorDim,
157 index_t BBlockTransferSrcScalarPerVector,
158 index_t BBlockTransferDstScalarPerVector_BK1,
159 bool BThreadTransferSrcResetCoordinateAfterRun,
161 index_t CShuffleMXdlPerWavePerShuffle,
162 index_t CShuffleNXdlPerWavePerShuffle,
163 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
164 typename CDEShuffleBlockTransferScalarPerVectors,
167 index_t ActivationOperation = 0,
168 bool NSwizzle =
false,
169 bool IsInputGemm =
true,
170 bool IsSplitK =
false,
171 bool MulRoutedWeight =
true,
173 typename ComputeTypeA = CDataType,
174 typename ComputeTypeB = ComputeTypeA,
175 typename LDSTypeA = ADataType,
176 typename LDSTypeB = BDataType,
177 bool NonTemporalLoadB =
false>
193 CDEShuffleBlockTransferScalarPerVectors{}[
I0];
231 return static_cast<const DDataType*
>(
nullptr);
258 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
259 const index_t gridy = NSwizzle ? 1 : mblock;
293 return K_Batch == 1 ? K / AK1Value : K_Batch * KPerBlock / AK1Value;
300 return K_Batch == 1 ? K / BK1Value : K_Batch * KPerBlock / BK1Value;
307 return K_Batch == 1 ? K : K_Batch * KPerBlock;
316 : K_Batch * KPerBlock;
329 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
345 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
347 const auto a_grid_desc_mraw_kraw = [&]() {
348 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
352 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
360 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
361 GemmSpec == GemmSpecialization::MNKPadding)
364 const auto a_grid_desc_m_k =
378 return a_grid_desc_ak0_m_ak1;
380 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
381 GemmSpec == GemmSpecialization::MNPadding)
385 a_grid_desc_mraw_kraw,
391 return a_grid_desc_ak0_m_ak1;
393 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
394 GemmSpec == GemmSpecialization::NKPadding)
398 a_grid_desc_mraw_kraw,
410 return a_grid_desc_ak0_m_ak1;
416 a_grid_desc_mraw_kraw,
421 return a_grid_desc_ak0_m_ak1;
427 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
428 constexpr
index_t WaveSize = BlockSize / (MWave *
NWave);
432 make_tuple(
NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber,
I1));
438 const auto b_grid_desc_nraw_kraw = [&]() {
452 GemmSpec != GemmSpecialization::Default),
453 "pk_i4_t does not support padding");
455 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
456 GemmSpec == GemmSpecialization::MNKPadding)
459 const auto b_grid_desc_n_k =
473 return b_grid_desc_bk0_n_bk1;
475 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
476 GemmSpec == GemmSpecialization::MNPadding)
480 b_grid_desc_nraw_kraw,
486 return b_grid_desc_bk0_n_bk1;
488 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
489 GemmSpec == GemmSpecialization::MKPadding)
493 b_grid_desc_nraw_kraw,
505 return b_grid_desc_bk0_n_bk1;
511 b_grid_desc_nraw_kraw,
517 return b_grid_desc_bk0_n_bk1;
521 template <
typename ABlockDesc_AK0_M_AK1>
522 __host__ __device__
static constexpr
auto
525 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
527 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
530 template <
typename BBlockDesc_BK0_N_BK1>
531 __host__ __device__
static constexpr
auto
534 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
537 template <
typename ELayout>
539 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
541 const auto c_grid_desc_mraw_nraw = [&]() {
560 template <
typename DLayout>
561 __host__ __device__
static auto
564 const auto c_grid_desc_mraw_nraw = [&]() {
589 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
594 template <
typename DsGr
idDesc>
596 const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
601 ds_grid_desc_m_n[i], MBlock, NBlock);
617 std::array<index_t, NumDTensor> StrideDs_,
643 std::cout <<
"problem {" <<
"NumTokens:" <<
NumTokens <<
", " <<
"TopK:" <<
TopK <<
", "
644 <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
647 <<
"KRead:" <<
KRead <<
", " <<
"KP:" <<
KPadded <<
", " <<
"AK0:" <<
AK0
648 <<
", " <<
"BK0:" <<
BK0 <<
", " <<
"MBlock: " <<
MBlock <<
", "
649 <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
676 const index_t* p_sorted_expert_ids_,
677 const index_t* p_max_token_id_,
678 const ADataType* p_a_grid_,
679 const BDataType* p_b_grid_,
680 std::array<const void*, NumDTensor> p_ds_grid_,
681 CDataType* p_c_grid_,
689 std::array<index_t, NumDTensor> StrideDs_,
694 AElementwiseOperation a_element_op_,
695 BElementwiseOperation b_element_op_,
696 CElementwiseOperation c_element_op_)
726 p_ds_grid(i) =
static_cast<const DDataType_*
>(p_ds_grid_[i]);
750 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
755 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
761 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
766 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
791 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
792 constexpr
index_t WaveSize = BlockSize / (MWave *
NWave);
794 if constexpr(ABlockLdsExtraM)
804 constexpr
auto a_lds_block_desc =
816 return a_lds_block_desc_permuted;
823 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
824 constexpr
auto M1 = MPerBlock / M0;
826 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
827 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
828 constexpr
auto KThreadRead = WaveSize / MPerXdl;
829 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
831 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(LDSTypeA) > 128)
833 : 128 / (
AK1Number * M0 *
sizeof(LDSTypeA));
834 constexpr
auto KThreadReadPerm =
835 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
836 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
840 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(LDSTypeA) > 128)
842 : ((128 / (
AK1Number * MPerXdl *
sizeof(LDSTypeA))) > M0
844 : 128 / (
AK1Number * MPerXdl *
sizeof(LDSTypeA)));
850 Number<kfold * M0 / mpair>{},
869 a_lds_block_desc_permuted,
891 a_lds_block_desc_unmerged,
894 Number<KThreadWrite / kfold / KThreadReadPerm>{},
903 return a_lds_block_desc_ak0_m_ak1;
916 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
918 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
925 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
943 ABlockTransferSrcScalarPerVector,
944 BBlockTransferSrcScalarPerVector,
956 IsInputGemm && !IsSplitK > ())>;
966 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
969 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
972 constexpr
auto c_block_size =
973 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
976 c_block_size *
sizeof(CShuffleDataType));
984 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
985 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
986 "Invalid tuning param!");
994 if(!(karg.M % MPerBlock == 0))
997 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.M <<
" "
998 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1012 if(!(karg.N % NPerBlock == 0))
1015 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.N <<
" "
1016 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1030 auto K_t = karg.KBatch * KPerBlock;
1031 if(!(karg.K % K_t == 0))
1034 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1035 << karg.K <<
" " << __FILE__ <<
":" << __LINE__
1036 <<
", in function: " << __func__ << std::endl;
1045 auto K_t = karg.KBatch * KReadVec;
1047 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1055 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1058 std::cout <<
"Arg K (" << karg.K
1059 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1060 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1061 << __LINE__ <<
", in function: " << __func__ << std::endl;
1069 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1072 std::cout <<
"Arg M (" << karg.M
1073 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1074 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1075 << __LINE__ <<
", in function: " << __func__ << std::endl;
1084 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1087 std::cout <<
"Arg N (" << karg.N
1088 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1089 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1090 << __LINE__ <<
", in function: " << __func__ << std::endl;
1098 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1101 std::cout <<
"Arg K (" << karg.K
1102 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1103 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1104 << __LINE__ <<
", in function: " << __func__ << std::endl;
1116 std::cout <<
"Arg N (" << karg.N
1117 <<
") value is not a multiple of "
1118 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1120 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1131 std::cout <<
"Arg M (" << karg.M
1132 <<
") value is not a multiple of "
1133 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1135 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1144 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1146 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1157 const index_t num_loop = K / KPerBlock;
1159 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1164 const index_t num_loop = K / KPerBlock;
1166 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1169 template <
typename CGr
idDesc>
1171 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1180 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1188 template <
bool HasMainKBlockLoop,
1192 const index_t* p_sorted_expert_ids,
1193 const index_t* p_max_token_id,
1194 const ADataType* p_a_grid,
1195 const BDataType* p_b_grid,
1197 CDataType* p_c_grid,
1202 AElementwiseOperation a_element_op,
1203 BElementwiseOperation b_element_op,
1204 CElementwiseOperation c_element_op)
1206 #if defined(__gfx942__) || defined(__gfx950__)
1207 constexpr
auto b_coherence_flag = NonTemporalLoadB
1208 ? AmdBufferCoherenceEnum::WAVE_NT1
1223 const auto b_grid_desc_bpreshuffled =
1225 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1228 problem.
N * (IsInputGemm && IsSplitK ? 2 : 1),
1229 problem.
NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
1244 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1247 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1249 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1250 if(expert_block_id * MPerBlock >= max_token_id)
1253 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1254 const auto block_mn = [&]() -> std::pair<int, int> {
1255 if constexpr(NSwizzle)
1257 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1259 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1260 const index_t expert_swizzle =
1261 ecnt > 0 ? ecnt : 1;
1262 const index_t bid_new = blockIdx.x - prefix_block;
1263 const index_t nid = __builtin_amdgcn_readfirstlane(
1264 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1266 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1271 return {blockIdx.x, blockIdx.y};
1274 const index_t block_n_id = block_mn.first;
1275 const index_t block_m_id = block_mn.second;
1277 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1280 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1281 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1282 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1283 constexpr
auto AKThreads = AK0Threads * AK1Threads;
1284 constexpr
auto AMRepeats = MPerBlock / AMThreads;
1285 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1287 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
1291 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1292 index_t token_offset = fused_token & 0xffffff;
1293 if constexpr(!IsInputGemm)
1295 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1297 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
1300 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
1301 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1306 const index_t n_block_data_idx_on_grid =
1307 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1309 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1310 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1311 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
1313 b_grid_desc_bpreshuffled.GetElementSpaceSize());
1315 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1316 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1317 const auto b_scale_grid_buf =
1318 make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
1319 p_b_scale_grid + expert_id * expert_scale_stride,
1320 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1331 AElementwiseOperation,
1335 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1336 ABlockTransferThreadClusterArrangeOrder,
1339 decltype(a_grid_desc_ak0_m_ak1),
1340 decltype(a_block_desc_ak0_m_ak1),
1341 ABlockTransferSrcAccessOrder,
1343 ABlockTransferSrcVectorDim,
1345 ABlockTransferSrcScalarPerVector,
1346 ABlockTransferDstScalarPerVector_AK1,
1349 AThreadTransferSrcResetCoordinateAfterRun,
1353 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1356 a_block_desc_ak0_m_ak1,
1363 auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1364 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1369 decltype(b_grid_desc_bpreshuffled),
1370 decltype(b_block_desc_bk0_n_bk1),
1374 BBlockTransferSrcScalarPerVector,
1375 BThreadTransferSrcResetCoordinateAfterRun,
1376 true>(b_grid_desc_bpreshuffled,
1384 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1385 static_cast<LDSTypeA*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1391 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1393 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1394 decltype(c_thread_buf) c_thread_buf_up;
1396 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1398 ? (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1401 constexpr
index_t ScaleSliceSizeM = MXdlPerWave;
1410 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1411 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1412 constexpr
index_t WaveSize = BlockSize / (MWaves * NWaves);
1424 const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
1426 if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
1431 p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
1432 index_t token_offset = fused_token & 0xffffff;
1433 if constexpr(!IsInputGemm)
1435 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1437 scale_gather_offsets(m0) =
1441 auto a_scale_thread_copy =
1444 decltype(a_scale_grid_desc_am_ak),
1445 decltype(a_scale_thread_desc),
1455 auto b_scale_thread_copy =
1458 decltype(b_scale_grid_desc_bn_ak),
1459 decltype(b_scale_thread_desc),
1466 b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1469 constexpr
auto a_scale_thread_slice_copy_step =
1471 constexpr
auto b_scale_thread_slice_copy_step =
make_multi_index(0, ScaleSliceSizeK);
1474 if constexpr(IsInputGemm && !IsSplitK)
1476 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
1477 const auto b_grid_buf_up =
1478 make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
1481 b_grid_desc_bpreshuffled.GetElementSpaceSize());
1485 decltype(b_grid_desc_bpreshuffled),
1486 decltype(b_block_desc_bk0_n_bk1),
1490 BBlockTransferSrcScalarPerVector,
1491 BThreadTransferSrcResetCoordinateAfterRun,
1492 true>(b_grid_desc_bpreshuffled,
1498 p_b_scale_grid + expert_scale_stride / 2 /
BPackedSize;
1499 const auto b_scale_grid_buf_up =
1500 make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
1501 p_b_scale_grid_up + expert_id * expert_scale_stride,
1502 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1503 auto b_scale_thread_copy_up =
1506 decltype(b_scale_grid_desc_bn_ak),
1507 decltype(b_scale_thread_desc),
1514 b_scale_grid_desc_bn_ak,
1517 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1518 a_grid_desc_ak0_m_ak1,
1519 a_block_desc_ak0_m_ak1,
1523 a_block_slice_copy_step,
1525 b_grid_desc_bpreshuffled,
1526 b_block_desc_bk0_n_bk1,
1528 b_blockwise_copy_up,
1532 b_block_slice_copy_step,
1534 c_scale_thread_desc,
1538 a_scale_grid_desc_am_ak,
1539 a_scale_thread_desc,
1540 a_scale_thread_copy,
1542 a_scale_thread_slice_copy_step,
1544 b_scale_grid_desc_bn_ak,
1545 b_scale_thread_desc,
1546 b_scale_thread_copy,
1547 b_scale_thread_copy_up,
1549 b_scale_grid_buf_up,
1550 b_scale_thread_slice_copy_step,
1552 num_k_block_main_loop);
1556 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1557 a_grid_desc_ak0_m_ak1,
1558 a_block_desc_ak0_m_ak1,
1562 a_block_slice_copy_step,
1564 b_grid_desc_bpreshuffled,
1565 b_block_desc_bk0_n_bk1,
1569 b_block_slice_copy_step,
1571 c_scale_thread_desc,
1574 a_scale_grid_desc_am_ak,
1575 a_scale_thread_desc,
1576 a_scale_thread_copy,
1578 a_scale_thread_slice_copy_step,
1580 b_scale_grid_desc_bn_ak,
1581 b_scale_thread_desc,
1582 b_scale_thread_copy,
1584 b_scale_thread_slice_copy_step,
1586 num_k_block_main_loop);
1591 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1592 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1595 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1599 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1600 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1604 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
1605 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1607 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I0);
1608 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I1);
1609 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I2);
1610 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I3);
1611 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I4);
1612 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I5);
1613 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I6);
1614 constexpr
auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I7);
1616 static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
1617 static_assert(M0 * M1 * M2 == MPerBlock);
1618 static_assert(N4 == 4 || N4 == 8);
1625 if constexpr(MulRoutedWeight)
1627 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
1628 topk_weight = p_ds_grid[
I0][m_pos];
1633 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1636 if constexpr(IsInputGemm && !IsSplitK)
1640 float gate = c_thread_buf[cidx];
1641 float up = c_thread_buf_up[cidx];
1642 if constexpr(MulRoutedWeight)
1644 gate = gate * topk_weight;
1645 up = up * topk_weight;
1653 c_thread_buf(cidx) = gate * up;
1657 float gate = c_thread_buf[cidx];
1658 float up = c_thread_buf_up[cidx];
1659 if constexpr(MulRoutedWeight)
1661 gate = gate * topk_weight;
1662 up = up * topk_weight;
1670 c_thread_buf(cidx) = gate * up;
1675 if constexpr(MulRoutedWeight)
1677 c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
1685 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1688 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1689 static_cast<CShuffleDataType*
>(p_shared),
1690 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1693 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1713 const auto c_thread_mtx_on_block =
1714 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1716 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1717 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1719 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1725 const auto m_thread_data_on_block_idx =
1726 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1729 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1735 const auto n_thread_data_on_block_idx =
1736 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1740 auto c_thread_copy_vgpr_to_lds =
1743 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1744 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1746 Sequence<CShuffleMXdlPerWavePerShuffle,
1747 CShuffleNXdlPerWavePerShuffle,
1760 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1763 m_thread_data_on_block_idx[
I1],
1764 n_thread_data_on_block_idx[
I1],
1765 m_thread_data_on_block_idx[
I2],
1766 n_thread_data_on_block_idx[
I2],
1767 n_thread_data_on_block_idx[
I3],
1768 n_thread_data_on_block_idx[
I4]),
1771 using EDataType = CDataType;
1773 const auto ds_grid_desc_m_n =
1776 problem.
N * (IsInputGemm && IsSplitK ? 2 : 1),
1777 problem.
NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
1780 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1787 const DDataType* ptr_ = p_ds_grid[i];
1790 return make_dynamic_buffer<AddressSpaceEnum::Global>(
1791 ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
1797 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1799 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1804 tie(c_shuffle_block_buf),
1806 {
return ds_grid_buf[i]; },
1810 const auto idx_c_ds_block_begin =
1820 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1821 c_grid_desc_mblock_mperblock_nblock_nperblock;
1823 using CDEBlockTransferCluster =
1824 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1825 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1826 constexpr
index_t scatter_weight_idx = IsInputGemm ? 1 : 1;
1831 decltype(c_ds_desc_refs),
1832 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1833 CElementwiseOperation,
1837 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1839 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
1840 CDEBlockTransferCluster,
1846 CDEShuffleBlockTransferScalarPerVectors,
1858 idx_c_ds_block_begin,
1859 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1863 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1864 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1866 constexpr
auto sfc_c_vgpr =
1869 Sequence<CShuffleMXdlPerWavePerShuffle,
1870 CShuffleNXdlPerWavePerShuffle,
1878 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1881 constexpr
auto sfc_cde_block =
1885 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1887 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
1889 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
1890 constexpr
auto EMThreads =
1891 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
1892 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1893 constexpr
auto ENThreads =
1894 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
1899 auto dstidx = sfc_cde_block.GetIndex(access_id);
1901 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
1903 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1904 index_t token_offset = fused_token & 0xffffff;
1905 if constexpr(IsInputGemm)
1907 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1909 scatter_offsets(m0) =
1910 token_offset * problem.
N * (IsInputGemm && IsSplitK ? 2 : 1);
1916 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1917 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1919 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1920 c_shuffle_block_buf);
1926 cde_block_copy_lds_and_global.Run(
1929 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1933 if constexpr(access_id < num_access - 1)
1935 constexpr
auto cde_lds_and_global_step =
1936 sfc_cde_block.GetForwardStep(access_id);
1940 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1941 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
1945 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1946 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1948 cde_lds_and_global_step);
1954 template <
bool HasMainKBlockLoop,
1958 const index_t* p_sorted_expert_ids,
1959 const index_t* p_max_token_id,
1960 const ADataType* p_a_grid,
1961 const BDataType* p_b_grid,
1963 CDataType* p_c_grid,
1969 AElementwiseOperation a_element_op,
1970 BElementwiseOperation b_element_op,
1971 CElementwiseOperation c_element_op)
1973 #if defined(__gfx942__) || defined(__gfx950__)
1974 constexpr
auto b_coherence_flag = NonTemporalLoadB
1975 ? AmdBufferCoherenceEnum::WAVE_NT1
1990 const auto b_grid_desc_bpreshuffled =
1992 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1995 problem.
N * (IsInputGemm && IsSplitK ? 2 : 1),
1996 problem.
NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
2009 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2012 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2013 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
2014 if(expert_block_id * MPerBlock >= max_token_id)
2017 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2018 const auto block_mn = [&]() -> std::pair<int, int> {
2019 if constexpr(NSwizzle)
2021 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2023 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2024 const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
2025 const index_t bid_new = blockIdx.x - prefix_block;
2026 const index_t nid = __builtin_amdgcn_readfirstlane(
2027 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2029 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2034 return {blockIdx.x, blockIdx.y};
2037 const index_t block_n_id = block_mn.first;
2038 const index_t block_m_id = block_mn.second;
2041 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2044 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
2045 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
2046 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
2047 constexpr
auto AKThreads = AK0Threads * AK1Threads;
2048 constexpr
auto AMRepeats = MPerBlock / AMThreads;
2049 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2051 if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
2057 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2058 index_t token_offset = fused_token & 0xffffff;
2059 if constexpr(!IsInputGemm)
2061 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2063 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
2066 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
2067 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2071 const index_t n_block_data_idx_on_grid =
2072 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2074 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2075 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2076 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
2078 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2080 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2081 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2082 const auto b_scale_grid_buf =
2083 make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
2084 p_b_scale_grid + expert_id * expert_scale_stride,
2085 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2096 AElementwiseOperation,
2100 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2101 ABlockTransferThreadClusterArrangeOrder,
2104 decltype(a_grid_desc_ak0_m_ak1),
2105 decltype(a_block_desc_ak0_m_ak1),
2106 ABlockTransferSrcAccessOrder,
2108 ABlockTransferSrcVectorDim,
2110 ABlockTransferSrcScalarPerVector,
2111 ABlockTransferDstScalarPerVector_AK1,
2114 AThreadTransferSrcResetCoordinateAfterRun,
2118 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2121 a_block_desc_ak0_m_ak1,
2128 auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2129 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2130 auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2131 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2132 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2137 decltype(b_grid_desc_bpreshuffled),
2138 decltype(b_block_desc_bk0_n_bk1),
2142 BBlockTransferSrcScalarPerVector,
2143 BThreadTransferSrcResetCoordinateAfterRun,
2144 true>(b_grid_desc_bpreshuffled,
2152 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2153 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2154 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2155 static_cast<ADataType*
>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2156 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2162 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2164 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2165 decltype(c_thread_buf) c_thread_buf_up;
2167 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2169 ? (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2174 constexpr
index_t ScaleSliceSizeM = MXdlPerWave;
2183 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
2184 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
2185 constexpr
index_t WaveSize = BlockSize / (MWaves * NWaves);
2197 const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
2199 if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
2204 p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
2205 index_t token_offset = fused_token & 0xffffff;
2206 if constexpr(!IsInputGemm)
2208 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2210 scale_gather_offsets(m0) =
static_cast<IndexType
>(token_offset) *
2214 auto a_scale_thread_copy =
2217 decltype(a_scale_grid_desc_am_ak),
2218 decltype(a_scale_thread_desc),
2228 auto b_scale_thread_copy =
2231 decltype(b_scale_grid_desc_bn_ak),
2232 decltype(b_scale_thread_desc),
2239 b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2242 constexpr
auto a_scale_thread_slice_copy_step =
2244 constexpr
auto b_scale_thread_slice_copy_step =
make_multi_index(0, ScaleSliceSizeK);
2247 if constexpr(IsInputGemm && !IsSplitK)
2249 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
2250 const auto b_grid_buf_up =
2251 make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
2254 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2258 decltype(b_grid_desc_bpreshuffled),
2259 decltype(b_block_desc_bk0_n_bk1),
2263 BBlockTransferSrcScalarPerVector,
2264 BThreadTransferSrcResetCoordinateAfterRun,
2265 true>(b_grid_desc_bpreshuffled,
2271 p_b_scale_grid + expert_scale_stride / 2 /
BPackedSize;
2272 const auto b_scale_grid_buf_up =
2273 make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
2274 p_b_scale_grid_up + expert_id * expert_scale_stride /
BPackedSize,
2275 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2276 auto b_scale_thread_copy_up =
2279 decltype(b_scale_grid_desc_bn_ak),
2280 decltype(b_scale_thread_desc),
2287 b_scale_grid_desc_bn_ak,
2290 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2291 a_grid_desc_ak0_m_ak1,
2292 a_block_desc_ak0_m_ak1,
2296 a_block_slice_copy_step,
2297 b_grid_desc_bpreshuffled,
2298 b_block_desc_bk0_n_bk1,
2300 b_blockwise_copy_up,
2304 b_block_slice_copy_step,
2305 c_scale_thread_desc,
2308 a_scale_grid_desc_am_ak,
2309 a_scale_thread_desc,
2310 a_scale_thread_copy,
2312 a_scale_thread_slice_copy_step,
2313 b_scale_grid_desc_bn_ak,
2314 b_scale_thread_desc,
2315 b_scale_thread_copy,
2316 b_scale_thread_copy_up,
2318 b_scale_grid_buf_up,
2319 b_scale_thread_slice_copy_step,
2320 num_k_block_main_loop);
2324 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2325 a_grid_desc_ak0_m_ak1,
2326 a_block_desc_ak0_m_ak1,
2330 a_block_slice_copy_step,
2331 b_grid_desc_bpreshuffled,
2332 b_block_desc_bk0_n_bk1,
2336 b_block_slice_copy_step,
2337 c_scale_thread_desc,
2339 a_scale_grid_desc_am_ak,
2340 a_scale_thread_desc,
2341 a_scale_thread_copy,
2343 a_scale_thread_slice_copy_step,
2344 b_scale_grid_desc_bn_ak,
2345 b_scale_thread_desc,
2346 b_scale_thread_copy,
2348 b_scale_thread_slice_copy_step,
2349 num_k_block_main_loop);
2355 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2356 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2359 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2363 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
2364 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2368 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
2369 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2371 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I0);
2372 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I1);
2373 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I2);
2374 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I3);
2375 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I4);
2376 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I5);
2377 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I6);
2378 constexpr
auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I7);
2380 static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
2381 static_assert(M0 * M1 * M2 == MPerBlock);
2382 static_assert(N4 == 4 || N4 == 8);
2389 if constexpr(MulRoutedWeight)
2391 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
2392 topk_weight = p_ds_grid[
I0][m_pos];
2397 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2400 if constexpr(IsInputGemm && !IsSplitK)
2404 float gate = c_thread_buf[cidx];
2405 float up = c_thread_buf_up[cidx];
2406 if constexpr(MulRoutedWeight)
2408 gate = gate * topk_weight;
2409 up = up * topk_weight;
2417 c_thread_buf(cidx) = gate * up;
2421 float gate = c_thread_buf[cidx];
2422 float up = c_thread_buf_up[cidx];
2423 if constexpr(MulRoutedWeight)
2425 gate = gate * topk_weight;
2426 up = up * topk_weight;
2434 c_thread_buf(cidx) = gate * up;
2439 if constexpr(MulRoutedWeight)
2441 c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
2450 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2453 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2454 static_cast<CShuffleDataType*
>(p_shared),
2455 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2458 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2478 const auto c_thread_mtx_on_block =
2479 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2481 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2482 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2484 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
2490 const auto m_thread_data_on_block_idx =
2491 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
2494 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
2500 const auto n_thread_data_on_block_idx =
2501 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
2505 auto c_thread_copy_vgpr_to_lds =
2508 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2509 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2511 Sequence<CShuffleMXdlPerWavePerShuffle,
2512 CShuffleNXdlPerWavePerShuffle,
2525 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2528 m_thread_data_on_block_idx[
I1],
2529 n_thread_data_on_block_idx[
I1],
2530 m_thread_data_on_block_idx[
I2],
2531 n_thread_data_on_block_idx[
I2],
2532 n_thread_data_on_block_idx[
I3],
2533 n_thread_data_on_block_idx[
I4]),
2536 using EDataType = CDataType;
2541 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2547 return make_dynamic_buffer<AddressSpaceEnum::Global>(
2548 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2554 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2556 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2561 tie(c_shuffle_block_buf),
2563 {
return ds_grid_buf[i]; },
2567 const auto idx_c_ds_block_begin =
2577 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2578 c_grid_desc_mblock_mperblock_nblock_nperblock;
2580 using CDEBlockTransferCluster =
2581 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2582 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2583 constexpr
index_t scatter_weight_idx = IsInputGemm ? 1 : 1;
2588 decltype(c_ds_desc_refs),
2589 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2590 CElementwiseOperation,
2594 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2596 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
2597 CDEBlockTransferCluster,
2603 CDEShuffleBlockTransferScalarPerVectors,
2615 idx_c_ds_block_begin,
2616 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2620 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2621 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2623 constexpr
auto sfc_c_vgpr =
2626 Sequence<CShuffleMXdlPerWavePerShuffle,
2627 CShuffleNXdlPerWavePerShuffle,
2635 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2638 constexpr
auto sfc_cde_block =
2642 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2644 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
2646 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2647 constexpr
auto EMThreads =
2648 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2649 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2650 constexpr
auto ENThreads =
2651 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2657 auto dstidx = sfc_cde_block.GetIndex(access_id);
2659 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2661 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2662 index_t token_offset = fused_token & 0xffffff;
2663 if constexpr(IsInputGemm)
2665 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2667 scatter_offsets(m0) =
2668 token_offset * problem.
N * (IsInputGemm && IsSplitK ? 2 : 1);
2674 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2675 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2677 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2678 c_shuffle_block_buf);
2684 cde_block_copy_lds_and_global.Run(
2687 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2691 if constexpr(access_id < num_access - 1)
2693 constexpr
auto cde_lds_and_global_step =
2694 sfc_cde_block.GetForwardStep(access_id);
2698 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2699 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2703 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2704 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2706 cde_lds_and_global_step);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:251
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:54
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:45
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:835
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:209
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:279
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Block GEMM pipeline version enumeration.
Definition: scheduler_enum.hpp:17
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
int64_t long_index_t
Definition: ck.hpp:302
TailNumber
Tail number enumeration for pipeline buffering.
Definition: scheduler_enum.hpp:49
@ Even
Even number of iterations.
@ Odd
Odd number of iterations.
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:219
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:832
BlockGemmPipelineScheduler
Block GEMM pipeline scheduler enumeration.
Definition: scheduler_enum.hpp:33
@ Intrawave
Schedule within a single wavefront.
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:212
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:84
Definition: gridwise_moe_gemm_blockscale.hpp:674
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm_blockscale.hpp:735
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm_blockscale.hpp:730
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, const AScaleType *p_a_scale_grid_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm_blockscale.hpp:675
const AScaleType * p_a_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:738
const BScaleType * p_b_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:739
const index_t * p_max_token_id
Definition: gridwise_moe_gemm_blockscale.hpp:732
CDataType * p_c_grid
Definition: gridwise_moe_gemm_blockscale.hpp:736
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm_blockscale.hpp:731
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:742
const BDataType * p_b_grid
Definition: gridwise_moe_gemm_blockscale.hpp:734
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:743
const ADataType * p_a_grid
Definition: gridwise_moe_gemm_blockscale.hpp:733
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:741
Definition: gridwise_moe_gemm_blockscale.hpp:609
index_t TopK
Definition: gridwise_moe_gemm_blockscale.hpp:653
index_t MBlock
Definition: gridwise_moe_gemm_blockscale.hpp:668
index_t KPadded
Definition: gridwise_moe_gemm_blockscale.hpp:665
index_t StrideC
Definition: gridwise_moe_gemm_blockscale.hpp:660
index_t MPadded
Definition: gridwise_moe_gemm_blockscale.hpp:662
index_t KRead
Definition: gridwise_moe_gemm_blockscale.hpp:664
index_t BK0
Definition: gridwise_moe_gemm_blockscale.hpp:667
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm_blockscale.hpp:610
index_t K
Definition: gridwise_moe_gemm_blockscale.hpp:656
index_t StrideB
Definition: gridwise_moe_gemm_blockscale.hpp:658
index_t AK0
Definition: gridwise_moe_gemm_blockscale.hpp:666
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm_blockscale.hpp:659
index_t N
Definition: gridwise_moe_gemm_blockscale.hpp:655
index_t KBatch
Definition: gridwise_moe_gemm_blockscale.hpp:661
index_t NPadded
Definition: gridwise_moe_gemm_blockscale.hpp:663
index_t M
Definition: gridwise_moe_gemm_blockscale.hpp:654
index_t NBlock
Definition: gridwise_moe_gemm_blockscale.hpp:669
__host__ void Print() const
Definition: gridwise_moe_gemm_blockscale.hpp:641
index_t NumTokens
Definition: gridwise_moe_gemm_blockscale.hpp:652
index_t StrideA
Definition: gridwise_moe_gemm_blockscale.hpp:657
Definition: gridwise_moe_gemm_blockscale.hpp:747
index_t bscale_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:786
index_t ascale_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:785
index_t a_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:783
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm_blockscale.hpp:748
index_t b_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:784
Definition: gridwise_moe_gemm_blockscale.hpp:179
static constexpr index_t KPack
Definition: gridwise_moe_gemm_blockscale.hpp:204
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm_blockscale.hpp:583
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm_blockscale.hpp:223
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:303
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm_blockscale.hpp:958
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:1170
remove_cvref_t< decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm &&!IsSplitK >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm_blockscale.hpp:956
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm_blockscale.hpp:344
static constexpr auto I4
Definition: gridwise_moe_gemm_blockscale.hpp:187
static constexpr auto I5
Definition: gridwise_moe_gemm_blockscale.hpp:188
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:330
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1957
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm_blockscale.hpp:435
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm_blockscale.hpp:425
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:289
static constexpr auto AK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:195
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm_blockscale.hpp:199
static constexpr auto I2
Definition: gridwise_moe_gemm_blockscale.hpp:185
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1155
static constexpr auto BK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:196
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:324
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1191
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:982
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm_blockscale.hpp:219
static constexpr auto I7
Definition: gridwise_moe_gemm_blockscale.hpp:190
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:523
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:270
static constexpr auto BK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:198
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm_blockscale.hpp:225
static constexpr auto I1
Definition: gridwise_moe_gemm_blockscale.hpp:184
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1162
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm_blockscale.hpp:907
static constexpr auto I0
Definition: gridwise_moe_gemm_blockscale.hpp:183
static constexpr auto I3
Definition: gridwise_moe_gemm_blockscale.hpp:186
float AScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:180
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm_blockscale.hpp:914
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:275
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:319
static constexpr index_t KLane
Definition: gridwise_moe_gemm_blockscale.hpp:217
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm_blockscale.hpp:789
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:296
static constexpr auto I6
Definition: gridwise_moe_gemm_blockscale.hpp:189
float BScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:181
static constexpr index_t NWave
Definition: gridwise_moe_gemm_blockscale.hpp:221
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm_blockscale.hpp:201
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm_blockscale.hpp:192
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:284
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:595
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:310
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:240
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:562
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:265
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t K, index_t KBatch)
Definition: gridwise_moe_gemm_blockscale.hpp:254
static constexpr index_t NLane
Definition: gridwise_moe_gemm_blockscale.hpp:220
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:247
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm_blockscale.hpp:238
static constexpr auto AK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:197
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))> DsGridDesc_M_N
Definition: gridwise_moe_gemm_blockscale.hpp:606
static constexpr index_t KGroup
Definition: gridwise_moe_gemm_blockscale.hpp:206
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:538
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:279
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:532
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm_blockscale.hpp:236
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1255
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1861
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1808
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1855
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:440
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:118
Definition: amd_ck_fp8.hpp:36
Definition: integral_constant.hpp:20
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:270
Definition: unary_element_wise_operation.hpp:1041
Definition: unary_element_wise_operation.hpp:340
Definition: unary_element_wise_operation.hpp:1087