16 typename ComputeTypeA,
17 typename ComputeTypeB,
19 typename AWmmaTileDesc,
20 typename BWmmaTileDesc,
21 index_t ABlockTransferSrcScalarPerVector,
22 index_t BBlockTransferSrcScalarPerVector,
32 bool TransposeC =
false,
33 bool BSkipLDS =
false>
41 typename ComputeTypeA,
42 typename ComputeTypeB,
44 typename AWmmaTileDesc,
45 typename BWmmaTileDesc,
46 index_t ABlockTransferSrcScalarPerVector,
47 index_t BBlockTransferSrcScalarPerVector,
67 ABlockTransferSrcScalarPerVector,
68 BBlockTransferSrcScalarPerVector,
88 ABlockTransferSrcScalarPerVector,
89 BBlockTransferSrcScalarPerVector,
113 ABlockTransferSrcScalarPerVector,
114 BBlockTransferSrcScalarPerVector,
136 using Base::wmma_gemm;
138 using Base::CalculateCThreadOriginDataIndex;
140 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
141 using Base::GetCThreadBuffer;
143 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
145 using Base::a_block_desc_k0_m0_m1_m2_k1;
146 using Base::b_block_desc_k0_n0_n1_n2_k1;
148 using typename Base::Empty;
156 return num_loop > PrefetchStages;
165 template <
bool HasMainLoop,
169 typename ABlockTransfer,
170 typename AGridBuffer,
171 typename ABlockBuffer,
172 typename ABlockTransferStep,
175 typename BBlockTransfer,
176 typename BGridBuffer,
177 typename BBlockBuffer,
178 typename BBlockTransferStep,
179 typename CThreadBuffer,
180 typename AScaleStruct,
181 typename BScaleStruct,
183 __device__
void Run(
const AGridDesc& a_grid_desc,
184 const ABlockDesc& a_block_desc,
185 ABlockTransfer& a_blockwise_copy,
186 const AGridBuffer& a_grid_buf,
187 ABlockBuffer& a_block_buf,
188 const ABlockTransferStep& a_block_copy_step,
189 const BGridDesc& b_grid_desc,
190 const BBlockDesc& b_block_desc,
191 BBlockTransfer& b_blockwise_copy,
192 const BGridBuffer& b_grid_buf,
193 BBlockBuffer& b_block_buf,
194 const BBlockTransferStep& b_block_copy_step,
195 CThreadBuffer& c_thread_buf,
197 BScaleStruct& b_scale_struct,
199 index_t num_loop_per_scale)
const
201 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
203 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
204 a_thread_desc_.GetElementSpaceSize());
205 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
206 b_thread_desc_.GetElementSpaceSize());
209 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
210 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
212 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
213 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
216 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
219 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
220 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
223 c_thread_buf.Clear();
225 auto blockwise_gemm_func = [&]() {
229 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
235 if constexpr(m0 == I0)
240 b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
252 b_block_desc_k0_n0_n1_n2_k1,
255 b_scale_struct.scale_thread_bufs(
256 I0)[
Number<n0 * BScaleStruct::num_scale_k_block +
257 k0 / BScaleStruct::num_scale_krepeat>{}],
267 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
268 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
270 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
271 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
272 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
273 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
282 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
283 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
284 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
285 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
295 using wmma_input_type_a =
296 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
297 using wmma_input_type_b =
298 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
301 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
303 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
304 b_thread_vec.template AsType<wmma_input_type_b>(),
313 if constexpr(HasMainLoop)
318 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
319 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
321 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
322 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
325 blockwise_gemm_func();
328 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
333 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
334 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
336 constexpr
index_t num_ds_write_inst =
337 HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
339 constexpr
index_t num_buffer_load_inst = HotLoopInstList::A_Buffer_Load_Inst_Num +
340 HotLoopInstList::B_Buffer_Load_Inst_Num;
342 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
346 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
347 if constexpr(m0 == I0)
350 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
355 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
361 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
365 }
while(i < (num_loop - 1));
372 blockwise_gemm_func();
376 template <
bool HasMainLoop,
380 typename ABlockTransfer,
381 typename AGridBuffer,
382 typename ABlockBuffer,
383 typename ABlockTransferStep,
386 typename BBlockTransfer,
387 typename BGridBuffer,
388 typename BBlockBuffer,
389 typename BBlockTransferStep,
390 typename CThreadBuffer,
391 typename AScaleStruct,
392 typename BScaleStruct,
394 !ck::is_same_v<BScaleStruct, Empty>,
396 __device__
void Run(
const AGridDesc& a_grid_desc,
397 const ABlockDesc& a_block_desc,
398 ABlockTransfer& a_blockwise_copy,
399 const AGridBuffer& a_grid_buf,
400 ABlockBuffer& a_block_buf,
401 const ABlockTransferStep& a_block_copy_step,
402 const BGridDesc& b_grid_desc,
403 const BBlockDesc& b_block_desc,
404 BBlockTransfer& b_blockwise_copy,
405 const BGridBuffer& b_grid_buf,
406 BBlockBuffer& b_block_buf,
407 const BBlockTransferStep& b_block_copy_step,
408 CThreadBuffer& c_thread_buf,
409 AScaleStruct& a_scale_struct,
410 BScaleStruct& b_scale_struct,
412 index_t num_loop_per_scale)
const
414 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
415 static constexpr
auto NumScaleKBlock =
418 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
419 Base::a_thread_desc_.GetElementSpaceSize());
420 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
421 Base::b_thread_desc_.GetElementSpaceSize());
423 using CScaleStruct =
typename Base::template CScale<AScaleStruct, BScaleStruct>;
424 auto c_scale_struct = CScaleStruct{};
427 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
428 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
430 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
431 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
434 a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
435 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
437 c_scale_struct.Load(a_scale_struct, b_scale_struct);
440 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
441 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
444 c_thread_buf.Clear();
446 auto blockwise_gemm_func = [&]() {
450 Base::a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
453 Base::a_thread_desc_,
458 Base::b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
461 Base::b_thread_desc_,
470 c_scale_struct.Clear();
471 static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](
auto k0) {
472 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
473 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
476 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
477 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
479 kscale0 * (KRepeat / NumScaleKBlock) + k0;
480 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
481 a_thread_buf[
Number<Base::a_thread_desc_.CalculateOffset(
490 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
491 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
493 kscale0 * (KRepeat / NumScaleKBlock) + k0;
494 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
495 b_thread_buf[
Number<Base::b_thread_desc_.CalculateOffset(
505 using wmma_input_type_a =
506 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
507 using wmma_input_type_b =
508 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
511 a_thread_vec.template AsType<wmma_input_type_a>(),
512 b_thread_vec.template AsType<wmma_input_type_b>(),
513 c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
517 c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
524 if constexpr(HasMainLoop)
529 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
530 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
532 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
533 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
535 a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
536 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
539 blockwise_gemm_func();
542 c_scale_struct.Load(a_scale_struct, b_scale_struct);
544 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
545 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
548 }
while(i < (num_loop - 1));
555 blockwise_gemm_func();
571 decltype(a_block_desc_k0_m0_m1_m2_k1),
572 decltype(a_thread_desc_),
573 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
582 decltype(b_block_desc_k0_n0_n1_n2_k1),
583 decltype(b_thread_desc_),
584 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
590 AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
591 BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
592 using Base::c_thread_desc_;
598 typename ComputeTypeA,
599 typename ComputeTypeB,
600 typename AccDataType,
601 typename AWmmaTileDesc,
602 typename BWmmaTileDesc,
603 index_t ABlockTransferSrcScalarPerVector,
604 index_t BBlockTransferSrcScalarPerVector,
624 ABlockTransferSrcScalarPerVector,
625 BBlockTransferSrcScalarPerVector,
645 ABlockTransferSrcScalarPerVector,
646 BBlockTransferSrcScalarPerVector,
670 ABlockTransferSrcScalarPerVector,
671 BBlockTransferSrcScalarPerVector,
692 using Base::wmma_gemm;
694 using Base::CalculateCThreadOriginDataIndex;
696 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
697 using Base::GetCThreadBuffer;
699 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
701 using Base::a_block_desc_k0_m0_m1_m2_k1;
702 using Base::b_block_desc_k0_n0_n1_n2_k1;
704 using typename Base::Empty;
715 return num_loop > PrefetchStages;
724 template <
typename AScaleStruct,
typename BScaleStruct>
727 static constexpr
auto KRepeatNoScale = 1;
728 static constexpr
auto NumScaleKBlock =
730 static constexpr
auto KRepeatPerNumScaleKBlock = KRepeatPerCluster / NumScaleKBlock;
736 static constexpr
index_t KRepeatNoScale = KRepeatPerCluster;
738 static constexpr
index_t KRepeatPerNumScaleKBlock = 1;
741 template <
bool HasMainLoop,
745 typename ABlockTransfer,
746 typename AGridBuffer,
747 typename ABlockBuffer,
748 typename ABlockTransferStep,
751 typename BBlockTransfer,
752 typename BGridBuffer,
753 typename BBlockBuffer,
754 typename BBlockTransferStep,
755 typename CThreadBuffer,
756 typename AScaleStruct,
757 typename BScaleStruct,
759 __device__
void Run(
const AGridDesc& a_grid_desc,
760 const ABlockDesc& a_block_desc,
761 ABlockTransfer& a_blockwise_copy,
762 const AGridBuffer& a_grid_buf,
763 ABlockBuffer& a_block_buf,
764 const ABlockTransferStep& a_block_copy_step,
765 const BGridDesc& b_grid_desc,
766 const BBlockDesc& b_block_desc,
767 BBlockTransfer& b_blockwise_copy,
768 const BGridBuffer& b_grid_buf,
769 BBlockBuffer& b_block_buf,
770 const BBlockTransferStep& b_block_copy_step,
771 CThreadBuffer& c_thread_buf,
773 BScaleStruct& b_scale_struct,
775 index_t num_loop_per_scale)
const
777 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
779 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
780 a_thread_desc_.GetElementSpaceSize());
781 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
782 b_thread_desc_.GetElementSpaceSize());
785 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
786 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
788 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
789 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
792 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
795 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
796 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
799 c_thread_buf.Clear();
801 auto blockwise_gemm_func = [&]() {
805 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
806 make_tuple(I0, m0, k0_offset + k0_inner, I0, I0, I0, I0),
816 b_block_desc_k0_n0_n1_n2_k1,
817 make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
828 b_block_desc_k0_n0_n1_n2_k1,
829 make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
831 b_scale_struct.scale_thread_bufs(I0)[
Number<
832 n0 * BScaleStruct::num_scale_k_block +
833 (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
841 __builtin_amdgcn_sched_barrier(0);
848 if constexpr(k0_offset != 0 || KRepeat == 1)
850 __builtin_amdgcn_s_barrier();
851 __builtin_amdgcn_sched_barrier(0);
857 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
858 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
860 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
861 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
862 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
863 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
872 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
873 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
874 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
875 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
885 using wmma_input_type_a =
886 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
887 using wmma_input_type_b =
888 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
891 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
899 if constexpr(k0_offset + k0_inner == KRepeat - 1 &&
900 m0 == MRepeat - 1 && n0 == NRepeat - 1)
902 __builtin_amdgcn_sched_barrier(0);
904 __builtin_amdgcn_sched_barrier(0);
907 a_thread_vec.template AsType<wmma_input_type_a>(),
908 b_thread_vec.template AsType<wmma_input_type_b>(),
910 if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
912 __builtin_amdgcn_sched_barrier(0);
913 __builtin_amdgcn_s_setprio(1);
914 __builtin_amdgcn_sched_barrier(0);
921 __builtin_amdgcn_sched_barrier(0);
922 __builtin_amdgcn_s_setprio(0);
923 __builtin_amdgcn_sched_barrier(0);
928 if constexpr(HasMainLoop)
933 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
934 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
936 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
937 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
940 blockwise_gemm_func();
942 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
947 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
948 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
951 }
while(i < (num_loop - 1));
958 blockwise_gemm_func();
963 static constexpr
auto a_thread_desc_ =
966 Number<KRepeatPerCluster>{},
973 Number<KPack / A_KRow * MRepeat>{},
979 static constexpr
auto b_thread_desc_ =
982 Number<KRepeatPerCluster>{},
989 Number<KPack / B_KRow * NRepeat>{},
998 decltype(a_block_desc_k0_m0_m1_m2_k1),
999 decltype(a_thread_desc_),
1000 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
1009 decltype(b_block_desc_k0_n0_n1_n2_k1),
1010 decltype(b_thread_desc_),
1011 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
1019 using Base::c_thread_desc_;
1025 typename ComputeTypeA,
1026 typename ComputeTypeB,
1027 typename AccDataType,
1028 typename AWmmaTileDesc,
1029 typename BWmmaTileDesc,
1030 index_t ABlockTransferSrcScalarPerVector,
1031 index_t BBlockTransferSrcScalarPerVector,
1051 ABlockTransferSrcScalarPerVector,
1052 BBlockTransferSrcScalarPerVector,
1072 ABlockTransferSrcScalarPerVector,
1073 BBlockTransferSrcScalarPerVector,
1097 ABlockTransferSrcScalarPerVector,
1098 BBlockTransferSrcScalarPerVector,
1112 using Base::WaveSize;
1119 using Base::KRepeat;
1122 using Base::wmma_gemm;
1124 using Base::CalculateCThreadOriginDataIndex;
1126 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
1127 using Base::GetCThreadBuffer;
1129 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
1131 using Base::a_block_desc_k0_m0_m1_m2_k1;
1132 using Base::b_block_desc_k0_n0_n1_n2_k1;
1134 using typename Base::Empty;
1149 constexpr
auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
1150 constexpr
auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
1151 constexpr
auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
1152 constexpr
auto wmma_interleave = 2;
1156 if constexpr(MPerBlock >= 128 && NPerBlock >= 128)
1158 __builtin_amdgcn_sched_group_barrier(0x008, 2 * wmma_interleave, 0);
1162 __builtin_amdgcn_sched_group_barrier(0x008, wmma_interleave, 0);
1164 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
1170 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
1171 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
1172 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
1173 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
1179 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
1180 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
1184 template <
bool HasMainLoop,
1187 typename ABlockDesc,
1188 typename ABlockTransfer,
1189 typename AGridBuffer,
1190 typename ABlockBuffer,
1191 typename ABlockTransferStep,
1193 typename BBlockDesc,
1194 typename BBlockTransfer,
1195 typename BGridBuffer,
1196 typename BBlockBuffer,
1197 typename BBlockTransferStep,
1198 typename CThreadBuffer,
1199 typename AScaleStruct,
1200 typename BScaleStruct,
1202 __device__
void Run(
const AGridDesc& a_grid_desc,
1203 const ABlockDesc& a_block_desc,
1204 ABlockTransfer& a_blockwise_copy,
1205 const AGridBuffer& a_grid_buf,
1206 ABlockBuffer& a_block_buf,
1207 const ABlockTransferStep& a_block_copy_step,
1208 const BGridDesc& b_grid_desc,
1210 BBlockTransfer& b_blockwise_copy,
1211 const BGridBuffer& b_grid_buf,
1213 const BBlockTransferStep& b_block_copy_step,
1214 CThreadBuffer& c_thread_buf,
1220 __builtin_amdgcn_sched_barrier(0);
1221 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
1223 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
1224 a_thread_desc_.GetElementSpaceSize());
1225 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
1226 b_thread_desc_.GetElementSpaceSize());
1229 constexpr
auto b_block_origin_idx =
make_tuple(I0, I0, I0, I0, I0, I0, I0);
1232 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
1233 b_blockwise_copy.Run(b_grid_desc,
1235 b_block_desc_k0_n0_n1_n2_k1,
1239 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1240 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1241 __builtin_amdgcn_sched_barrier(0);
1244 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1247 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
1248 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1254 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1264 c_thread_buf.Clear();
1266 __builtin_amdgcn_sched_barrier(0);
1269 if constexpr(HasMainLoop)
1274 auto LoopFunc = [&](
auto wmma_reg_buf,
auto local_read_buf) {
1275 b_blockwise_copy.Run(b_grid_desc,
1277 b_block_desc_k0_n0_n1_n2_k1,
1279 b_thread_bufs(local_read_buf));
1281 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1285 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf);
1287 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
1288 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1293 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1294 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1296 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
1297 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1298 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1299 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1308 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
1309 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1310 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1311 b_thread_bufs[wmma_reg_buf]
1312 [
Number<b_thread_desc_.CalculateOffset(
1321 using wmma_input_type_a =
1322 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1323 using wmma_input_type_b =
1324 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1327 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
1330 a_thread_vec.template AsType<wmma_input_type_a>(),
1331 b_thread_vec.template AsType<wmma_input_type_b>(),
1343 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1353 __builtin_amdgcn_sched_barrier(0);
1360 }
while(i < (num_loop - 2));
1366 b_blockwise_copy.Run(b_grid_desc,
1368 b_block_desc_k0_n0_n1_n2_k1,
1374 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1379 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1380 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1382 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
1383 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1384 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1385 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1394 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
1395 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1396 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1397 b_thread_bufs[I0][
Number<b_thread_desc_.CalculateOffset(
1407 using wmma_input_type_a =
1408 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1409 using wmma_input_type_b =
1410 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1413 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
1415 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1416 b_thread_vec.template AsType<wmma_input_type_b>(),
1428 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1437 __builtin_amdgcn_sched_barrier(0);
1442 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1443 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1445 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
1446 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1447 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1448 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1457 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
1458 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1459 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1460 b_thread_bufs[I1][
Number<b_thread_desc_.CalculateOffset(
1469 using wmma_input_type_a =
1470 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1471 using wmma_input_type_b =
1472 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1475 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
1477 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1478 b_thread_vec.template AsType<wmma_input_type_b>(),
1493 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1494 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1496 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
1497 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1498 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1499 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1508 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
1509 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1510 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1511 b_thread_bufs[I0][
Number<b_thread_desc_.CalculateOffset(
1520 using wmma_input_type_a =
1521 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1522 using wmma_input_type_b =
1523 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1526 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
1528 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1529 b_thread_vec.template AsType<wmma_input_type_b>(),
1538 template <
bool HasMainLoop,
1541 typename ABlockDesc,
1542 typename ABlockTransfer,
1543 typename AGridBuffer,
1544 typename ABlockBuffer,
1545 typename ABlockTransferStep,
1547 typename BBlockDesc,
1548 typename BBlockTransfer,
1549 typename BGridBuffer,
1550 typename BBlockBuffer,
1551 typename BBlockTransferStep,
1552 typename CThreadBuffer,
1553 typename AScaleStruct,
1554 typename BScaleStruct,
1556 !ck::is_same_v<BScaleStruct, Empty>,
1557 bool>::type =
false>
1558 __device__
void Run(
const AGridDesc& a_grid_desc,
1559 const ABlockDesc& a_block_desc,
1560 ABlockTransfer& a_blockwise_copy,
1561 const AGridBuffer& a_grid_buf,
1562 ABlockBuffer& a_block_buf,
1563 const ABlockTransferStep& a_block_copy_step,
1564 const BGridDesc& b_grid_desc,
1566 BBlockTransfer& b_blockwise_copy,
1567 const BGridBuffer& b_grid_buf,
1569 const BBlockTransferStep& b_block_copy_step,
1570 CThreadBuffer& c_thread_buf,
1571 AScaleStruct& a_scale_struct,
1572 BScaleStruct& b_scale_struct,
1574 index_t num_loop_per_scale)
const
1576 __builtin_amdgcn_sched_barrier(0);
1577 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
1578 static constexpr
auto NumScaleKBlock =
1581 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
1582 a_thread_desc_.GetElementSpaceSize());
1583 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
1584 b_thread_desc_.GetElementSpaceSize());
1587 constexpr
auto b_block_origin_idx =
make_tuple(I0, I0, I0, I0, I0, I0, I0);
1589 using CScaleStruct =
typename Base::template CScale<AScaleStruct, BScaleStruct>;
1590 auto c_scale_struct = CScaleStruct{};
1592 auto gemm_core_func = [&](
auto reg_buf) {
1596 c_scale_struct.Clear();
1597 static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](
auto k0) {
1598 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1599 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1601 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
1602 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1604 kscale0 * (KRepeat / NumScaleKBlock) + k0;
1605 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1606 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1615 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
1616 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1618 kscale0 * (KRepeat / NumScaleKBlock) + k0;
1619 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1620 b_thread_bufs[reg_buf]
1621 [
Number<b_thread_desc_.CalculateOffset(
1630 using wmma_input_type_a =
1631 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1632 using wmma_input_type_b =
1633 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1635 a_thread_vec.template AsType<wmma_input_type_a>(),
1636 b_thread_vec.template AsType<wmma_input_type_b>(),
1637 c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
1641 c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
1647 auto a_local_prefetch_func = [&]() {
1650 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1661 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
1662 b_blockwise_copy.Run(b_grid_desc,
1664 b_block_desc_k0_n0_n1_n2_k1,
1668 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1669 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1672 a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
1673 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
1675 __builtin_amdgcn_sched_barrier(0);
1677 c_scale_struct.Load(a_scale_struct, b_scale_struct);
1680 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1683 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
1684 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1688 a_local_prefetch_func();
1691 c_thread_buf.Clear();
1693 __builtin_amdgcn_sched_barrier(0);
1696 if constexpr(HasMainLoop)
1701 auto LoopFunc = [&](
auto wmma_reg_buf,
auto local_read_buf) {
1702 b_blockwise_copy.Run(b_grid_desc,
1704 b_block_desc_k0_n0_n1_n2_k1,
1706 b_thread_bufs(local_read_buf));
1708 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1712 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf);
1714 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
1715 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1717 a_scale_struct.template GlobalLoad<0>(
1718 (i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
1719 b_scale_struct.template GlobalLoad<0>(
1720 (i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
1722 gemm_core_func(wmma_reg_buf);
1727 a_local_prefetch_func();
1729 c_scale_struct.Load(a_scale_struct, b_scale_struct);
1732 __builtin_amdgcn_sched_barrier(0);
1739 }
while(i < (num_loop - 2));
1745 b_blockwise_copy.Run(b_grid_desc,
1747 b_block_desc_k0_n0_n1_n2_k1,
1753 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1755 a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
1756 b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
1763 a_local_prefetch_func();
1765 c_scale_struct.Load(a_scale_struct, b_scale_struct);
1767 __builtin_amdgcn_sched_barrier(0);
1781 static constexpr
auto b_thread_desc_ =
1790 using Base::a_thread_copy_;
1791 using Base::a_thread_desc_;
1792 using Base::c_thread_desc_;
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:211
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Tail number enumeration for pipeline buffering.
Definition: scheduler_enum.hpp:49
@ Even
Even number of iterations.
@ Odd
Odd number of iterations.
@ Full
Full tail iterations.
@ Empty
No tail iterations.
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
BlockGemmPipelineScheduler
Block GEMM pipeline scheduler enumeration.
Definition: scheduler_enum.hpp:33
@ Intrawave
Schedule within a single wavefront.
@ Interwave
Schedule across multiple wavefronts.
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:212
int32_t index_t
Definition: ck.hpp:301
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:759
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::BlockHasHotloop __host__ static __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:713
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:718
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:159
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:183
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::BlockHasHotloop static bool __host__ __device__ BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:154
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &a_scale_struct, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:396
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, true >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &, BScaleStruct &, index_t num_loop, index_t) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:1202
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, true >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:1147
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, true >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:1140
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, true >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:1142
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, true >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, AScaleStruct &a_scale_struct, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:1558
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:35
Definition: sequence.hpp:43
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), Sequence< KPack/A_K1/A_KRow, 1, 1, 1, 1, 1, A_K1 >, Sequence< 0, 1, 2, 3, 4, 5, 6 >, 6, A_K1, A_K1 >
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:11