/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp Source File
element_wise_operation.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 #include "ck/utility/math_v2.hpp"
12 
13 namespace ck {
14 namespace tensor_operation {
15 namespace element_wise {
16 
17 // Need to ensure compiler will fail if there is no matching candidate, instead of compiler
18 // siliently do implicit type conversion
19 //
20 // Example:
21 //
22 // struct ExampleElementwiseOp
23 // {
24 // template<typename Y, typename X>
25 // __host__ __device__ constexpr void
26 // operator()(Y&, const X) const;
27 //
28 // template<>
29 // __host__ __device__ constexpr void
30 // operator()<half_t, half_t>(half_t& y, const half_t& x) const
31 // {
32 // }
33 // };
34 
35 struct AddReluAdd
36 {
37  static constexpr const char* name = "AddReluAdd";
38 
39  template <typename Y, typename X0, typename X1, typename X2>
40  __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
41 
42  template <>
43  __host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
44  half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
45  {
46  half_t a = x0 + x1;
47  half_t b = a > 0 ? a : 0;
48  y = b + x2;
49  }
50 
51  template <>
52  __host__ __device__ constexpr void operator()<float, float, float, float>(float& y,
53  const float& x0,
54  const float& x1,
55  const float& x2) const
56  {
57  float a = x0 + x1;
58  float b = a > 0 ? a : 0;
59  float c = b + x2;
60  y = c;
61  }
62 
63  template <>
64  __host__ __device__ constexpr void operator()<float, float, half_t, half_t>(
65  float& y, const float& x0, const half_t& x1, const half_t& x2) const
66  {
67  float a = x0 + x1;
68  float b = a > 0 ? a : 0;
69  float c = b + x2;
70  y = c;
71  }
72 
73  template <>
74  __host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
75  half_t& y, const float& x0, const half_t& x1, const half_t& x2) const
76  {
77  float y_float = 0.0;
78  (*this)(y_float, x0, x1, x2);
79  y = y_float;
80  }
81 
82  template <>
83  __host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
84  bhalf_t& y, const float& x0, const bhalf_t& x1, const bhalf_t& x2) const
85  {
86  float a = x0 + x1;
87  float b = a > 0 ? a : 0;
88  float c = b + x2;
89  y = c;
90  }
91 
92  template <>
93  __host__ __device__ constexpr void operator()<int8_t, int8_t, int8_t, int8_t>(
94  int8_t& y, const int8_t& x0, const int8_t& x1, const int8_t& x2) const
95  {
96  int32_t a = x0 + x1;
97  int32_t b = a > 0 ? a : 0;
98  int32_t c = b + x2;
99  y = c;
100  }
101 
102 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
103  template <>
104  __host__ __device__ constexpr void operator()<int4_t, int8_t, int4_t, int4_t>(
105  int4_t& y, const int8_t& x0, const int4_t& x1, const int4_t& x2) const
106  {
107  int32_t a = x0 + x1;
108  int32_t b = a > 0 ? a : 0;
109  int32_t c = b + x2;
110  y = c;
111  }
112 #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
113 };
114 
116 {
117  static constexpr const char* name = "AddHardswishAdd";
118 
119  template <typename Y, typename X0, typename X1, typename X2>
120  __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
121 
122  template <>
123  __host__ __device__ constexpr void operator()<float, float, float, float>(float& y,
124  const float& x0,
125  const float& x1,
126  const float& x2) const
127  {
128  float a = x0 + x1;
129  float b = a + float{3};
130  float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
131  float d = c + x2;
132  y = d;
133  }
134 
135  template <>
136  __host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
137  half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
138  {
139  float a = x0 + x1;
140  float b = a + float{3};
141  float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
142  float d = c + x2;
143  y = d;
144  }
145 };
146 
147 // C = A * B
148 // E = C + D0 + D1
149 struct AddAdd
150 {
151  static constexpr const char* name = "AddAdd";
152 
153  template <typename E, typename C, typename D0, typename D1>
154  __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const
155  {
156  // Only support floating so far
159  "Data type is not supported by this operation!");
160 
163  "Data type is not supported by this operation!");
164 
167  "Data type is not supported by this operation!");
168 
171  "Data type is not supported by this operation!");
172 
173  const C y = c + type_convert<C>(d0) + type_convert<C>(d1);
174  e = type_convert<E>(y);
175  }
176 };
177 
178 // C = A * B
179 // E = (C + D0) x D1
181 {
182  static constexpr const char* name = "AddMultiply";
183 
184  template <typename E, typename C, typename D0, typename D1>
185  __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
186 
187  template <>
188  __host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
189  const half_t& c,
190  const half_t& d0,
191  const half_t& d1) const
192  {
193  const half_t y = (c + d0) * d1;
194  e = y;
195  }
196  template <>
197  __host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
198  const float& c,
199  const half_t& d0,
200  const half_t& d1) const
201  {
202  const half_t y = (type_convert<half_t>(c) + d0) * d1;
203  e = y;
204  }
205  template <>
206  __host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
207  const float& c,
208  const half_t& d0,
209  const half_t& d1) const
210  {
211  const float y = (c + d0) * d1;
212  e = y;
213  }
214 };
215 
216 // C = A * B
217 // E = C x D0 + D1
219 {
220  static constexpr const char* name = "MultiplyAdd";
221 
222  template <typename E, typename C, typename D0, typename D1>
223  __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
224 
225  template <>
226  __host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
227  const half_t& c,
228  const half_t& d0,
229  const half_t& d1) const
230  {
231  const half_t y = (c * d0) + d1;
232  e = y;
233  }
234  template <>
235  __host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
236  const float& c,
237  const half_t& d0,
238  const half_t& d1) const
239  {
240  const half_t y =
241  type_convert<half_t>(c * type_convert<float>(d0) + type_convert<float>(d1));
242  e = y;
243  }
244  template <>
245  __host__ __device__ void operator()<bhalf_t, float, bhalf_t, bhalf_t>(bhalf_t& e,
246  const float& c,
247  const bhalf_t& d0,
248  const bhalf_t& d1) const
249  {
250  const bhalf_t y =
251  type_convert<bhalf_t>(c * type_convert<float>(d0) + type_convert<float>(d1));
252  e = y;
253  }
254  template <>
255  __host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
256  const float& c,
257  const half_t& d0,
258  const half_t& d1) const
259  {
260  const float y = c * d0 + d1;
261  e = y;
262  }
263  template <>
264  __host__ __device__ void operator()<half_t, float, float, float>(half_t& e,
265  const float& c,
266  const float& d0,
267  const float& d1) const
268  {
269  const float y = c * d0 + d1;
270  e = y;
271  }
272 };
273 
275 {
276  static constexpr const char* name = "MultiplyMultiply";
277 
278  template <typename E, typename C, typename D0, typename D1>
279  __host__ __device__ constexpr void
280  operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
281 
282  template <>
283  __host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
284  ck::half_t& e, const float& c, const float& d0, const float& d1) const
285  {
286  const float x0_f = c * d0 * d1;
287 
288  e = ck::type_convert<ck::half_t>(x0_f);
289  }
290 
291  template <>
292  __host__ __device__ constexpr void operator()<ck::bhalf_t, float, float, float>(
293  ck::bhalf_t& e, const float& c, const float& d0, const float& d1) const
294  {
295  const float x0_f = c * d0 * d1;
296 
297  e = ck::type_convert<ck::bhalf_t>(x0_f);
298  }
299 
300  template <>
301  __host__ __device__ constexpr void operator()<ck::half_t, int, ck::half_t, ck::half_t>(
302  ck::half_t& e, const int& c, const ck::half_t& d0, const ck::half_t& d1) const
303  {
304  const float x0_f =
305  ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
306 
307  e = ck::type_convert<ck::half_t>(x0_f);
308  }
309 
310  template <>
311  __host__ __device__ constexpr void operator()<ck::half_t, int, float, float>(
312  ck::half_t& e, const int& c, const float& d0, const float& d1) const
313  {
314  const float x0_f =
315  ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
316 
317  e = ck::type_convert<ck::half_t>(x0_f);
318  }
319 
320  template <>
321  __host__ __device__ constexpr void operator()<ck::bhalf_t, int, float, float>(
322  ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const
323  {
324  const float x0_f =
325  ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
326 
327  e = ck::type_convert<ck::bhalf_t>(x0_f);
328  }
329 };
330 
332 {
333  static constexpr const char* name = "MultiplyAddFastGelu";
334 
335  template <typename E, typename C, typename D0, typename D1>
336  __host__ __device__ constexpr void
337  operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
338 
339  template <>
340  __host__ __device__ constexpr void operator()<ck::bhalf_t, float, ck::bhalf_t, ck::bhalf_t>(
341  ck::bhalf_t& e, const float& c, const ck::bhalf_t& d0, const ck::bhalf_t& d1) const
342  {
343  const float x0_f = c * ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
344 
345  float x1_f = 0;
346 
347  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
348 
349  e = ck::type_convert<ck::bhalf_t>(x1_f);
350  }
351 };
352 
353 // E = FastGelu(C + D0 + D1)
355 {
356  static constexpr const char* name = "AddAddFastGelu";
357 
358  template <typename E, typename C, typename D0, typename D1>
359  __host__ __device__ constexpr void
360  operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
361 
362  template <>
363  __host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
364  const float& c,
365  const float& d0,
366  const float& d1) const
367  {
368  const float x = c + d0 + d1;
369 
370  FastGelu{}.template operator()<float, float>(e, x);
371  }
372 
373  template <>
374  __host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
375  half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
376  {
377  const half_t x = c + d0 + d1;
378 
379  ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
380  }
381 
382  template <>
383  __host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
384  half_t& e, const float& c, const half_t& d0, const half_t& d1) const
385  {
386  const float x0_f = c + d0 + d1;
387 
388  float x1_f = 0;
389 
390  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
391  x0_f);
392 
393  e = type_convert<half_t>(x1_f);
394  }
395 
396  template <>
397  __host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
398  bhalf_t& e, const float& c, const bhalf_t& d0, const bhalf_t& d1) const
399  {
400  const float x0_f = c + type_convert<float>(d0) + type_convert<float>(d1);
401 
402  float x1_f = 0;
403 
404  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
405  x0_f);
406 
407  e = type_convert<bhalf_t>(x1_f);
408  }
409 
410  template <>
411  __host__ __device__ constexpr void operator()<int8_t, int32_t, int8_t, int8_t>(
412  int8_t& e, const int32_t& c, const int8_t& d0, const int8_t& d1) const
413  {
414  const float x0_f =
415  type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1);
416 
417  float x1_f = 0;
418 
419  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
420  x0_f);
421 
422  e = type_convert<int8_t>(x1_f);
423  }
424 };
425 
426 // E = Relu(alpha1 * C + alpha2 * D0 + D1)
428 {
429  static constexpr const char* name = "ScaleAddScaleAddRelu";
430 
431  ScaleAddScaleAddRelu(const float alpha1 = 1.f, const float alpha2 = 1.f)
432  : alpha1_(alpha1), alpha2_(alpha2)
433  {
434  }
435 
436  template <typename E, typename C, typename D0, typename D1>
437  __host__ __device__ constexpr void
438  operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
439 
440  template <>
441  __host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
442  const float& c,
443  const float& d0,
444  const float& d1) const
445  {
446  const float x = c * alpha1_ + alpha2_ * d0 + d1;
447  e = x > 0 ? x : 0;
448  }
449 
450  template <>
451  __host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
452  half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
453  {
454  const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
455  type_convert<float>(d1);
456 
457  float result = 0;
458  result = x > 0 ? x : 0;
459 
460  e = type_convert<half_t>(result);
461  }
462 
463  template <>
464  __host__ __device__ constexpr void operator()<bhalf_t, bhalf_t, bhalf_t, bhalf_t>(
465  bhalf_t& e, const bhalf_t& c, const bhalf_t& d0, const bhalf_t& d1) const
466  {
467  const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
468  type_convert<float>(d1);
469 
470  float result = 0;
471  result = x > 0 ? x : 0;
472 
473  e = type_convert<bhalf_t>(result);
474  }
475 
476  template <>
477  __host__ __device__ constexpr void operator()<int8_t, int8_t, float, float>(
478  int8_t& e, const int8_t& c, const float& d0, const float& d1) const
479  {
480  const float x = type_convert<float>(c) * alpha1_ + alpha2_ * d0 + d1;
481 
482  float result = 0;
483  result = x > 0 ? x : 0;
484 
485  e = type_convert<int8_t>(result);
486  }
487 
488  const float alpha1_;
489  const float alpha2_;
490 };
491 
492 struct Normalize
493 {
494  static constexpr const char* name = "Normalize";
495 
496  // FIXME: is double absolutely necessary?
497  Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {}
498 
499  template <typename T1, typename T2, typename T3>
500  __host__ __device__ constexpr void operator()(T1& y,
501  const T1& x,
502  const T2& mean,
503  const T2& mean_square,
504  const T3& gamma,
505  const T3& beta) const;
506 
507  template <>
508  __host__ __device__ constexpr void operator()<half_t, float, half_t>(half_t& y,
509  const half_t& x,
510  const float& mean,
511  const float& mean_square,
512  const half_t& gamma,
513  const half_t& beta) const
514  {
515  using ck::math::sqrt;
516 
517  float variance = mean_square - (mean * mean);
518 
519  float tmp_x = type_convert<float>(x);
520  float tmp_gamma = type_convert<float>(gamma);
521  float tmp_beta = type_convert<float>(beta);
522 
523  float tmp_y =
524  ((tmp_x - mean) / sqrt(variance + type_convert<float>(epsilon_))) * tmp_gamma +
525  tmp_beta;
526 
527  y = type_convert<half_t>(tmp_y);
528  };
529 
530  template <>
531  __host__ __device__ constexpr void operator()<float, float, float>(float& y,
532  const float& x,
533  const float& mean,
534  const float& mean_square,
535  const float& gamma,
536  const float& beta) const
537  {
538  using ck::math::sqrt;
539 
540  float variance = mean_square - (mean * mean);
541  y = ((x - mean) / sqrt(variance + type_convert<float>(epsilon_))) * gamma + beta;
542  };
543 
544  template <>
545  __host__ __device__ constexpr void operator()<double, double, double>(double& y,
546  const double& x,
547  const double& mean,
548  const double& mean_square,
549  const double& gamma,
550  const double& beta) const
551  {
552  using ck::math::sqrt;
553 
554  double variance = mean_square - (mean * mean);
555  y = ((x - mean) / sqrt(variance + epsilon_)) * gamma + beta;
556  };
557 
558  // FIXME: is double absolutely necessary?
559  double epsilon_;
560 };
561 
562 // used by BatchNorm inference
563 // y = gamma * (x-mean) / sqrt(epsilon+variance) + beta
564 // The data type of mean and variance is used as AccDataType
566 {
567  static constexpr const char* name = "NormalizeInInfer";
568 
569  NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
570 
571  template <typename T1, typename T2, typename T3, typename T4>
572  __host__ __device__ constexpr void operator()(T1& y,
573  const T1& x,
574  const T2& mean,
575  const T2& variance,
576  const T3& gamma,
577  const T4& beta) const
578  {
580  "Data type is not supported by this operation!");
581 
582  using ck::type_convert;
583  using ck::math::sqrt;
584 
585  T2 tmp_x, tmp_y;
586 
587  tmp_x = type_convert<T2>(x);
588 
589  tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) *
590  type_convert<T2>(gamma) +
591  type_convert<T2>(beta);
592  y = type_convert<T1>(tmp_y);
593  };
594 
595  double epsilon_;
596 };
597 
598 // used by Conv+Bias+BatchNorm+Clamp inference
600 {
601  static constexpr const char* name = "BiasNormalizeInInferClamp";
602 
605  float epsilon = 1e-4)
606  : clamp_(floor, ceil), epsilon_(epsilon)
607  {
608  }
609 
610  template <typename T>
611  __host__ __device__ constexpr void operator()(T& y,
612  const T& x,
613  const T& bias,
614  const T& mean,
615  const T& variance,
616  const T& gamma,
617  const T& beta) const
618  {
619  using ck::type_convert;
620  using ck::math::sqrt;
621 
622  float tmp_x = type_convert<float>(x) + type_convert<float>(bias);
623 
624  float tmp_y =
625  ((tmp_x - type_convert<float>(mean)) / sqrt(type_convert<float>(variance) + epsilon_)) *
626  type_convert<float>(gamma) +
627  type_convert<float>(beta);
628  clamp_(tmp_y, tmp_y);
629  y = type_convert<T>(tmp_y);
630  };
631 
632  template <>
633  __host__ __device__ constexpr void operator()(float& y,
634  const float& x,
635  const float& bias,
636  const float& mean,
637  const float& variance,
638  const float& gamma,
639  const float& beta) const
640  {
641  using ck::type_convert;
642  using ck::math::sqrt;
643 
644  float tmp_y = (((x + bias) - mean) / sqrt(variance + epsilon_)) * gamma + beta;
645  clamp_(y, tmp_y);
646  };
647 
649  float epsilon_;
650 };
651 
652 template <typename Y, typename X>
654 
655 template <>
656 struct UnaryTypeConvert<float, ck::bhalf_t>
657 {
658  static constexpr const char* name = "UnaryTypeConvert";
659 
660  __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
661  {
662  y = ck::type_convert<float, ck::bhalf_t>(x);
663  }
664 };
665 
666 template <>
667 struct UnaryTypeConvert<ck::bhalf_t, float>
668 {
669  static constexpr const char* name = "UnaryTypeConvert";
670 
671  __host__ __device__ void operator()(ck::bhalf_t& y, float& x) const
672  {
673  y = ck::type_convert<ck::bhalf_t, float>(x);
674  }
675 };
676 
677 } // namespace element_wise
678 } // namespace tensor_operation
679 } // namespace ck
__host__ T ceil(T x)
Definition: math_v2.hpp:331
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: ck.hpp:270
_Float16 half_t
Definition: data_type.hpp:31
ushort bhalf_t
Definition: data_type.hpp:30
__host__ constexpr __device__ Y type_convert(X x)
Definition: type_convert.hpp:98
_BitInt(4) int4_t
Definition: data_type.hpp:32
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: numeric_limits.hpp:310
Definition: type.hpp:177
Definition: element_wise_operation.hpp:355
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
static constexpr const char * name
Definition: element_wise_operation.hpp:356
Definition: element_wise_operation.hpp:150
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:154
static constexpr const char * name
Definition: element_wise_operation.hpp:151
Definition: element_wise_operation.hpp:116
static constexpr const char * name
Definition: element_wise_operation.hpp:117
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
Definition: element_wise_operation.hpp:181
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
static constexpr const char * name
Definition: element_wise_operation.hpp:182
Definition: element_wise_operation.hpp:36
static constexpr const char * name
Definition: element_wise_operation.hpp:37
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
Definition: element_wise_operation.hpp:600
BiasNormalizeInInferClamp(float floor=0.f, float ceil=NumericLimits< float >::Max(), float epsilon=1e-4)
Definition: element_wise_operation.hpp:603
__host__ constexpr __device__ void operator()(T &y, const T &x, const T &bias, const T &mean, const T &variance, const T &gamma, const T &beta) const
Definition: element_wise_operation.hpp:611
float epsilon_
Definition: element_wise_operation.hpp:649
Clamp clamp_
Definition: element_wise_operation.hpp:646
__host__ constexpr __device__ void operator()(float &y, const float &x, const float &bias, const float &mean, const float &variance, const float &gamma, const float &beta) const
Definition: element_wise_operation.hpp:633
static constexpr const char * name
Definition: element_wise_operation.hpp:601
Definition: unary_element_wise_operation.hpp:811
Definition: unary_element_wise_operation.hpp:924
Definition: element_wise_operation.hpp:332
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
static constexpr const char * name
Definition: element_wise_operation.hpp:333
Definition: element_wise_operation.hpp:219
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
static constexpr const char * name
Definition: element_wise_operation.hpp:220
Definition: element_wise_operation.hpp:275
static constexpr const char * name
Definition: element_wise_operation.hpp:276
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:493
Normalize(double epsilon=1e-4)
Definition: element_wise_operation.hpp:497
double epsilon_
Definition: element_wise_operation.hpp:556
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &mean_square, const T3 &gamma, const T3 &beta) const
static constexpr const char * name
Definition: element_wise_operation.hpp:494
Definition: element_wise_operation.hpp:566
static constexpr const char * name
Definition: element_wise_operation.hpp:567
double epsilon_
Definition: element_wise_operation.hpp:593
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &variance, const T3 &gamma, const T4 &beta) const
Definition: element_wise_operation.hpp:572
NormalizeInInfer(double epsilon=1e-4)
Definition: element_wise_operation.hpp:569
Definition: element_wise_operation.hpp:428
ScaleAddScaleAddRelu(const float alpha1=1.f, const float alpha2=1.f)
Definition: element_wise_operation.hpp:431
static constexpr const char * name
Definition: element_wise_operation.hpp:429
const float alpha2_
Definition: element_wise_operation.hpp:489
const float alpha1_
Definition: element_wise_operation.hpp:488
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
__host__ __device__ void operator()(ck::bhalf_t &y, float &x) const
Definition: element_wise_operation.hpp:671
__host__ __device__ void operator()(float &y, ck::bhalf_t &x) const
Definition: element_wise_operation.hpp:660
Definition: element_wise_operation.hpp:653