/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Source File#
fmha_fwd_splitkv_kernel.hpp
Go to the documentation of this file.
97 "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
98 "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
99 "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
100 "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
101 (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
102 "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
103 (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
105 (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:146
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
@ ELEMENTWISE_BIAS
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition: page_block_navigator.hpp:333
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:526
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1691
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1634
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
@ MASK_FROM_TOP_LEFT
@ FROM_BOTTOM_RIGHT
scales(Scale) -> scales< Scale >
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_attention_bias_enum.hpp:19
Definition: fmha_fwd_splitkv_kernel.hpp:194
const void * alibi_slope_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:196
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_splitkv_kernel.hpp:197
Definition: fmha_fwd_splitkv_kernel.hpp:189
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:190
Definition: fmha_fwd_splitkv_kernel.hpp:240
ck_tile::index_t batch_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:248
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:246
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:243
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:244
ck_tile::index_t batch_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:249
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:241
Definition: fmha_fwd_splitkv_kernel.hpp:277
ck_tile::index_t batch_idx
Definition: fmha_fwd_splitkv_kernel.hpp:278
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_splitkv_kernel.hpp:280
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_splitkv_kernel.hpp:279
Definition: fmha_fwd_splitkv_kernel.hpp:225
const int32_t * cache_batch_idx
Definition: fmha_fwd_splitkv_kernel.hpp:226
Definition: fmha_fwd_splitkv_kernel.hpp:182
const void * bias_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:183
ck_tile::index_t stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:184
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:185
Definition: fmha_fwd_splitkv_kernel.hpp:121
ck_tile::index_t split_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:156
ck_tile::index_t nhead_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:153
const void * k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:123
ck_tile::index_t num_splits
Definition: fmha_fwd_splitkv_kernel.hpp:140
void * lse_acc_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:125
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:149
void * o_acc_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:126
ck_tile::index_t nhead_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:152
const void * sink_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:127
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:151
ck_tile::index_t hdim_q
Definition: fmha_fwd_splitkv_kernel.hpp:133
const void * v_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:124
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:150
ck_tile::index_t split_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:155
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_splitkv_kernel.hpp:139
ck_tile::index_t stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:144
ck_tile::index_t seqlen_k
Definition: fmha_fwd_splitkv_kernel.hpp:132
ck_tile::index_t batch
Definition: fmha_fwd_splitkv_kernel.hpp:129
const void * q_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:122
ck_tile::index_t stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:146
ck_tile::index_t stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:145
ck_tile::index_t num_head_q
Definition: fmha_fwd_splitkv_kernel.hpp:136
ck_tile::index_t seqlen_q
Definition: fmha_fwd_splitkv_kernel.hpp:131
ck_tile::index_t stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:147
float scale_s
Definition: fmha_fwd_splitkv_kernel.hpp:142
ck_tile::index_t hdim_v
Definition: fmha_fwd_splitkv_kernel.hpp:134
Definition: fmha_fwd_splitkv_kernel.hpp:213
ck_tile::index_t page_block_size
Definition: fmha_fwd_splitkv_kernel.hpp:216
ck_tile::index_t batch_stride_block_table
Definition: fmha_fwd_splitkv_kernel.hpp:215
const int32_t * block_table_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:214
Definition: fmha_fwd_splitkv_kernel.hpp:114
Definition: fmha_fwd_splitkv_kernel.hpp:208
float scale_p
Definition: fmha_fwd_splitkv_kernel.hpp:209
Definition: fmha_fwd_splitkv_kernel.hpp:263
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:266
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:268
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:264
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:265
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:270
Definition: fmha_fwd_splitkv_kernel.hpp:220
bool is_gappy
Definition: fmha_fwd_splitkv_kernel.hpp:221
Definition: fmha_fwd_splitkv_kernel.hpp:160
float logits_soft_cap_rcp
Definition: fmha_fwd_splitkv_kernel.hpp:178
float logits_soft_cap
Definition: fmha_fwd_splitkv_kernel.hpp:177
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_splitkv_kernel.hpp:163
LogitsSoftCapKargs()=default
Definition: fmha_fwd_splitkv_kernel.hpp:201
ck_tile::index_t sink_size
Definition: fmha_fwd_splitkv_kernel.hpp:203
ck_tile::index_t window_size_right
Definition: fmha_fwd_splitkv_kernel.hpp:203
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_splitkv_kernel.hpp:204
ck_tile::index_t window_size_left
Definition: fmha_fwd_splitkv_kernel.hpp:203
Definition: fmha_fwd_splitkv_kernel.hpp:66
Definition: fmha_fwd_splitkv_kernel.hpp:24
static constexpr auto BiasEnum
Definition: fmha_fwd_splitkv_kernel.hpp:50
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_splitkv_kernel.hpp:27
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_splitkv_kernel.hpp:36
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_splitkv_kernel.hpp:598
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_splitkv_kernel.hpp:26
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_splitkv_kernel.hpp:42
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_splitkv_kernel.hpp:47
std::conditional_t< kIsGroupMode, GroupModeKargs, BatchModeKargs > Kargs
Definition: fmha_fwd_splitkv_kernel.hpp:274
static constexpr bool kHasSink
Definition: fmha_fwd_splitkv_kernel.hpp:54
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead_q, ck_tile::index_t nhead_kv, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits)
Definition: fmha_fwd_splitkv_kernel.hpp:541
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_splitkv_kernel.hpp:46
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_splitkv_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_splitkv_kernel.hpp:34
remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_splitkv_kernel.hpp:40
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: fmha_fwd_splitkv_kernel.hpp:55
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_splitkv_kernel.hpp:58
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_splitkv_kernel.hpp:52
static CK_TILE_HOST std::string GetName()
Definition: fmha_fwd_splitkv_kernel.hpp:74
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_splitkv_kernel.hpp:38
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_fwd_splitkv_kernel.hpp:586
static constexpr bool kHasMask
Definition: fmha_fwd_splitkv_kernel.hpp:59
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_splitkv_kernel.hpp:559
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_splitkv_kernel.hpp:48
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_splitkv_kernel.hpp:35
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_splitkv_kernel.hpp:31
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_acc_ptr, void *o_acc_ptr, ck_tile::index_t batch, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, const void *sink_ptr=nullptr)
Definition: fmha_fwd_splitkv_kernel.hpp:419
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_splitkv_kernel.hpp:37
remove_cvref_t< typename FmhaPipeline::OaccDataType > OaccDataType
Definition: fmha_fwd_splitkv_kernel.hpp:39
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_acc_ptr, void *o_acc_ptr, ck_tile::index_t batch, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, const void *sink_ptr=nullptr)
Definition: fmha_fwd_splitkv_kernel.hpp:285
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_splitkv_kernel.hpp:57
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_splitkv_kernel.hpp:28
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_splitkv_kernel.hpp:33
static constexpr bool kStoreLSE
Definition: fmha_fwd_splitkv_kernel.hpp:51
static constexpr bool kIsPagedKV
Definition: fmha_fwd_splitkv_kernel.hpp:53
static constexpr bool kIsGroupMode
Definition: fmha_fwd_splitkv_kernel.hpp:44
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_splitkv_kernel.hpp:603
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_splitkv_kernel.hpp:49
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_splitkv_kernel.hpp:25
Definition: variants.hpp:63
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: numeric.hpp:18
Definition: sequence.hpp:49