14 namespace tensor_operation {
15 namespace element_wise {
37 static constexpr
const char*
name =
"AddReluAdd";
39 template <
typename Y,
typename X0,
typename X1,
typename X2>
40 __host__ __device__ constexpr
void operator()(Y&,
const X0&,
const X1&,
const X2&)
const;
52 __host__ __device__ constexpr
void operator()<float, float, float,
float>(
float& y,
55 const float& x2)
const
58 float b =
a > 0 ?
a : 0;
64 __host__ __device__ constexpr
void operator()<float, float,
half_t,
half_t>(
65 float& y,
const float& x0,
const half_t& x1,
const half_t& x2)
const
68 float b =
a > 0 ?
a : 0;
78 (*this)(y_float, x0, x1, x2);
87 float b =
a > 0 ?
a : 0;
102 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
117 static constexpr
const char*
name =
"AddHardswishAdd";
119 template <
typename Y,
typename X0,
typename X1,
typename X2>
120 __host__ __device__ constexpr
void operator()(Y&,
const X0&,
const X1&,
const X2&)
const;
123 __host__ __device__ constexpr
void operator()<float, float, float,
float>(
float& y,
126 const float& x2)
const
129 float b =
a +
float{3};
130 float c = (b > 0) * (b >
float{6} ?
float{6} : b) *
a *
float{0.166667};
140 float b =
a +
float{3};
141 float c = (b > 0) * (b >
float{6} ?
float{6} : b) *
a *
float{0.166667};
151 static constexpr
const char*
name =
"AddAdd";
153 template <
typename E,
typename C,
typename D0,
typename D1>
154 __host__ __device__
void operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const
159 "Data type is not supported by this operation!");
163 "Data type is not supported by this operation!");
167 "Data type is not supported by this operation!");
171 "Data type is not supported by this operation!");
173 const C y = c + type_convert<C>(d0) + type_convert<C>(d1);
174 e = type_convert<E>(y);
182 static constexpr
const char*
name =
"AddMultiply";
184 template <
typename E,
typename C,
typename D0,
typename D1>
185 __host__ __device__
void operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
193 const half_t y = (c + d0) * d1;
202 const half_t y = (type_convert<half_t>(c) + d0) * d1;
206 __host__ __device__
void operator()<float, float,
half_t,
half_t>(
float& e,
211 const float y = (c + d0) * d1;
220 static constexpr
const char*
name =
"MultiplyAdd";
222 template <
typename E,
typename C,
typename D0,
typename D1>
223 __host__ __device__
void operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
231 const half_t y = (c * d0) + d1;
241 type_convert<half_t>(c * type_convert<float>(d0) + type_convert<float>(d1));
251 type_convert<bhalf_t>(c * type_convert<float>(d0) + type_convert<float>(d1));
255 __host__ __device__
void operator()<float, float,
half_t,
half_t>(
float& e,
260 const float y = c * d0 + d1;
264 __host__ __device__
void operator()<
half_t, float, float,
float>(
half_t& e,
267 const float& d1)
const
269 const float y = c * d0 + d1;
276 static constexpr
const char*
name =
"MultiplyMultiply";
278 template <
typename E,
typename C,
typename D0,
typename D1>
279 __host__ __device__ constexpr
void
280 operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
283 __host__ __device__ constexpr
void operator()<
ck::half_t, float, float,
float>(
284 ck::half_t& e,
const float& c,
const float& d0,
const float& d1)
const
286 const float x0_f = c * d0 * d1;
288 e = ck::type_convert<ck::half_t>(x0_f);
292 __host__ __device__ constexpr
void operator()<
ck::bhalf_t, float, float,
float>(
293 ck::bhalf_t& e,
const float& c,
const float& d0,
const float& d1)
const
295 const float x0_f = c * d0 * d1;
297 e = ck::type_convert<ck::bhalf_t>(x0_f);
305 ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
307 e = ck::type_convert<ck::half_t>(x0_f);
311 __host__ __device__ constexpr
void operator()<
ck::half_t, int, float,
float>(
312 ck::half_t& e,
const int& c,
const float& d0,
const float& d1)
const
315 ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
317 e = ck::type_convert<ck::half_t>(x0_f);
321 __host__ __device__ constexpr
void operator()<
ck::bhalf_t, int, float,
float>(
322 ck::bhalf_t& e,
const int& c,
const float& d0,
const float& d1)
const
325 ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
327 e = ck::type_convert<ck::bhalf_t>(x0_f);
333 static constexpr
const char*
name =
"MultiplyAddFastGelu";
335 template <
typename E,
typename C,
typename D0,
typename D1>
336 __host__ __device__ constexpr
void
337 operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
343 const float x0_f = c * ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
347 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
349 e = ck::type_convert<ck::bhalf_t>(x1_f);
356 static constexpr
const char*
name =
"AddAddFastGelu";
358 template <
typename E,
typename C,
typename D0,
typename D1>
359 __host__ __device__ constexpr
void
360 operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
363 __host__ __device__ constexpr
void operator()<float, float, float,
float>(
float& e,
366 const float& d1)
const
368 const float x = c + d0 + d1;
370 FastGelu{}.template operator()<float,
float>(e, x);
377 const half_t x = c + d0 + d1;
386 const float x0_f = c + d0 + d1;
393 e = type_convert<half_t>(x1_f);
400 const float x0_f = c + type_convert<float>(d0) + type_convert<float>(d1);
407 e = type_convert<bhalf_t>(x1_f);
415 type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1);
422 e = type_convert<int8_t>(x1_f);
429 static constexpr
const char*
name =
"ScaleAddScaleAddRelu";
436 template <
typename E,
typename C,
typename D0,
typename D1>
437 __host__ __device__ constexpr
void
438 operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
441 __host__ __device__ constexpr
void operator()<float, float, float,
float>(
float& e,
444 const float& d1)
const
454 const float x = type_convert<float>(c) *
alpha1_ +
alpha2_ * type_convert<float>(d0) +
455 type_convert<float>(d1);
458 result = x > 0 ? x : 0;
460 e = type_convert<half_t>(result);
467 const float x = type_convert<float>(c) *
alpha1_ +
alpha2_ * type_convert<float>(d0) +
468 type_convert<float>(d1);
471 result = x > 0 ? x : 0;
473 e = type_convert<bhalf_t>(result);
477 __host__ __device__ constexpr
void operator()<
int8_t,
int8_t, float,
float>(
478 int8_t& e,
const int8_t& c,
const float& d0,
const float& d1)
const
480 const float x = type_convert<float>(c) *
alpha1_ +
alpha2_ * d0 + d1;
483 result = x > 0 ? x : 0;
485 e = type_convert<int8_t>(result);
494 static constexpr
const char*
name =
"Normalize";
499 template <
typename T1,
typename T2,
typename T3>
503 const T2& mean_square,
505 const T3& beta)
const;
511 const float& mean_square,
515 using ck::math::sqrt;
517 float variance = mean_square - (mean * mean);
519 float tmp_x = type_convert<float>(x);
520 float tmp_gamma = type_convert<float>(gamma);
521 float tmp_beta = type_convert<float>(beta);
524 ((tmp_x - mean) / sqrt(variance + type_convert<float>(
epsilon_))) * tmp_gamma +
527 y = type_convert<half_t>(tmp_y);
531 __host__ __device__ constexpr
void operator()<float, float,
float>(
float& y,
534 const float& mean_square,
536 const float& beta)
const
538 using ck::math::sqrt;
540 float variance = mean_square - (mean * mean);
541 y = ((x - mean) / sqrt(variance + type_convert<float>(
epsilon_))) * gamma + beta;
545 __host__ __device__ constexpr
void operator()<double, double,
double>(
double& y,
548 const double& mean_square,
550 const double& beta)
const
552 using ck::math::sqrt;
554 double variance = mean_square - (mean * mean);
555 y = ((x - mean) / sqrt(variance +
epsilon_)) * gamma + beta;
567 static constexpr
const char*
name =
"NormalizeInInfer";
571 template <
typename T1,
typename T2,
typename T3,
typename T4>
577 const T4& beta)
const
580 "Data type is not supported by this operation!");
583 using ck::math::sqrt;
587 tmp_x = type_convert<T2>(x);
589 tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(
epsilon_))) *
590 type_convert<T2>(gamma) +
591 type_convert<T2>(beta);
592 y = type_convert<T1>(tmp_y);
601 static constexpr
const char*
name =
"BiasNormalizeInInferClamp";
605 float epsilon = 1e-4)
610 template <
typename T>
620 using ck::math::sqrt;
622 float tmp_x = type_convert<float>(x) + type_convert<float>(bias);
625 ((tmp_x - type_convert<float>(mean)) / sqrt(type_convert<float>(variance) +
epsilon_)) *
626 type_convert<float>(gamma) +
627 type_convert<float>(beta);
629 y = type_convert<T>(tmp_y);
637 const float& variance,
639 const float& beta)
const
642 using ck::math::sqrt;
644 float tmp_y = (((x + bias) - mean) / sqrt(variance +
epsilon_)) * gamma + beta;
652 template <
typename Y,
typename X>
658 static constexpr
const char* name =
"UnaryTypeConvert";
662 y = ck::type_convert<float, ck::bhalf_t>(x);
669 static constexpr
const char* name =
"UnaryTypeConvert";
673 y = ck::type_convert<ck::bhalf_t, float>(x);
__host__ T ceil(T x)
Definition: math_v2.hpp:331
__host__ T floor(T x)
Definition: math_v2.hpp:367
_Float16 half_t
Definition: data_type.hpp:31
ushort bhalf_t
Definition: data_type.hpp:30
__host__ constexpr __device__ Y type_convert(X x)
Definition: type_convert.hpp:98
_BitInt(4) int4_t
Definition: data_type.hpp:32
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: numeric_limits.hpp:310
Definition: element_wise_operation.hpp:355
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
static constexpr const char * name
Definition: element_wise_operation.hpp:356
Definition: element_wise_operation.hpp:150
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:154
static constexpr const char * name
Definition: element_wise_operation.hpp:151
Definition: element_wise_operation.hpp:116
static constexpr const char * name
Definition: element_wise_operation.hpp:117
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
Definition: element_wise_operation.hpp:181
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
static constexpr const char * name
Definition: element_wise_operation.hpp:182
Definition: element_wise_operation.hpp:36
static constexpr const char * name
Definition: element_wise_operation.hpp:37
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
Definition: element_wise_operation.hpp:600
BiasNormalizeInInferClamp(float floor=0.f, float ceil=NumericLimits< float >::Max(), float epsilon=1e-4)
Definition: element_wise_operation.hpp:603
__host__ constexpr __device__ void operator()(T &y, const T &x, const T &bias, const T &mean, const T &variance, const T &gamma, const T &beta) const
Definition: element_wise_operation.hpp:611
float epsilon_
Definition: element_wise_operation.hpp:649
Clamp clamp_
Definition: element_wise_operation.hpp:646
__host__ constexpr __device__ void operator()(float &y, const float &x, const float &bias, const float &mean, const float &variance, const float &gamma, const float &beta) const
Definition: element_wise_operation.hpp:633
static constexpr const char * name
Definition: element_wise_operation.hpp:601
Definition: unary_element_wise_operation.hpp:811
Definition: unary_element_wise_operation.hpp:924
Definition: element_wise_operation.hpp:332
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
static constexpr const char * name
Definition: element_wise_operation.hpp:333
Definition: element_wise_operation.hpp:219
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
static constexpr const char * name
Definition: element_wise_operation.hpp:220
Definition: element_wise_operation.hpp:275
static constexpr const char * name
Definition: element_wise_operation.hpp:276
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:493
Normalize(double epsilon=1e-4)
Definition: element_wise_operation.hpp:497
double epsilon_
Definition: element_wise_operation.hpp:556
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &mean_square, const T3 &gamma, const T3 &beta) const
static constexpr const char * name
Definition: element_wise_operation.hpp:494
Definition: element_wise_operation.hpp:566
static constexpr const char * name
Definition: element_wise_operation.hpp:567
double epsilon_
Definition: element_wise_operation.hpp:593
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &variance, const T3 &gamma, const T4 &beta) const
Definition: element_wise_operation.hpp:572
NormalizeInInfer(double epsilon=1e-4)
Definition: element_wise_operation.hpp:569
Definition: element_wise_operation.hpp:428
ScaleAddScaleAddRelu(const float alpha1=1.f, const float alpha2=1.f)
Definition: element_wise_operation.hpp:431
static constexpr const char * name
Definition: element_wise_operation.hpp:429
const float alpha2_
Definition: element_wise_operation.hpp:489
const float alpha1_
Definition: element_wise_operation.hpp:488
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
__host__ __device__ void operator()(ck::bhalf_t &y, float &x) const
Definition: element_wise_operation.hpp:671
__host__ __device__ void operator()(float &y, ck::bhalf_t &x) const
Definition: element_wise_operation.hpp:660
Definition: element_wise_operation.hpp:653