/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File#
gridwise_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:209
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
BlockGemmPipelineVersion
Block GEMM pipeline version enumeration.
Definition: scheduler_enum.hpp:17
TailNumber
Tail number enumeration for pipeline buffering.
Definition: scheduler_enum.hpp:49
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
BlockGemmPipelineScheduler
Block GEMM pipeline scheduler enumeration.
Definition: scheduler_enum.hpp:33
Definition: block_to_ctile_map.hpp:271
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:298
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition: block_to_ctile_map.hpp:384
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:415
BsGridPointer p_bs_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:479
__host__ Argument()=default
BElementwiseOperation b_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:484
bool is_reduce
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:488
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:468
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:473
AsGridPointer p_as_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:478
EDataType * p_e_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:481
AElementwiseOperation a_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:483
CDEElementwiseOperation cde_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:485
DsGridPointer p_ds_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:480
__host__ __device__ Argument(std::array< const void *, NumATensor > p_as_grid_, std::array< const void *, NumBTensor > p_bs_grid_, std::array< const void *, NumDTensor > p_ds_grid_, EDataType *p_e_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_, bool is_reduce_=false)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:417
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:338
std::array< index_t, NumBTensor > StrideBs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:398
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:401
index_t BK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:407
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:396
index_t Kt
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:410
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:394
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:368
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:395
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:399
index_t MPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:402
index_t MBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:408
index_t AK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:406
__host__ __device__ Problem(index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t KBatch_)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:340
index_t KRead
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:404
index_t StrideE
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:400
index_t NBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:409
__host__ Problem()=default
std::array< index_t, NumATensor > StrideAs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:397
index_t KPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:405
index_t NPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:403
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:492
index_t c_reduce_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:561
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:494
std::array< index_t, NumATensor > a_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:559
std::array< index_t, NumBTensor > b_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:560
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:297
decltype(MakeBsGridPointer()) BsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:592
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:300
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:848
static constexpr auto AK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:327
__host__ static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:971
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:310
decltype(MakeAsGridPointer()) AsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:591
__host__ static __device__ auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:773
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:299
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:559
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:540
static constexpr auto BK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:326
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:552
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:354
static constexpr auto I3
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:302
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:529
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const BaseDescriptors_M_K &base_descs, const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:620
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:524
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:519
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, LDSTypeA, LDSTypeB, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor()), decltype(MakeBWmmaTileDescriptor()), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, false, IsBPreShuffled >())> BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:896
static constexpr index_t NumBTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:309
static constexpr auto I4
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:303
static constexpr index_t NumATensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:308
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:347
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:846
static constexpr auto I6
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:305
static constexpr auto BK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:328
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:301
static constexpr auto AK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:325
static constexpr auto I5
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:304
static constexpr auto I7
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:306
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const BaseDescriptors_N_K &base_descs, const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:689
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:564
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:534
__host__ static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:861
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:237
decltype(MakeBsGridPointer()) BsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:592
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:743
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:729
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:300
BsDataType BsDataType_
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:335
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:848
__device__ static __host__ auto DefaultBlock2CTileMap(const Problem &problem)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:837
typename Base::BlockwiseGemmPipe BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:564
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const Problem &problem, const Block2CTileMap &block_2_ctile_map, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument &epilogue_args, const index_t A_k_id=0, const index_t B_k_id=0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:580
__host__ static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:971
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:310
decltype(MakeAsGridPointer()) AsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:591
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:326
BlockToCTileMap_Grouped_M00_N0_M01Adapt< 8, MPerBlock, NPerBlock > Block2CTileMap
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:568
AsDataType AsDataType_
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:334
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:299
static __device__ void Run(void *p_shared, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMapExt &block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfN compute_ptr_offset_of_n, const index_t num_k_per_block, Argument &karg, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:860
static __device__ void Run(void *p_shared, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch &compute_ptr_offset_of_batch, const ComputePtrOffsetOfN &compute_ptr_offset_of_n, [[maybe_unused]] const index_t num_k_per_block, Argument &karg, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1119
static __device__ void Run(AsGridPointer &p_as_grid, BsGridPointer &p_bs_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const AsGridDescriptor_AK0_M_AK1 as_grid_desc_ak0_m_ak1, const BsGridDescriptor_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap &block_2_ctile_map, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument &epilogue_args, const index_t A_k_id=0, const index_t B_k_id=0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:651
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:559
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:540
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:552
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:354
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:529
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:524
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:519
__device__ static __host__ auto DefaultBlock2CTileMap(const index_t M, const index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:841
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg, const Block2CTileMap &block_2_ctile_map, EpilogueArgument &epilogue_args, const index_t A_k_id=0, const index_t B_k_id=0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:766
static constexpr index_t NumBTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:309
static constexpr index_t NumATensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:308
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:347
static __device__ void Run(void *p_shared, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const index_t num_k_per_block, Argument &karg, EpilogueArgument &epilogue_args)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:998
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:846
static __device__ index_t GetKBlockPerScale()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:571
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:301
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg, EpilogueArgument &epilogue_args, const index_t A_k_id=0, const index_t B_k_id=0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:817
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:564
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:534
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:674
__host__ static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:861
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: device_base.hpp:270