/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/reduce/block/block_reduce2d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/reduce/block/block_reduce2d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck_tile/ops/reduce/block/block_reduce2d.hpp Source File
block_reduce2d.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 // BlockReduce2d implements a hierarchical 2D reduction operator that reduces data along the second
12 // dimension using a user-specified reduction function.
13 //
14 // The reduction is performed in a three-stage hierarchical approach:
15 //
16 // STAGE 1: Thread-level reduction (BlockReduce2d)
17 // ===============================================
18 // - Each thread processes multiple elements from the input tensor within its assigned data
19 // partition
20 // - Reduction is performed locally within each thread by iterating over assigned elements
21 // - ReducePacksPerXDim controls how many elements sweep_tile processes in one iteration per
22 // dimension
23 // (e.g., {1,1} = 1 element at a time from each dimension, {2,4} = 2 from dim0, 4 from dim1)
24 // - Results are accumulated into a thread-local output tensor stored in registers
25 // - The output tensor distribution is derived from the input tensor's distribution using
26 // make_reduce_tile_distribution_encoding() to handle dimension reduction
27 //
28 // STAGE 2: Warp-level reduction (BlockReduce2dSync)
29 // ================================================
30 // - Performs inter-thread reduction within each warp
31 // - Uses warp shuffle operations to exchange data between threads in the same warp
32 // - Implements a tree-reduction pattern with power-of-2 stages
33 // - Only reduces along dimensions that map to lane IDs within the warp
34 //
35 // STAGE 3: Cross-warp reduction (BlockReduce2dCrossWarpSync)
36 // ========================================================
37 // - Performs reduction across multiple warps within the same thread block
38 // - Uses shared memory (LDS) to facilitate data exchange between warps
39 // - Each warp's lane-0 thread stores its partial results to shared memory
40 // - All threads participate in loading and reducing data from shared memory
41 // - Implements block-level synchronization to ensure memory consistency
42 
43 // BlockReduce2d: Thread-level reduction (Stage 1)
44 template <typename Problem_, typename Policy_ = void>
46 {
47  // Thread-level reduction implementation
49  using XDataType = typename Problem::XDataType;
50  using ComputeDataType = typename Problem::ComputeDataType;
51 
53 
54  private:
55  template <bool kProcessIndex,
56  typename XDistributedTensor_,
57  typename YDistributedTensor_,
58  typename YIndexDistributedTensor_,
59  typename ReduceFunc,
60  typename IndexCalculatorFunc,
61  typename ReducePacksPerXDim>
62  CK_TILE_DEVICE void reduce_impl(const XDistributedTensor_& x_tensor,
63  YDistributedTensor_& y_tensor,
64  YIndexDistributedTensor_& y_index_tensor,
65  const ReduceFunc& reduce_func,
66  const IndexCalculatorFunc& index_calculator,
67  ReducePacksPerXDim)
68  {
69  sweep_tile<XDistributedTensor_>(
70  [&](auto... idx_) {
71  constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
72 
73  (..., [&](auto idx) {
74  auto val = ck_tile::type_convert<ComputeDataType>(x_tensor[idx]);
75 
76  if constexpr(kProcessIndex)
77  {
78 
79  const auto x_indices = get_x_indices_from_distributed_indices(
80  XDistributedTensor_::get_tile_distribution(), idx);
81  const auto new_idx = index_calculator(x_indices);
82  auto current_idx = y_index_tensor(idx_0);
83 
84  AccumulateWithIndex{}(
85  reduce_func, y_tensor(idx_0), current_idx, val, new_idx);
86 
87  y_index_tensor(idx_0) =
88  type_convert<typename YIndexDistributedTensor_::DataType>(current_idx);
89  }
90  else
91  {
92  Accumulate{}(reduce_func, y_tensor(idx_0), val);
93  }
94  }(idx_));
95  },
96  ReducePacksPerXDim{});
97  }
98 
99  public:
100  // Overload for non-index tracking
101  template <
102  typename XDistributedTensor_,
103  typename YDistributedTensor_,
104  typename ReduceFunc,
105  typename ReducePacksPerXDim =
106  uniform_sequence_gen_t<2, 1>> // {1,1} = process 1 element at a time from each dimension
107  CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
108  YDistributedTensor_& y_tensor,
109  const ReduceFunc& reduce_func,
110  ReducePacksPerXDim = {})
111  {
112  reduce_impl<false>(
113  x_tensor,
114  y_tensor,
115  y_tensor, // dummy
116  reduce_func,
117  [](auto) { return 0; }, // dummy
118  ReducePacksPerXDim{});
119  }
120 
121  // Overload for index tracking
122  template <typename XDistributedTensor_,
123  typename YDistributedTensor_,
124  typename YIndexDistributedTensor_,
125  typename ReduceFunc,
126  typename IndexCalculatorFunc,
127  typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
128  CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
129  YDistributedTensor_& y_tensor,
130  YIndexDistributedTensor_& y_index_tensor,
131  const ReduceFunc& reduce_func,
132  const IndexCalculatorFunc& index_calculator,
133  ReducePacksPerXDim = {})
134  {
135  reduce_impl<Problem::kOutputIndex>(x_tensor,
136  y_tensor,
137  y_index_tensor,
138  reduce_func,
139  index_calculator,
140  ReducePacksPerXDim{});
141  }
142 
143 #if 0
144  constexpr auto I0 = number<0>{};
145  constexpr auto I1 = number<1>{};
146  constexpr auto spans = XDistributedTensor_::get_distributed_spans();
147 
148  // FIXME: hard coded to reduce 2nd axis
149  sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
150  constexpr auto y_dstr_idx = make_tuple(dstr_idx_i0);
151 
152  auto y = y_tensor[y_dstr_idx];
153 
154  sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
155  constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
156  const auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
157 
158  y = reduce_func(y, x);
159  });
160 
161  y_tensor(y_dstr_idx) = y;
162  });
163 #endif
164 
165  template <typename XDistributedTensor_>
167  {
168  // FIXME: hard coded to reduce 2nd axis
169  constexpr auto reduce_dims = sequence<1>{};
170 
171  constexpr auto dstr =
173  XDistributedTensor_::get_tile_distribution()
174  .get_static_tile_distribution_encoding(),
175  reduce_dims));
176 
177  auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
178 
179  return tensor;
180  }
181 
182  template <typename XDistributedTensor_, typename IndexDataType = index_t>
184  {
185  static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
186 
187  // FIXME: hard coded to reduce 2nd axis
188  constexpr auto reduce_dims = sequence<1>{};
189 
190  constexpr auto dstr =
192  XDistributedTensor_::get_tile_distribution()
193  .get_static_tile_distribution_encoding(),
194  reduce_dims));
195 
196  auto tensor = make_static_distributed_tensor<IndexDataType>(dstr);
197 
198  return tensor;
199  }
200 
201  // uniform_sequence_gen_t<NSize, Value> generates sequence of NSize elements filled with Value
202  // e.g., uniform_sequence_gen_t<2, 1> → {1, 1} and uniform_sequence_gen_t<3, 4> → {4, 4, 4}
203  template <typename XDistributedTensor_,
204  typename ReduceFunc,
205  typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
206  CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor,
207  const ComputeDataType& reduce_init,
208  const ReduceFunc& reduce_func,
209  ReducePacksPerXDim = {})
210  {
211  auto y_tensor = MakeYBlockTile<XDistributedTensor_>();
212  set_tile(y_tensor, reduce_init);
213  (*this)(x_tensor, y_tensor, reduce_func, ReducePacksPerXDim{});
214 
215  return y_tensor;
216  }
217 };
218 
219 // BlockReduce2dSync: Warp-level reduction (Stage 2)
220 template <typename Problem_, typename Policy_ = void>
222 {
224 
225  private:
226  template <bool kProcessIndex,
227  typename YDistributedTensor_,
228  typename YIndexDistributedTensor_,
229  typename ReduceFunc>
230  CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
231  YIndexDistributedTensor_& y_index_tensor,
232  const ReduceFunc& reduce_func)
233  {
234  using Dstr = typename YDistributedTensor_::StaticTileDistribution;
235  using DstrEncode = typename Dstr::DstrEncode;
236  using DstrEncodeDetail = typename DstrEncode::detail;
237 
238  constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
239  constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
240 
241  constexpr index_t idim_p_lane = NDimP - 1;
242 
243  // const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
244  // const auto rs_idx =
245  // y_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
246 
247  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
248 
249  // loop over thread data
250  static_for<0, thread_buf_size, 1>{}([&](auto i) {
251  auto v_local = y_tensor.get_thread_buffer()[i];
252 
253  using IndexDataType = typename YIndexDistributedTensor_::DataType;
254  IndexDataType idx_local{};
255 
256  if constexpr(kProcessIndex)
257  {
258  idx_local = y_index_tensor.get_thread_buffer()[i];
259  }
260 
261  // cross-lane reduce for replication
262  // only reduce on R dimension correspond to lane
263  // (lane id maps to this R dimension)
264  static_for<0, NDimR, 1>{}([&](auto idim_r) {
265  // FIXME: nasty to use does_p_own_r_
266  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
267  {
268  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
269 
270  constexpr index_t lid_over_rid_derivative =
271  DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
272 
273  static_assert(is_power_of_two_integer(r_length),
274  "wrong! only support power of 2 reduction");
275 
276  constexpr index_t nstage = integer_log2_floor(r_length);
277 
278  // reduction sweep forward
279  static_for<0, nstage, 1>{}([&](auto istage) {
280  // xor
281  index_t src_lane =
282  (__lane_id()) ^
283  (number<lid_over_rid_derivative << istage.value>{}.value);
284 
285  // pull data from remote lane
286  const auto v_remote = warp_shuffle(v_local, src_lane);
287 
288  if constexpr(kProcessIndex)
289  {
290  const auto idx_remote = warp_shuffle(idx_local, src_lane);
291 
293  reduce_func, v_local, idx_local, v_remote, idx_remote);
294  }
295  else
296  {
297  Accumulate{}(reduce_func, v_local, v_remote);
298  }
299  });
300  }
301  });
302 
303  // TODO - Do we need to broadcast to other lane?
304  y_tensor.get_thread_buffer()(i) = v_local;
305 
306  if constexpr(kProcessIndex)
307  {
308  y_index_tensor.get_thread_buffer()(i) = idx_local;
309  }
310  });
311  }
312 
313  public:
314  template <typename YDistributedTensor_, typename ReduceFunc>
315  CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
316  {
317  reduce_impl<false>(y_tensor, y_tensor, reduce_func);
318  }
319 
320  template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
321  CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
322  YIndexDistributedTensor_& y_index_tensor,
323  const ReduceFunc& reduce_func)
324  {
325  reduce_impl<Problem::kOutputIndex>(y_tensor, y_index_tensor, reduce_func);
326  }
327 };
328 
329 // BlockReduce2dCrossWarpSync: Cross-warp reduction (Stage 3)
330 template <typename Problem_, typename Policy_ = void>
332 {
334  using BlockShape = typename Problem::BlockShape;
335 
336  template <typename YDistributedTensor_>
338  {
339  constexpr index_t num_reduce_warps = [&]() {
340  using Dstr = typename YDistributedTensor_::StaticTileDistribution;
341  using DstrEncode = typename Dstr::DstrEncode;
342  using DstrEncodeDetail = typename DstrEncode::detail;
343 
344  constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
345 
346  constexpr index_t idim_p_warp = 0;
347 
348  index_t len_ = 1;
349  static_for<0, NDimR, 1>{}([&](auto idim_r) {
350  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
351  {
352  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
353  len_ *= r_length;
354  }
355  });
356  return len_;
357  }();
358  return num_reduce_warps;
359  }
360 
361  // return in byte
362  template <typename YDistributedTensor_>
364  {
365  using DataType = typename YDistributedTensor_::DataType;
366  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
367 
368  // we need to store all data from every wave into smem
369  // e.g. 2x2 reduce along N
370  // -------------> reduce N
371  // | w0 | w1 | ___> | w01 |
372  // | w2 | w3 | | w23 |
373  //
374  // -> store data from every wave into LDS
375  //
376  //
377  // -------------> reduce N
378  // | w0 | w1 | w2 | w3 | -----> | w0123 |
379  //
380  // -> also store data from every wave into LDS
381  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
382  return num_warps * thread_buf_size * sizeof(DataType);
383  }
384 
385  // return in byte - separate shared memory size calculation for indices
386  template <typename YIndexDistributedTensor_>
388  {
389  using IndexDataType = typename YIndexDistributedTensor_::DataType;
390  constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
391  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
392  return num_warps * thread_buf_size * sizeof(IndexDataType);
393  }
394 
395  private:
396  template <bool kProcessIndex,
397  typename YDistributedTensor_,
398  typename YIndexDistributedTensor_,
399  typename ReduceFunc>
400  CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
401  YIndexDistributedTensor_& y_index_tensor,
402  void* smem,
403  void* smem_indices_ptr,
404  const ReduceFunc& reduce_func)
405  {
406  using DataType = typename YDistributedTensor_::DataType;
407  using IndexDataType = typename YIndexDistributedTensor_::DataType;
408 
409  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
410 
411  DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
412  IndexDataType* smem_indices = nullptr;
413  if constexpr(kProcessIndex)
414  {
415  smem_indices = reinterpret_cast<IndexDataType*>(smem_indices_ptr);
416  }
417 
418  const index_t lane_id = get_lane_id();
419  const index_t warp_id = get_warp_id();
420 
421  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
422  constexpr index_t num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
423 
424  if constexpr(num_reduce_warps == 1)
425  return;
426  block_sync_lds();
427  // Each warp's lane 0 writes its partial results to shared memory
428  const index_t smem_offset = warp_id;
429  if(lane_id == 0)
430  {
431  static_for<0, thread_buf_size, 1>{}([&](auto i) {
432  // Store the i-th element of this warp's thread_buffer into SMEM
433  smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
434  if constexpr(kProcessIndex)
435  {
436  smem_indices[smem_offset + i * num_warps] =
437  y_index_tensor.get_thread_buffer()[i];
438  }
439  });
440  }
441  block_sync_lds();
442 
443  // We let each warp holds a duplication to do reduction.
444  const index_t local_warp_id = warp_id / num_reduce_warps;
445  const index_t local_smem_os = local_warp_id * num_reduce_warps;
446 
447  static_for<0, thread_buf_size, 1>{}([&](auto i) {
448  DataType v[num_reduce_warps];
449  [[maybe_unused]] std::
450  conditional_t<kProcessIndex, IndexDataType[num_reduce_warps], IndexDataType> idx_v;
451 
452  static_for<0, num_reduce_warps, 1>{}([&](auto idx) {
453  v[idx] = smem_ptr[i * num_warps + local_smem_os + idx];
454  if constexpr(kProcessIndex)
455  {
456  idx_v[idx] = smem_indices[i * num_warps + local_smem_os + idx];
457  }
458  });
459 
460  static_assert(is_power_of_two_integer(num_reduce_warps),
461  "wrong! only support power of 2 reduction");
462 
463  constexpr index_t nstage = integer_log2_floor(num_reduce_warps);
464 
465  static_for<0, nstage, 1>{}([&](auto istage) {
466  constexpr index_t stride = 1 << istage.value;
467  static_for<0, num_reduce_warps, stride * 2>{}([&](auto idx_) {
468  constexpr index_t i0 = idx_();
469  constexpr index_t i1 = idx_ + stride;
470  if constexpr(i1 < num_reduce_warps)
471  {
472  if constexpr(kProcessIndex)
473  {
474  AccumulateWithIndex{}(reduce_func, v[i0], idx_v[i0], v[i1], idx_v[i1]);
475  }
476  else
477  {
478  Accumulate{}(reduce_func, v[i0], v[i1]);
479  }
480  }
481  });
482  });
483 
484  y_tensor.get_thread_buffer()(i) = v[0];
485  if constexpr(kProcessIndex)
486  {
487  y_index_tensor.get_thread_buffer()(i) = idx_v[0];
488  }
489  });
490  }
491 
492  public:
493  template <typename YDistributedTensor_, typename ReduceFunc>
494  CK_TILE_DEVICE void
495  operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
496  {
497  reduce_impl<false>(y_tensor, y_tensor, smem, nullptr, reduce_func);
498  }
499 
500  template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
501  CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
502  YIndexDistributedTensor_& y_index_tensor,
503  void* smem,
504  void* smem_indices,
505  const ReduceFunc& reduce_func)
506  {
507  reduce_impl<Problem::kOutputIndex>(
508  y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
509  }
510 };
511 
512 template <typename Problem_, typename Policy_ = void>
514 {
516  using BlockShape = typename Problem::BlockShape;
517 
518  template <typename YDistributedTensor_>
520  {
521  constexpr index_t num_reduce_warps = [&]() {
522  using Dstr = typename YDistributedTensor_::StaticTileDistribution;
523  using DstrEncode = typename Dstr::DstrEncode;
524  using DstrEncodeDetail = typename DstrEncode::detail;
525 
526  constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
527 
528  constexpr index_t idim_p_warp = 0;
529 
530  index_t len_ = 1;
531  static_for<0, NDimR, 1>{}([&](auto idim_r) {
532  if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
533  {
534  constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
535  len_ *= r_length;
536  }
537  });
538  return len_;
539  }();
540  return num_reduce_warps;
541  }
542 
543  // return in byte
544  template <typename YDistributedTensor_>
546  {
547  using DataType = typename YDistributedTensor_::DataType;
548  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
549 
550  // we need to store all data from every wave into smem
551  // e.g. 2x2 reduce along N
552  // -------------> reduce N
553  // | w0 | w1 | ___> | w01 |
554  // | w2 | w3 | | w23 |
555  //
556  // -> store data from every wave into LDS
557  //
558  //
559  // -------------> reduce N
560  // | w0 | w1 | w2 | w3 | -----> | w0123 |
561  //
562  // -> also store data from every wave into LDS
563  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
564  return num_warps * thread_buf_size * sizeof(DataType);
565  }
566 
567  // return in byte - separate shared memory size calculation for indices
568  template <typename YIndexDistributedTensor_>
570  {
571  using IndexDataType = typename YIndexDistributedTensor_::DataType;
572  constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
573  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
574  return num_warps * thread_buf_size * sizeof(IndexDataType);
575  }
576 
577  private:
578  template <bool kProcessIndex,
579  typename YDistributedTensor_,
580  typename YIndexDistributedTensor_,
581  typename ReduceFunc>
582  CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
583  YIndexDistributedTensor_& y_index_tensor,
584  void* smem,
585  void* smem_indices_ptr,
586  const ReduceFunc& reduce_func)
587  {
588  using DataType = typename YDistributedTensor_::DataType;
589  using IndexDataType = typename YIndexDistributedTensor_::DataType;
590 
591  constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
592 
593  DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
594  IndexDataType* smem_indices = nullptr;
595  if constexpr(kProcessIndex)
596  {
597  smem_indices = reinterpret_cast<IndexDataType*>(smem_indices_ptr);
598  }
599 
600  const index_t lane_id = get_lane_id();
601  const index_t warp_id = get_warp_id();
602  constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
603  constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
604  const index_t smem_offset = warp_id;
605 
606  // skip if nonthing to do
607  if constexpr(num_reduce_warps == 1)
608  return;
609 
610  // store into smem only for lane-0 within one warp
611  if(lane_id == 0)
612  {
613  static_for<0, thread_buf_size, 1>{}([&](auto i) {
614  smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
615  if constexpr(kProcessIndex)
616  {
617  smem_indices[smem_offset + i * num_warps] =
618  y_index_tensor.get_thread_buffer()[i];
619  }
620  });
621  }
622  block_sync_lds();
623 
624  // load from smem. here we let everythread to do compute :)
625  index_t local_warp_id = warp_id / num_reduce_warps;
626  index_t local_smem_os = local_warp_id * num_reduce_warps;
627 
628  DataType all_scratch[thread_buf_size * num_reduce_warps];
629  [[maybe_unused]] std::conditional_t<kProcessIndex,
630  IndexDataType[thread_buf_size * num_reduce_warps],
631  IndexDataType> all_indices;
632 
633  // Load data from shared memory
634  static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
635  static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
636  all_scratch[i_0 * num_reduce_warps + i_1] =
637  smem_ptr[i_0 * num_warps + local_smem_os + i_1];
638 
639  if constexpr(kProcessIndex)
640  {
641  all_indices[i_0 * num_reduce_warps + i_1] =
642  smem_indices[i_0 * num_warps + local_smem_os + i_1];
643  }
644  });
645  });
646  block_sync_lds(); // TODO: we don't need sync here
647 
648  // Perform reduction
649  static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
650  // TODO: use descriptor for this
651  auto v_local = all_scratch[i_0 * num_reduce_warps];
652 
653  IndexDataType idx_local{};
654  if constexpr(kProcessIndex)
655  {
656  idx_local = all_indices[i_0 * num_reduce_warps];
657  }
658 
659  // further reduce mean/var
660  static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
661  constexpr auto i_1 = number<i_1_n1 + 1>{};
662  const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
663 
664  if constexpr(kProcessIndex)
665  {
666  const IndexDataType idx_remote = all_indices[i_0 * num_reduce_warps + i_1];
667 
668  bool changed = false;
669  v_local = reduce_func(v_local, v_remote, changed);
670  if(changed)
671  {
672  idx_local = idx_remote;
673  }
674  }
675  else
676  {
677  v_local = reduce_func(v_local, v_remote);
678  }
679  });
680 
681  y_tensor.get_thread_buffer()(i_0) = v_local;
682  if constexpr(kProcessIndex)
683  {
684  y_index_tensor.get_thread_buffer()(i_0) = idx_local;
685  }
686  });
687  }
688 
689  public:
690  template <typename YDistributedTensor_, typename ReduceFunc>
691  CK_TILE_DEVICE void
692  operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
693  {
694  reduce_impl<false>(y_tensor, y_tensor, smem, nullptr, reduce_func);
695  }
696 
697  template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
698  CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
699  YIndexDistributedTensor_& y_index_tensor,
700  void* smem,
701  void* smem_indices,
702  const ReduceFunc& reduce_func)
703  {
704  reduce_impl<Problem::kOutputIndex>(
705  y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
706  }
707 };
708 
709 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:762
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:450
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition: utility.hpp:78
constexpr CK_TILE_HOST_DEVICE auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices, decltype(get_partition_index(tile_distribution)) partition_index)
Definition: static_distributed_tensor.hpp:158
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE int32_t integer_log2_floor(int32_t x)
Definition: math.hpp:443
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:495
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1037
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
int32_t index_t
Definition: ck.hpp:301
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
Definition: reduce_operator_accumulate.hpp:41
Accumulate with index tracking reductions, provides deterministic first occurring index.
Definition: reduce_operator_accumulate.hpp:12
Definition: block_reduce2d.hpp:332
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_reduce2d.hpp:363
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, void *smem, void *smem_indices, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:501
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:333
static constexpr CK_TILE_HOST_DEVICE index_t GetIndicesSmemSize()
Definition: block_reduce2d.hpp:387
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, void *smem, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:495
typename Problem::BlockShape BlockShape
Definition: block_reduce2d.hpp:334
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_reduce2d.hpp:337
Definition: block_reduce2d.hpp:46
constexpr CK_TILE_DEVICE BlockReduce2d()
Definition: block_reduce2d.hpp:52
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, const ReduceFunc &reduce_func, const IndexCalculatorFunc &index_calculator, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:128
typename Problem::ComputeDataType ComputeDataType
Definition: block_reduce2d.hpp:50
static CK_TILE_DEVICE auto MakeYBlockTile()
Definition: block_reduce2d.hpp:166
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:107
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:48
CK_TILE_DEVICE auto operator()(const XDistributedTensor_ &x_tensor, const ComputeDataType &reduce_init, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition: block_reduce2d.hpp:206
typename Problem::XDataType XDataType
Definition: block_reduce2d.hpp:49
static CK_TILE_DEVICE auto MakeYIndexBlockTile()
Definition: block_reduce2d.hpp:183
Definition: block_reduce2d.hpp:514
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, void *smem, void *smem_indices, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:698
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:515
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, void *smem, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:692
typename Problem::BlockShape BlockShape
Definition: block_reduce2d.hpp:516
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_reduce2d.hpp:519
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_reduce2d.hpp:545
static constexpr CK_TILE_HOST_DEVICE index_t GetIndicesSmemSize()
Definition: block_reduce2d.hpp:569
Definition: block_reduce2d.hpp:222
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:315
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, const ReduceFunc &reduce_func)
Definition: block_reduce2d.hpp:321
remove_cvref_t< Problem_ > Problem
Definition: block_reduce2d.hpp:223
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:49
Definition: functional.hpp:43