/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File#
fmha_fwd_kernel.hpp
Go to the documentation of this file.
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:146
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:36
constexpr CK_TILE_HOST_DEVICE auto make_composes(Ts &&... ts)
Definition: unary_element_function.hpp:51
@ ELEMENTWISE_BIAS
@ SYSTEM_NT1
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:526
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1634
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1698
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1685
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_xor_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1738
@ MASK_FROM_TOP_LEFT
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_view.hpp:486
@ FROM_BOTTOM_RIGHT
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_position_encoding.hpp:48
Definition: block_dropout.hpp:53
Definition: block_position_encoding.hpp:137
Definition: fmha_fwd_kernel.hpp:321
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_kernel.hpp:324
ck_tile::index_t batch_idx
Definition: fmha_fwd_kernel.hpp:322
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_kernel.hpp:323
Definition: fmha_fwd_kernel.hpp:151
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_kernel.hpp:154
const void * alibi_slope_ptr
Definition: fmha_fwd_kernel.hpp:153
Definition: fmha_fwd_kernel.hpp:182
ck_tile::index_t batch_stride_q_descale
Definition: fmha_fwd_kernel.hpp:183
ck_tile::index_t batch_stride_v_descale
Definition: fmha_fwd_kernel.hpp:185
ck_tile::index_t batch_stride_k_descale
Definition: fmha_fwd_kernel.hpp:184
Definition: fmha_fwd_kernel.hpp:146
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_kernel.hpp:147
Definition: fmha_fwd_kernel.hpp:251
ck_tile::index_t batch_stride_randval
Definition: fmha_fwd_kernel.hpp:252
Definition: fmha_fwd_kernel.hpp:277
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_kernel.hpp:281
const int32_t * cu_seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:286
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_kernel.hpp:278
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_kernel.hpp:279
const int32_t * cu_seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:285
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_kernel.hpp:280
Definition: fmha_fwd_kernel.hpp:139
const void * bias_ptr
Definition: fmha_fwd_kernel.hpp:140
ck_tile::index_t stride_bias
Definition: fmha_fwd_kernel.hpp:141
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_kernel.hpp:142
Definition: fmha_fwd_kernel.hpp:172
ck_tile::index_t nhead_stride_v_descale
Definition: fmha_fwd_kernel.hpp:175
ck_tile::index_t block_scale_size_q
Definition: fmha_fwd_kernel.hpp:177
ck_tile::index_t nhead_stride_k_descale
Definition: fmha_fwd_kernel.hpp:174
ck_tile::index_t block_scale_size_kv
Definition: fmha_fwd_kernel.hpp:178
ck_tile::index_t nhead_stride_q_descale
Definition: fmha_fwd_kernel.hpp:173
Definition: fmha_fwd_kernel.hpp:216
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_fwd_kernel.hpp:229
float rp_undrop
Definition: fmha_fwd_kernel.hpp:241
ck_tile::index_t stride_randval
Definition: fmha_fwd_kernel.hpp:246
ck_tile::index_t nhead_stride_randval
Definition: fmha_fwd_kernel.hpp:247
void * rand_val_ptr
Definition: fmha_fwd_kernel.hpp:244
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_fwd_kernel.hpp:217
bool is_store_randval
Definition: fmha_fwd_kernel.hpp:243
uint8_t p_undrop_in_uint8_t
Definition: fmha_fwd_kernel.hpp:242
Definition: fmha_fwd_kernel.hpp:87
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_kernel.hpp:111
ck_tile::index_t seqlen_k
Definition: fmha_fwd_kernel.hpp:95
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_kernel.hpp:113
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_kernel.hpp:102
ck_tile::index_t num_head_q
Definition: fmha_fwd_kernel.hpp:99
ck_tile::index_t hdim_q
Definition: fmha_fwd_kernel.hpp:96
const void * v_ptr
Definition: fmha_fwd_kernel.hpp:90
const void * k_ptr
Definition: fmha_fwd_kernel.hpp:89
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_kernel.hpp:110
ck_tile::index_t stride_k
Definition: fmha_fwd_kernel.hpp:106
ck_tile::index_t stride_o
Definition: fmha_fwd_kernel.hpp:108
const void * sink_ptr
Definition: fmha_fwd_kernel.hpp:92
ck_tile::index_t stride_v
Definition: fmha_fwd_kernel.hpp:107
ck_tile::index_t hdim_v
Definition: fmha_fwd_kernel.hpp:97
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_kernel.hpp:112
const void * q_ptr
Definition: fmha_fwd_kernel.hpp:88
ck_tile::index_t seqlen_q
Definition: fmha_fwd_kernel.hpp:94
ck_tile::index_t stride_q
Definition: fmha_fwd_kernel.hpp:105
Definition: fmha_fwd_kernel.hpp:195
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_kernel.hpp:198
void * lse_ptr
Definition: fmha_fwd_kernel.hpp:196
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_kernel.hpp:197
Definition: fmha_fwd_kernel.hpp:165
const void * v_descale_ptr
Definition: fmha_fwd_kernel.hpp:168
const void * k_descale_ptr
Definition: fmha_fwd_kernel.hpp:167
const void * q_descale_ptr
Definition: fmha_fwd_kernel.hpp:166
Definition: fmha_fwd_kernel.hpp:202
bool is_drop_seed_offset_from_host
Definition: fmha_fwd_kernel.hpp:212
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_fwd_kernel.hpp:210
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_fwd_kernel.hpp:211
Definition: fmha_fwd_kernel.hpp:80
Definition: fmha_fwd_kernel.hpp:189
const int32_t * block_scale_seqstart_q_ptr
Definition: fmha_fwd_kernel.hpp:190
const int32_t * block_scale_seqstart_k_ptr
Definition: fmha_fwd_kernel.hpp:191
Definition: fmha_fwd_kernel.hpp:307
const int32_t * seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:310
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_kernel.hpp:308
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:311
const int32_t * cu_seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:315
const int32_t * cu_seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:314
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_kernel.hpp:309
Definition: fmha_fwd_kernel.hpp:117
float logits_soft_cap
Definition: fmha_fwd_kernel.hpp:134
FmhaFwdLogitsSoftCapKargs()=default
float logits_soft_cap_rcp
Definition: fmha_fwd_kernel.hpp:135
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_kernel.hpp:120
Definition: fmha_fwd_kernel.hpp:158
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_kernel.hpp:161
ck_tile::index_t window_size_right
Definition: fmha_fwd_kernel.hpp:160
ck_tile::index_t sink_size
Definition: fmha_fwd_kernel.hpp:160
ck_tile::index_t window_size_left
Definition: fmha_fwd_kernel.hpp:160
Definition: fmha_fwd_kernel.hpp:256
ck_tile::index_t min_seqlen_q
Definition: fmha_fwd_kernel.hpp:257
Definition: fmha_fwd_kernel.hpp:28
static constexpr bool kHasDropout
Definition: fmha_fwd_kernel.hpp:58
static constexpr bool kIsAvailable
Definition: fmha_fwd_kernel.hpp:73
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_kernel.hpp:38
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, const void *block_scale_seqstart_q_ptr, const void *block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_q_descale, ck_tile::index_t nhead_stride_k_descale, ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr, const void *sink_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:1023
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_kernel.hpp:318
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_kernel.hpp:33
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_kernel.hpp:48
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_kernel.hpp:31
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_kernel.hpp:41
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k=false)
Definition: fmha_fwd_kernel.hpp:1129
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_kernel.hpp:39
static constexpr auto QScaleEnum
Definition: fmha_fwd_kernel.hpp:59
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_kernel.hpp:35
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, const void *block_scale_seqstart_q_ptr, const void *block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_q_descale, ck_tile::index_t nhead_stride_k_descale, ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr, const void *sink_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:914
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_q_descale, ck_tile::index_t nhead_stride_k_descale, ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t batch_stride_q_descale, ck_tile::index_t batch_stride_k_descale, ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr, const void *sink_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:329
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_kernel.hpp:54
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_q_descale, ck_tile::index_t nhead_stride_k_descale, ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t batch_stride_q_descale, ck_tile::index_t batch_stride_k_descale, ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr, const void *sink_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:623
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_kernel.hpp:1154
static constexpr bool kSkipMinSeqlenQ
Definition: fmha_fwd_kernel.hpp:60
static constexpr std::string_view kPipelineName
Definition: fmha_fwd_kernel.hpp:75
ck_tile::remove_cvref_t< typename FmhaPipeline::PDataType > PDataType
Definition: fmha_fwd_kernel.hpp:40
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_kernel.hpp:44
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_kernel.hpp:37
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_q_descale, ck_tile::index_t nhead_stride_k_descale, ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t batch_stride_q_descale, ck_tile::index_t batch_stride_k_descale, ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr, const void *sink_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:504
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_fwd_kernel.hpp:1219
CK_TILE_DEVICE void run_(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1242
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_kernel.hpp:63
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_kernel.hpp:1231
static constexpr bool kUseTrLoad
Definition: fmha_fwd_kernel.hpp:69
static constexpr bool kUseAsyncCopy
Definition: fmha_fwd_kernel.hpp:67
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_kernel.hpp:29
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, const void *block_scale_seqstart_q_ptr, const void *block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_q_descale, ck_tile::index_t nhead_stride_k_descale, ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr, const void *sink_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:741
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_kernel.hpp:53
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_kernel.hpp:46
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_kernel.hpp:64
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_kernel.hpp:55
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_fwd_kernel.hpp:43
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_kernel.hpp:30
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_kernel.hpp:52
static constexpr bool kIsGroupMode
Definition: fmha_fwd_kernel.hpp:50
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1236
Definition: variants.hpp:63
Definition: block_dropout.hpp:39
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:114
Definition: numeric.hpp:18
Definition: coordinate_transform.hpp:1393
Definition: unary_element_function.hpp:58
Definition: math.hpp:28
Definition: sequence.hpp:49
T val
Definition: fmha_batch_prefill_kernel.hpp:200
const T * ptr
Definition: fmha_batch_prefill_kernel.hpp:201
Definition: fmha_fwd_kernel.hpp:205
T val
Definition: fmha_fwd_kernel.hpp:206
const T * ptr
Definition: fmha_fwd_kernel.hpp:207