/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp Source File
device_moe_gemm_blockscale.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 #include <hip/hip_runtime.h>
9 
20 
21 namespace ck {
22 namespace tensor_operation {
23 namespace device {
24 
25 template <typename ALayout,
26  typename BLayout,
27  typename DsLayout,
28  typename CLayout,
29  typename ADataType,
30  typename AScaleDataType,
31  typename BDataType,
32  typename BScaleDataType,
33  typename DsDataType,
34  typename CDataType,
35  typename GemmAccDataType,
36  typename CShuffleDataType,
37  typename AElementwiseOperation,
38  typename BElementwiseOperation,
39  typename CElementwiseOperation,
40  GemmSpecialization GemmSpec,
41  index_t BlockSize,
42  index_t ScaleBlockM,
43  index_t ScaleBlockN,
44  index_t ScaleBlockK,
45  index_t MPerBlock,
46  index_t NPerBlock,
47  index_t KPerBlock,
48  index_t AK1,
49  index_t BK1,
50  index_t MPerXDL,
51  index_t NPerXDL,
52  index_t MXdlPerWave,
53  index_t NXdlPerWave,
54  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
55  typename ABlockTransferThreadClusterArrangeOrder,
56  typename ABlockTransferSrcAccessOrder,
57  index_t ABlockTransferSrcVectorDim,
58  index_t ABlockTransferSrcScalarPerVector,
59  index_t ABlockTransferDstScalarPerVector_AK1,
60  bool ABlockLdsExtraM,
61  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
62  typename BBlockTransferThreadClusterArrangeOrder,
63  typename BBlockTransferSrcAccessOrder,
64  index_t BBlockTransferSrcVectorDim,
65  index_t BBlockTransferSrcScalarPerVector,
66  index_t BBlockTransferDstScalarPerVector_BK1,
67  bool BBlockLdsExtraN,
68  index_t CShuffleMXdlPerWavePerShuffle,
69  index_t CShuffleNXdlPerWavePerShuffle,
70  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
71  typename CDEShuffleBlockTransferScalarPerVectors,
74  index_t ActivationOP = 0,
75  bool NSwizzle = false,
76  bool IsInputGemm = true,
77  bool IsSplitK = false,
78  bool MulRoutedWeight = false,
79  typename IndexType = index_t,
80  typename ComputeTypeA = CDataType,
81  typename ComputeTypeB = ComputeTypeA,
82  typename LDSTypeA = ComputeTypeA,
83  typename LDSTypeB = ComputeTypeB,
84  bool NonTemporalLoadB = false>
87  BLayout,
88  DsLayout,
89  CLayout,
90  ADataType,
91  AScaleDataType,
92  BDataType,
93  BScaleDataType,
94  DsDataType,
95  CDataType,
96  ScaleBlockM,
97  ScaleBlockN,
98  ScaleBlockK,
99  AElementwiseOperation,
100  BElementwiseOperation,
101  CElementwiseOperation>
102 {
104  static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
105  static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
106  static constexpr index_t NumDTensor = DsDataType::Size();
107  template <index_t NXdlPerWave_>
109  ALayout,
110  BLayout,
111  DsLayout,
112  CLayout,
113  ADataType,
114  BDataType,
115  GemmAccDataType,
116  CShuffleDataType,
117  DsDataType,
118  CDataType,
119  AElementwiseOperation,
120  BElementwiseOperation,
121  CElementwiseOperation,
122  GemmSpec,
123  BlockSize,
124  ScaleBlockM,
125  ScaleBlockN,
126  ScaleBlockK,
127  MPerBlock,
128  NPerBlock,
129  KPerBlock,
130  AK1,
131  BK1,
132  MPerXDL,
133  NPerXDL,
134  MXdlPerWave,
135  NXdlPerWave_,
136  ABlockTransferThreadClusterLengths_AK0_M_AK1,
137  ABlockTransferThreadClusterArrangeOrder,
138  ABlockTransferSrcAccessOrder,
139  ABlockTransferSrcVectorDim,
140  ABlockTransferSrcScalarPerVector,
141  ABlockTransferDstScalarPerVector_AK1,
142  false,
143  ABlockLdsExtraM,
144  BBlockTransferThreadClusterLengths_BK0_N_BK1,
145  BBlockTransferThreadClusterArrangeOrder,
146  BBlockTransferSrcAccessOrder,
147  BBlockTransferSrcVectorDim,
148  BBlockTransferSrcScalarPerVector,
149  BBlockTransferDstScalarPerVector_BK1,
150  false,
151  BBlockLdsExtraN,
152  CShuffleMXdlPerWavePerShuffle,
153  math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_),
154  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
155  CDEShuffleBlockTransferScalarPerVectors,
156  BlkGemmPipeSched,
157  BlkGemmPipelineVer,
158  ActivationOP,
159  NSwizzle,
160  IsInputGemm,
161  IsSplitK,
162  MulRoutedWeight,
163  IndexType,
164  ComputeTypeA,
165  ComputeTypeB,
166  LDSTypeA,
167  LDSTypeB,
168  NonTemporalLoadB>;
171 
173 
174  static constexpr index_t APackedSize = []() {
176  return 2;
177  else
178  return 1;
179  }();
180 
181  static constexpr index_t BPackedSize = []() {
183  return 2;
184  else
185  return 1;
186  }();
187 
188  int GetPreShuffleParameters() override { return NPerXDL; }
189 
190  // Invoker
191  struct Invoker : public BaseInvoker
192  {
193  template <typename GridwiseGemm>
194  float RunImp(const typename GridwiseGemm::Argument& arg,
195  const StreamConfig& stream_config = StreamConfig{})
196  {
197  if(stream_config.log_level_ > 0)
198  {
199  arg.Print();
200  }
201 
202  if(!GridwiseGemm::CheckValidity(arg))
203  {
204  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
205  }
206 
207  index_t gdx, gdy, gdz;
208  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
209  arg.M, arg.N * (IsInputGemm && IsSplitK ? 2 : 1), arg.K, arg.KBatch);
210 
211  float ave_time = 0;
212 
213  index_t K_split = arg.KBatch == 1 ? arg.K : arg.KBatch * KPerBlock;
214 
215  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
216  const auto RunKernel = [&](const auto& kernel) {
217  if(stream_config.flush_cache)
218  {
219 
220  std::array<std::size_t, NumDTensor> DsSize;
221 
222  auto arg_ = arg;
223 
224  const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
225  arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
226  const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
227  arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
228 
229  auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
230  sizeof(ADataType) / APackedSize;
231  auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
232  sizeof(BDataType) / BPackedSize;
233 
234  const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
235  arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
236 
237  static_for<0, NumDTensor, 1>{}([&](auto i) {
238  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
239  DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
240  });
241  ck::utility::RotatingMemWrapperMultiD<typename GridwiseGemm::Argument,
242  DsDataType>
243  rotating_mem(arg_,
244  stream_config.rotating_count,
245  size_a_buffer,
246  size_b_buffer,
247  DsSize);
248  rotating_mem.Print();
249 
250  auto run_flush_cache = [&]() {
251  // flush icache
253  // rotating mem
254  rotating_mem.Next();
255  // clear c mem
256  // if(arg_.KBatch > 1)
257  // hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
258  // 0,
259  // arg_.M * arg_.N * sizeof(CDataType)
260  // * (IsInputGemm && IsSplitK ? 2 : 1),
261  // stream_config.stream_id_));
262  };
263 
264  ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
265  stream_config,
266  run_flush_cache,
267  kernel,
268  dim3(gdx, gdy, gdz),
269  dim3(BlockSize),
270  0,
271  arg_);
272  }
273  else
274  {
275  // if(arg.KBatch > 1)
276  // hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
277  // 0,
278  // arg.M * arg.N * sizeof(CDataType) *
279  // (IsInputGemm && IsSplitK ? 2 : 1),
280  // stream_config.stream_id_));
281 
282  ave_time = launch_and_time_kernel(
283  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
284  }
285  };
286 
287  constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
288  4 * (1 + GridwiseGemm::NWave);
289  constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize /
290  4 * (2) * (IsInputGemm ? 2 : 1);
291  constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
292  BlockSize / 4 * (IsInputGemm ? 2 : 1);
293  constexpr auto estimated_reg_total =
294  estimated_reg_a + estimated_reg_b + estimated_reg_c;
295 
296  constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
297 
298  constexpr auto MemoryDataOp = (IsInputGemm && !IsSplitK)
301 
302  if(has_main_k_block_loop)
303  {
304  // Tail number always full
305  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
306  {
307  {
308  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
309  {
310  const auto kernel = kernel_moe_gemm<GridwiseGemm,
311  true,
312  MemoryDataOp,
313  minimum_occupancy,
315  RunKernel(kernel);
316  }
317  else
318  {
319  const auto kernel = kernel_moe_gemm<GridwiseGemm,
320  true,
321  MemoryDataOp,
322  minimum_occupancy,
324  RunKernel(kernel);
325  }
326  }
327  }
328  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
329  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
330  {
331  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
332  {
333  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
334  true,
335  MemoryDataOp,
336  minimum_occupancy,
338  RunKernel(kernel);
339  }
340  else
341  {
342  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
343  true,
344  MemoryDataOp,
345  minimum_occupancy,
347  RunKernel(kernel);
348  }
349  }
350  else
351  {
352  throw std::runtime_error("todo: only v1 & v2 support now");
353  }
354  }
355 #if 1
356  else
357  {
358  // Tail number always 1
359  if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
360  {
361  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
362  {
363  const auto kernel = kernel_moe_gemm<GridwiseGemm,
364  false,
365  MemoryDataOp,
366  minimum_occupancy,
368  RunKernel(kernel);
369  }
370  else
371  {
372  const auto kernel = kernel_moe_gemm<GridwiseGemm,
373  false,
374  MemoryDataOp,
375  minimum_occupancy,
377  RunKernel(kernel);
378  }
379  }
380  else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
381  BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
382  {
383  if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
384  {
385  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
386  false,
387  MemoryDataOp,
388  minimum_occupancy,
390  RunKernel(kernel);
391  }
392  else
393  {
394  const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
395  false,
396  MemoryDataOp,
397  minimum_occupancy,
399  RunKernel(kernel);
400  }
401  }
402  }
403 #endif
404 
405  return ave_time;
406  }
407 
409 
410  // polymorphic
411  float Run(const BaseArgument* p_arg,
412  const StreamConfig& stream_config = StreamConfig{}) override
413  {
414  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
415  }
416  };
417 
418  static constexpr bool IsValidCompilationParameter()
419  {
420  // TODO: properly implement this check
421  return true;
422  }
423 
424  static bool IsSupportedArgument(const Argument& arg)
425  {
426  // only impl kbatch 1 for fp32
427  if(arg.KBatch > 1 && !std::is_same_v<CDataType, float>)
428  {
429  return false;
430  }
431  if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
432  {
433  return false;
434  }
435  if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
436  {
437  return false;
438  }
439 
440  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
441  GemmSpec == GemmSpecialization::NKPadding ||
442  GemmSpec == GemmSpecialization::MNKPadding ||
443  GemmSpec == GemmSpecialization::KPadding))
444  {
445  return false;
446  }
447  if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
448  {
449  return false;
450  }
451  if(arg.KBatch > 1 && arg.K % (KPerBlock * arg.KBatch) != 0)
452  {
453  // Not support Kpadding with KBatch > 1
454  return false;
455  }
456 
457  if(get_warp_size() == 64)
458  {
459  if constexpr(NXdlPerWave64 > 0)
460  {
461  return GridwiseGemm64::CheckValidity(arg);
462  }
463  }
464  else
465  {
466  if constexpr(NXdlPerWave32 > 0)
467  {
469  reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
470  }
471  }
472  return false;
473  }
474 
475  // polymorphic
476  bool IsSupportedArgument(const BaseArgument* p_arg) override
477  {
478  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
479  }
480 
481  static auto MakeArgument(const void* p_sorted_token_ids,
482  const void* p_sorted_expert_ids,
483  const void* p_max_token_id,
484  const void* p_a,
485  const void* p_b,
486  std::array<const void*, NumDTensor> p_ds,
487  void* p_c,
488  index_t NumTokens,
489  index_t TopK,
490  index_t M,
491  index_t N,
492  index_t K,
493  index_t StrideA,
494  index_t StrideB,
495  std::array<index_t, NumDTensor> StrideDs,
496  index_t StrideC,
497  const void* p_a_scale,
498  const void* p_b_scale,
499  index_t KBatch,
500  AElementwiseOperation a_element_op,
501  BElementwiseOperation b_element_op,
502  CElementwiseOperation c_element_op)
503  {
504  return Argument{static_cast<const index_t*>(p_sorted_token_ids),
505  static_cast<const index_t*>(p_sorted_expert_ids),
506  static_cast<const index_t*>(p_max_token_id),
507  static_cast<const ADataType*>(p_a),
508  static_cast<const BDataType*>(p_b),
509  p_ds,
510  static_cast<CDataType*>(p_c),
511  NumTokens,
512  TopK,
513  M,
514  N,
515  K,
516  StrideA,
517  StrideB,
518  StrideDs,
519  StrideC,
520  static_cast<const AScaleDataType*>(p_a_scale),
521  static_cast<const BScaleDataType*>(p_b_scale),
522  KBatch,
523  a_element_op,
524  b_element_op,
525  c_element_op};
526  }
527 
528  static auto MakeInvoker() { return Invoker{}; }
529 
530  // polymorphic
531  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
532  const void* p_b,
533  std::array<const void*, NumDTensor> p_ds,
534  void* p_c,
535  index_t M,
536  index_t N,
537  index_t K,
538  index_t StrideA,
539  index_t StrideB,
540  std::array<ck::index_t, NumDTensor> StrideDs,
541  index_t StrideC,
542  const void* p_a_scale,
543  const void* p_b_scale,
544  // index_t KBatch,
545  AElementwiseOperation a_element_op,
546  BElementwiseOperation b_element_op,
547  CElementwiseOperation c_element_op) override
548  {
549  return std::make_unique<Argument>(nullptr,
550  nullptr,
551  nullptr,
552  static_cast<const ADataType*>(p_a),
553  static_cast<const BDataType*>(p_b),
554  p_ds,
555  static_cast<CDataType*>(p_c),
556  M, // randoms set, no use
557  0,
558  M,
559  N,
560  K,
561  StrideA,
562  StrideB,
563  StrideDs,
564  StrideC,
565  static_cast<const AScaleDataType*>(p_a_scale),
566  static_cast<const BScaleDataType*>(p_b_scale),
567  1, // KBatch,
568  a_element_op,
569  b_element_op,
570  c_element_op);
571  }
572 
573  // polymorphic
574  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
575  {
576  return std::make_unique<Invoker>(Invoker{});
577  }
578 
579  // polymorphic
580  std::string GetTypeString() const override
581  {
582  auto str = std::stringstream();
583 
584  std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
587 
588  std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
592 
593  // clang-format off
594  str << "DeviceMoeGEmm"
595  << "<"
596  << getGemmSpecializationString(GemmSpec) << ", "
597  << std::string(ALayout::name)[0]
598  << std::string(BLayout::name)[0]
599  << std::string(CLayout::name)[0]
600  << ">"
601  << " BlkSize: "
602  << BlockSize << ", "
603  << "BlkTile: "
604  << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
605  << "WaveTile: "
606  << MPerXDL<<"x"<<NPerXDL << ", "
607  << "WaveMap: "
608  << MXdlPerWave<<"x" << NXdlPerWave<<", "
609  << "VmemReadVec: "
610  << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
611  << "BlkGemmPipelineScheduler: "
612  << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
613  << "BlkGemmPipelineVersion: "
614  << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
615  << "BlkGemmPipelinePrefetchStages: "
616  << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
617  // clang-format on
618 
619  return str.str();
620  }
621 };
622 
623 } // namespace device
624 } // namespace tensor_operation
625 } // namespace ck
#define INVOKER_RUN3_IMPL
Definition: device_base.hpp:187
#define GET_NXDL_PER_WAVE_IMPL
Definition: device_base.hpp:87
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:32
GemmSpecialization
Definition: gemm_specialization.hpp:11
void flush_icache()
Definition: flush_cache.hpp:383
Definition: ck.hpp:270
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:209
BlockGemmPipelineVersion
Block GEMM pipeline version enumeration.
Definition: scheduler_enum.hpp:17
@ v2
Memory-optimized pipeline.
@ v3
Compute-optimized pipeline.
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
@ Even
Even number of iterations.
@ Odd
Odd number of iterations.
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:219
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Block GEMM pipeline scheduler enumeration.
Definition: scheduler_enum.hpp:33
@ Intrawave
Schedule within a single wavefront.
@ Interwave
Schedule across multiple wavefronts.
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:301
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:16
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:84
bool is_bf16_atomic_supported()
Definition: device_prop.hpp:113
Definition: stream_config.hpp:9
Definition: gridwise_moe_gemm_blockscale.hpp:674
Definition: gridwise_moe_gemm_blockscale.hpp:179
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:982
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:270
Definition: device_base.hpp:281
Definition: device_gemm_multiple_d_ab_scale.hpp:82
Definition: device_moe_gemm_blockscale.hpp:192
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_moe_gemm_blockscale.hpp:194
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_moe_gemm_blockscale.hpp:411
Definition: device_moe_gemm_blockscale.hpp:102
static constexpr index_t BPackedSize
Definition: device_moe_gemm_blockscale.hpp:181
static constexpr index_t APackedSize
Definition: device_moe_gemm_blockscale.hpp:174
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_moe_gemm_blockscale.hpp:476
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, const void *p_a_scale, const void *p_b_scale, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_moe_gemm_blockscale.hpp:531
typename GridwiseGemm64::Argument Argument
Definition: device_moe_gemm_blockscale.hpp:172
static constexpr GET_NXDL_PER_WAVE_IMPL auto NXdlPerWave64
Definition: device_moe_gemm_blockscale.hpp:104
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_moe_gemm_blockscale.hpp:574
static bool IsSupportedArgument(const Argument &arg)
Definition: device_moe_gemm_blockscale.hpp:424
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, const void *p_a_scale, const void *p_b_scale, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_moe_gemm_blockscale.hpp:481
static auto MakeInvoker()
Definition: device_moe_gemm_blockscale.hpp:528
static constexpr index_t NumDTensor
Definition: device_moe_gemm_blockscale.hpp:106
static constexpr bool IsValidCompilationParameter()
Definition: device_moe_gemm_blockscale.hpp:418
std::string GetTypeString() const override
Definition: device_moe_gemm_blockscale.hpp:580
int GetPreShuffleParameters() override
Definition: device_moe_gemm_blockscale.hpp:188
static constexpr auto NXdlPerWave32
Definition: device_moe_gemm_blockscale.hpp:105
Definition: flush_cache.hpp:174