44 template <
typename Problem_,
typename Policy_ =
void>
55 template <
bool kProcessIndex,
56 typename XDistributedTensor_,
57 typename YDistributedTensor_,
58 typename YIndexDistributedTensor_,
60 typename IndexCalculatorFunc,
61 typename ReducePacksPerXDim>
62 CK_TILE_DEVICE void reduce_impl(
const XDistributedTensor_& x_tensor,
63 YDistributedTensor_& y_tensor,
64 YIndexDistributedTensor_& y_index_tensor,
65 const ReduceFunc& reduce_func,
66 const IndexCalculatorFunc& index_calculator,
69 sweep_tile<XDistributedTensor_>(
74 auto val = ck_tile::type_convert<ComputeDataType>(x_tensor[idx]);
76 if constexpr(kProcessIndex)
80 XDistributedTensor_::get_tile_distribution(), idx);
81 const auto new_idx = index_calculator(x_indices);
82 auto current_idx = y_index_tensor(idx_0);
84 AccumulateWithIndex{}(
85 reduce_func, y_tensor(idx_0), current_idx, val, new_idx);
87 y_index_tensor(idx_0) =
88 type_convert<typename YIndexDistributedTensor_::DataType>(current_idx);
92 Accumulate{}(reduce_func, y_tensor(idx_0), val);
96 ReducePacksPerXDim{});
102 typename XDistributedTensor_,
103 typename YDistributedTensor_,
105 typename ReducePacksPerXDim =
106 uniform_sequence_gen_t<2, 1>>
108 YDistributedTensor_& y_tensor,
109 const ReduceFunc& reduce_func,
110 ReducePacksPerXDim = {})
117 [](
auto) {
return 0; },
118 ReducePacksPerXDim{});
122 template <
typename XDistributedTensor_,
123 typename YDistributedTensor_,
124 typename YIndexDistributedTensor_,
126 typename IndexCalculatorFunc,
127 typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
129 YDistributedTensor_& y_tensor,
130 YIndexDistributedTensor_& y_index_tensor,
131 const ReduceFunc& reduce_func,
132 const IndexCalculatorFunc& index_calculator,
133 ReducePacksPerXDim = {})
135 reduce_impl<Problem::kOutputIndex>(x_tensor,
140 ReducePacksPerXDim{});
144 constexpr
auto I0 = number<0>{};
145 constexpr
auto I1 = number<1>{};
146 constexpr
auto spans = XDistributedTensor_::get_distributed_spans();
150 constexpr
auto y_dstr_idx =
make_tuple(dstr_idx_i0);
152 auto y = y_tensor[y_dstr_idx];
155 constexpr
auto in_dstr_idx =
make_tuple(dstr_idx_i0, dstr_idx_i1);
156 const auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
158 y = reduce_func(y, x);
161 y_tensor(y_dstr_idx) = y;
165 template <
typename XDistributedTensor_>
171 constexpr
auto dstr =
173 XDistributedTensor_::get_tile_distribution()
174 .get_static_tile_distribution_encoding(),
177 auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
182 template <
typename XDistributedTensor_,
typename IndexDataType = index_t>
185 static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>,
"wrong!");
190 constexpr
auto dstr =
192 XDistributedTensor_::get_tile_distribution()
193 .get_static_tile_distribution_encoding(),
196 auto tensor = make_static_distributed_tensor<IndexDataType>(dstr);
203 template <
typename XDistributedTensor_,
208 const ReduceFunc& reduce_func,
209 ReducePacksPerXDim = {})
211 auto y_tensor = MakeYBlockTile<XDistributedTensor_>();
213 (*this)(x_tensor, y_tensor, reduce_func, ReducePacksPerXDim{});
220 template <
typename Problem_,
typename Policy_ =
void>
226 template <
bool kProcessIndex,
227 typename YDistributedTensor_,
228 typename YIndexDistributedTensor_,
231 YIndexDistributedTensor_& y_index_tensor,
232 const ReduceFunc& reduce_func)
234 using Dstr =
typename YDistributedTensor_::StaticTileDistribution;
235 using DstrEncode =
typename Dstr::DstrEncode;
236 using DstrEncodeDetail =
typename DstrEncode::detail;
238 constexpr
index_t NDimP = Dstr::get_num_of_dimension_p();
239 constexpr
index_t NDimR = Dstr::get_num_of_dimension_r();
241 constexpr
index_t idim_p_lane = NDimP - 1;
247 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
251 auto v_local = y_tensor.get_thread_buffer()[i];
253 using IndexDataType =
typename YIndexDistributedTensor_::DataType;
254 IndexDataType idx_local{};
256 if constexpr(kProcessIndex)
258 idx_local = y_index_tensor.get_thread_buffer()[i];
266 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
268 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
270 constexpr
index_t lid_over_rid_derivative =
271 DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
274 "wrong! only support power of 2 reduction");
283 (
number<lid_over_rid_derivative << istage.
value>{}.value);
288 if constexpr(kProcessIndex)
290 const auto idx_remote =
warp_shuffle(idx_local, src_lane);
293 reduce_func, v_local, idx_local, v_remote, idx_remote);
304 y_tensor.get_thread_buffer()(i) = v_local;
306 if constexpr(kProcessIndex)
308 y_index_tensor.get_thread_buffer()(i) = idx_local;
314 template <
typename YDistributedTensor_,
typename ReduceFunc>
317 reduce_impl<false>(y_tensor, y_tensor, reduce_func);
320 template <
typename YDistributedTensor_,
typename YIndexDistributedTensor_,
typename ReduceFunc>
322 YIndexDistributedTensor_& y_index_tensor,
323 const ReduceFunc& reduce_func)
325 reduce_impl<Problem::kOutputIndex>(y_tensor, y_index_tensor, reduce_func);
330 template <
typename Problem_,
typename Policy_ =
void>
336 template <
typename YDistributedTensor_>
339 constexpr
index_t num_reduce_warps = [&]() {
340 using Dstr =
typename YDistributedTensor_::StaticTileDistribution;
341 using DstrEncode =
typename Dstr::DstrEncode;
342 using DstrEncodeDetail =
typename DstrEncode::detail;
344 constexpr
index_t NDimR = Dstr::get_num_of_dimension_r();
346 constexpr
index_t idim_p_warp = 0;
350 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
352 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
358 return num_reduce_warps;
362 template <
typename YDistributedTensor_>
365 using DataType =
typename YDistributedTensor_::DataType;
366 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
382 return num_warps * thread_buf_size *
sizeof(DataType);
386 template <
typename YIndexDistributedTensor_>
389 using IndexDataType =
typename YIndexDistributedTensor_::DataType;
390 constexpr
index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
392 return num_warps * thread_buf_size *
sizeof(IndexDataType);
396 template <
bool kProcessIndex,
397 typename YDistributedTensor_,
398 typename YIndexDistributedTensor_,
401 YIndexDistributedTensor_& y_index_tensor,
403 void* smem_indices_ptr,
404 const ReduceFunc& reduce_func)
406 using DataType =
typename YDistributedTensor_::DataType;
407 using IndexDataType =
typename YIndexDistributedTensor_::DataType;
409 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
411 DataType* smem_ptr =
reinterpret_cast<DataType*
>(smem);
412 IndexDataType* smem_indices =
nullptr;
413 if constexpr(kProcessIndex)
415 smem_indices =
reinterpret_cast<IndexDataType*
>(smem_indices_ptr);
418 const index_t lane_id = get_lane_id();
419 const index_t warp_id = get_warp_id();
422 constexpr
index_t num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
424 if constexpr(num_reduce_warps == 1)
428 const
index_t smem_offset = warp_id;
431 static_for<0, thread_buf_size, 1>{}([&](
auto i) {
433 smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
434 if constexpr(kProcessIndex)
436 smem_indices[smem_offset + i * num_warps] =
437 y_index_tensor.get_thread_buffer()[i];
444 const index_t local_warp_id = warp_id / num_reduce_warps;
445 const index_t local_smem_os = local_warp_id * num_reduce_warps;
447 static_for<0, thread_buf_size, 1>{}([&](
auto i) {
448 DataType v[num_reduce_warps];
449 [[maybe_unused]] std::
450 conditional_t<kProcessIndex, IndexDataType[num_reduce_warps], IndexDataType> idx_v;
452 static_for<0, num_reduce_warps, 1>{}([&](
auto idx) {
453 v[idx] = smem_ptr[i * num_warps + local_smem_os + idx];
454 if constexpr(kProcessIndex)
456 idx_v[idx] = smem_indices[i * num_warps + local_smem_os + idx];
461 "wrong! only support power of 2 reduction");
465 static_for<0, nstage, 1>{}([&](
auto istage) {
466 constexpr
index_t stride = 1 << istage.value;
467 static_for<0, num_reduce_warps, stride * 2>{}([&](
auto idx_) {
469 constexpr
index_t i1 = idx_ + stride;
470 if constexpr(i1 < num_reduce_warps)
472 if constexpr(kProcessIndex)
474 AccumulateWithIndex{}(reduce_func, v[i0], idx_v[i0], v[i1], idx_v[i1]);
478 Accumulate{}(reduce_func, v[i0], v[i1]);
484 y_tensor.get_thread_buffer()(i) = v[0];
485 if constexpr(kProcessIndex)
487 y_index_tensor.get_thread_buffer()(i) = idx_v[0];
493 template <
typename YDistributedTensor_,
typename ReduceFunc>
495 operator()(YDistributedTensor_& y_tensor,
void* smem,
const ReduceFunc& reduce_func)
497 reduce_impl<false>(y_tensor, y_tensor, smem,
nullptr, reduce_func);
500 template <
typename YDistributedTensor_,
typename YIndexDistributedTensor_,
typename ReduceFunc>
502 YIndexDistributedTensor_& y_index_tensor,
505 const ReduceFunc& reduce_func)
507 reduce_impl<Problem::kOutputIndex>(
508 y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
512 template <
typename Problem_,
typename Policy_ =
void>
518 template <
typename YDistributedTensor_>
521 constexpr
index_t num_reduce_warps = [&]() {
522 using Dstr =
typename YDistributedTensor_::StaticTileDistribution;
523 using DstrEncode =
typename Dstr::DstrEncode;
524 using DstrEncodeDetail =
typename DstrEncode::detail;
526 constexpr
index_t NDimR = Dstr::get_num_of_dimension_r();
528 constexpr
index_t idim_p_warp = 0;
532 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
534 constexpr
index_t r_length = DstrEncode::rs_lengths_[idim_r];
540 return num_reduce_warps;
544 template <
typename YDistributedTensor_>
547 using DataType =
typename YDistributedTensor_::DataType;
548 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
564 return num_warps * thread_buf_size *
sizeof(DataType);
568 template <
typename YIndexDistributedTensor_>
571 using IndexDataType =
typename YIndexDistributedTensor_::DataType;
572 constexpr
index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
574 return num_warps * thread_buf_size *
sizeof(IndexDataType);
578 template <
bool kProcessIndex,
579 typename YDistributedTensor_,
580 typename YIndexDistributedTensor_,
583 YIndexDistributedTensor_& y_index_tensor,
585 void* smem_indices_ptr,
586 const ReduceFunc& reduce_func)
588 using DataType =
typename YDistributedTensor_::DataType;
589 using IndexDataType =
typename YIndexDistributedTensor_::DataType;
591 constexpr
index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
593 DataType* smem_ptr =
reinterpret_cast<DataType*
>(smem);
594 IndexDataType* smem_indices =
nullptr;
595 if constexpr(kProcessIndex)
597 smem_indices =
reinterpret_cast<IndexDataType*
>(smem_indices_ptr);
600 const index_t lane_id = get_lane_id();
601 const index_t warp_id = get_warp_id();
602 constexpr
auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
604 const index_t smem_offset = warp_id;
607 if constexpr(num_reduce_warps == 1)
613 static_for<0, thread_buf_size, 1>{}([&](
auto i) {
614 smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
615 if constexpr(kProcessIndex)
617 smem_indices[smem_offset + i * num_warps] =
618 y_index_tensor.get_thread_buffer()[i];
625 index_t local_warp_id = warp_id / num_reduce_warps;
626 index_t local_smem_os = local_warp_id * num_reduce_warps;
628 DataType all_scratch[thread_buf_size * num_reduce_warps];
630 IndexDataType[thread_buf_size * num_reduce_warps],
631 IndexDataType> all_indices;
634 static_for<0, thread_buf_size, 1>{}([&](
auto i_0) {
635 static_for<0, num_reduce_warps, 1>{}([&](
auto i_1) {
636 all_scratch[i_0 * num_reduce_warps + i_1] =
637 smem_ptr[i_0 * num_warps + local_smem_os + i_1];
639 if constexpr(kProcessIndex)
641 all_indices[i_0 * num_reduce_warps + i_1] =
642 smem_indices[i_0 * num_warps + local_smem_os + i_1];
649 static_for<0, thread_buf_size, 1>{}([&](
auto i_0) {
651 auto v_local = all_scratch[i_0 * num_reduce_warps];
653 IndexDataType idx_local{};
654 if constexpr(kProcessIndex)
656 idx_local = all_indices[i_0 * num_reduce_warps];
660 static_for<0, num_reduce_warps - 1, 1>{}([&](
auto i_1_n1) {
661 constexpr
auto i_1 = number<i_1_n1 + 1>{};
662 const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
664 if constexpr(kProcessIndex)
666 const IndexDataType idx_remote = all_indices[i_0 * num_reduce_warps + i_1];
668 bool changed =
false;
669 v_local = reduce_func(v_local, v_remote, changed);
672 idx_local = idx_remote;
677 v_local = reduce_func(v_local, v_remote);
681 y_tensor.get_thread_buffer()(i_0) = v_local;
682 if constexpr(kProcessIndex)
684 y_index_tensor.get_thread_buffer()(i_0) = idx_local;
690 template <
typename YDistributedTensor_,
typename ReduceFunc>
692 operator()(YDistributedTensor_& y_tensor,
void* smem,
const ReduceFunc& reduce_func)
694 reduce_impl<false>(y_tensor, y_tensor, smem,
nullptr, reduce_func);
697 template <
typename YDistributedTensor_,
typename YIndexDistributedTensor_,
typename ReduceFunc>
699 YIndexDistributedTensor_& y_index_tensor,
702 const ReduceFunc& reduce_func)
704 reduce_impl<Problem::kOutputIndex>(
705 y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:762
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:450
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition: utility.hpp:78
constexpr CK_TILE_HOST_DEVICE auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices, decltype(get_partition_index(tile_distribution)) partition_index)
Definition: static_distributed_tensor.hpp:158
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE int32_t integer_log2_floor(int32_t x)
Definition: math.hpp:443
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:495
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1037
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
int32_t index_t
Definition: ck.hpp:301
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
Definition: reduce_operator_accumulate.hpp:41
Accumulate with index tracking reductions, provides deterministic first occurring index.
Definition: reduce_operator_accumulate.hpp:12
Definition: block_reduce2d.hpp:332
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_reduce2d.hpp:363
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, void *smem, void *smem_indices, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:501
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:333
static constexpr CK_TILE_HOST_DEVICE index_t GetIndicesSmemSize()
Definition: block_reduce2d.hpp:387
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, void *smem, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:495
typename Problem::BlockShape BlockShape
Definition: block_reduce2d.hpp:334
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_reduce2d.hpp:337
Definition: block_reduce2d.hpp:46
constexpr CK_TILE_DEVICE BlockReduce2d()
Definition: block_reduce2d.hpp:52
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, const ReduceFunc &reduce_func, const IndexCalculatorFunc &index_calculator, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:128
typename Problem::ComputeDataType ComputeDataType
Definition: block_reduce2d.hpp:50
static CK_TILE_DEVICE auto MakeYBlockTile()
Definition: block_reduce2d.hpp:166
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:107
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:48
CK_TILE_DEVICE auto operator()(const XDistributedTensor_ &x_tensor, const ComputeDataType &reduce_init, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:206
typename Problem::XDataType XDataType
Definition: block_reduce2d.hpp:49
static CK_TILE_DEVICE auto MakeYIndexBlockTile()
Definition: block_reduce2d.hpp:183
Definition: block_reduce2d.hpp:514
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, void *smem, void *smem_indices, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:698
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:515
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, void *smem, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:692
typename Problem::BlockShape BlockShape
Definition: block_reduce2d.hpp:516
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_reduce2d.hpp:519
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_reduce2d.hpp:545
static constexpr CK_TILE_HOST_DEVICE index_t GetIndicesSmemSize()
Definition: block_reduce2d.hpp:569
Definition: block_reduce2d.hpp:222
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:315
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:321
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:223
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:49
Definition: functional.hpp:43