/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp Source File#
universal_gemm_kernel.hpp
Go to the documentation of this file.
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:146
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1691
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1634
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:209
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Scheduler for persistent GEMM kernels with asynchronous input streaming.
Definition: persistent_async_input_scheduler.hpp:27
uint32_t tiles_per_chunk_m
Number of M-dimension tiles grouped into each chunk. Grouping tiles balances synchronization overhead...
Definition: persistent_async_input_scheduler.hpp:31
int32_t tile_idx_pivot_m
Pivot offset for rotating the chunk assignment. Allows shifting which tiles map to which chunks,...
Definition: persistent_async_input_scheduler.hpp:41
uint32_t * chunk_signals
Device pointer to array of signal values (uint32_t), one per chunk. Producer sets signals to coordina...
Definition: persistent_async_input_scheduler.hpp:36
uint32_t num_chunks
Number of signal chunks allocated. Must equal ceil((tiles_m + tile_idx_pivot_m) / tiles_per_chunk_m)....
Definition: persistent_async_input_scheduler.hpp:46
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:34
const std::array< index_t, NumDTensor > stride_Ds
Definition: universal_gemm_kernel.hpp:78
const std::array< index_t, NumBTensor > stride_Bs
Definition: universal_gemm_kernel.hpp:77
const std::array< const void *, NumDTensor > ds_ptr
Definition: universal_gemm_kernel.hpp:67
const std::array< const void *, NumATensor > as_ptr
Definition: universal_gemm_kernel.hpp:65
const std::array< index_t, NumATensor > stride_As
Definition: universal_gemm_kernel.hpp:76
CK_TILE_HOST UniversalGemmHostArgs(const std::array< const void *, NumATensor > &as_ptr_, const std::array< const void *, NumBTensor > &bs_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, const std::array< index_t, NumATensor > &stride_As_, const std::array< index_t, NumBTensor > &stride_Bs_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_, PersistentAsyncInputScheduler async_input_scheduler_=PersistentAsyncInputScheduler{})
Definition: universal_gemm_kernel.hpp:35
PersistentAsyncInputScheduler async_input_scheduler
Definition: universal_gemm_kernel.hpp:86
const std::array< const void *, NumBTensor > bs_ptr
Definition: universal_gemm_kernel.hpp:66
Definition: universal_gemm_kernel.hpp:336
std::array< index_t, NumATensor > as_k_split_offset
Definition: universal_gemm_kernel.hpp:401
index_t splitted_k
Definition: universal_gemm_kernel.hpp:403
__device__ SplitKBatchOffset(const KernelArgs &kargs, const index_t k_id=blockIdx.z)
Definition: universal_gemm_kernel.hpp:339
std::array< index_t, NumBTensor > bs_k_split_offset
Definition: universal_gemm_kernel.hpp:402
Definition: universal_gemm_kernel.hpp:214
static constexpr bool value
Definition: universal_gemm_kernel.hpp:218
decltype(T::UsePersistentKernel) has_persistent_type
Definition: universal_gemm_kernel.hpp:216
Definition: universal_gemm_kernel.hpp:229
decltype(T::GetOutputOffset(std::declval< KernelArgs >(), std::declval< index_t >())) has_get_output_offset_t
Definition: universal_gemm_kernel.hpp:232
static constexpr bool value
Definition: universal_gemm_kernel.hpp:234
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:92
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:100
std::array< index_t, NumBTensor > stride_Bs
The distance between consecutive elements of non-contiguous dimension (in memory) of Bs tensor.
Definition: universal_gemm_kernel.hpp:112
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:98
std::array< index_t, NumATensor > stride_As
The distance between consecutive elements of non-contiguous dimension (in memory) of As tensor.
Definition: universal_gemm_kernel.hpp:109
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
index_t N
GEMM's N dimension size.
Definition: universal_gemm_kernel.hpp:104
index_t stride_E
The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.
Definition: universal_gemm_kernel.hpp:118
index_t K
GEMM's K dimension size.
Definition: universal_gemm_kernel.hpp:106
PersistentAsyncInputScheduler async_input_scheduler
Persistent async input scheduler for chunk-based tile scheduling.
Definition: universal_gemm_kernel.hpp:121
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:96
std::array< index_t, NumDTensor > stride_Ds
The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.
Definition: universal_gemm_kernel.hpp:115
index_t M
GEMM's M dimension size.
Definition: universal_gemm_kernel.hpp:102
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:162
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: universal_gemm_kernel.hpp:1158
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: universal_gemm_kernel.hpp:164
static CK_TILE_HOST const std::string GetName()
Definition: universal_gemm_kernel.hpp:270
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: universal_gemm_kernel.hpp:163
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: universal_gemm_kernel.hpp:1197
static CK_TILE_DEVICE auto GetGridSize() -> index_t
Definition: universal_gemm_kernel.hpp:1130
static constexpr bool BDataTypeIsTuple
Definition: universal_gemm_kernel.hpp:169
static CK_TILE_DEVICE auto MakeBBlockWindows(const std::array< const BDataType *, NumBTensor > &bs_ptr, const KernelArgs &kargs, const index_t k_size, const index_t i_n)
Definition: universal_gemm_kernel.hpp:770
static constexpr bool BLayoutIsTuple
Definition: universal_gemm_kernel.hpp:175
remove_cvref_t< typename GemmPipeline::BElementWise > BElementWise
Definition: universal_gemm_kernel.hpp:208
static constexpr index_t NumATensor
Definition: universal_gemm_kernel.hpp:249
static constexpr bool ALayoutIsTuple
Definition: universal_gemm_kernel.hpp:173
remove_cvref_t< std::tuple_element_t< I0, AsDataType > > ADataType
Definition: universal_gemm_kernel.hpp:253
std::conditional_t< ALayoutIsTuple, remove_cvref_t< typename GemmPipeline::AsLayout >, remove_cvref_t< tuple< typename GemmPipeline::ALayout > >> AsLayout
Definition: universal_gemm_kernel.hpp:182
std::conditional_t< DDataTypeIsTuple, remove_cvref_t< typename EpiloguePipeline::DsDataType >, remove_cvref_t< tuple< typename EpiloguePipeline::DsDataType > >> DsDataType
Definition: universal_gemm_kernel.hpp:202
static constexpr bool ADataTypeIsTuple
Definition: universal_gemm_kernel.hpp:167
static constexpr bool has_tile_partitioner_output_offset
Definition: universal_gemm_kernel.hpp:241
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< typename GemmPipeline::AsDataType >, remove_cvref_t< tuple< typename GemmPipeline::ADataType > >> AsDataType
Definition: universal_gemm_kernel.hpp:193
static constexpr index_t NumDTensor
Definition: universal_gemm_kernel.hpp:251
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:1066
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition: universal_gemm_kernel.hpp:268
static constexpr bool DDataTypeIsTuple
Definition: universal_gemm_kernel.hpp:171
static CK_TILE_DEVICE auto GetTileCoordinates(const KernelArgs &kargs) -> tuple< index_t, index_t >
Definition: universal_gemm_kernel.hpp:1107
static CK_TILE_DEVICE auto MakeCBlockWindows(EDataType *e_ptr, const KernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:998
static constexpr bool PersistentKernel
Definition: universal_gemm_kernel.hpp:225
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: universal_gemm_kernel.hpp:204
static CK_TILE_DEVICE auto MakeDBlockWindows(const std::array< const void *, NumDTensor > &ds_ptr, const KernelArgs &kargs, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:921
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: universal_gemm_kernel.hpp:277
static CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:300
remove_cvref_t< std::tuple_element_t< I0, BsDataType > > BDataType
Definition: universal_gemm_kernel.hpp:254
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Calculate grid size that maximizes hardware utilization for persistent kernels.
Definition: universal_gemm_kernel.hpp:288
static constexpr index_t NumBTensor
Definition: universal_gemm_kernel.hpp:250
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:406
std::conditional_t< DLayoutIsTuple, remove_cvref_t< typename EpiloguePipeline::DsLayout >, remove_cvref_t< tuple< typename EpiloguePipeline::DsLayout > >> DsLayout
Definition: universal_gemm_kernel.hpp:189
static CK_TILE_HOST_DEVICE auto GetNumTiles(Args &&... args) -> index_t
Definition: universal_gemm_kernel.hpp:1138
static constexpr bool DLayoutIsTuple
Definition: universal_gemm_kernel.hpp:177
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: universal_gemm_kernel.hpp:165
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< typename GemmPipeline::BsDataType >, remove_cvref_t< tuple< typename GemmPipeline::BDataType > >> BsDataType
Definition: universal_gemm_kernel.hpp:197
static CK_TILE_DEVICE auto MakeABlockWindows(const std::array< const ADataType *, NumATensor > &as_ptr, const KernelArgs &kargs, const index_t k_size, const index_t i_m)
Definition: universal_gemm_kernel.hpp:693
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:330
remove_cvref_t< typename GemmPipeline::AElementWise > AElementWise
Definition: universal_gemm_kernel.hpp:207
std::conditional_t< BLayoutIsTuple, remove_cvref_t< typename GemmPipeline::BsLayout >, remove_cvref_t< tuple< typename GemmPipeline::BLayout > >> BsLayout
Definition: universal_gemm_kernel.hpp:185
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const UniversalGemmHostArgs< NumATensor, NumBTensor, NumDTensor > &hostArgs)
Definition: universal_gemm_kernel.hpp:313
static CK_TILE_DEVICE auto GetBlockId() -> index_t
Definition: universal_gemm_kernel.hpp:1124
static constexpr index_t kBlockSize
Definition: universal_gemm_kernel.hpp:210
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: universal_gemm_kernel.hpp:205
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: numeric.hpp:81
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: stream_config.hpp:30
Definition: tuple.hpp:192
Definition: workgroup_barrier.hpp:12
CK_TILE_DEVICE void wait_eq_wave(uint32_t value, uint32_t offset=0)
Definition: workgroup_barrier.hpp:30