From 7de6a5978463c2ddf9a79305bbea3a701cb5057f Mon Sep 17 00:00:00 2001 From: ZZK <359521840@qq.com> Date: Wed, 8 Jan 2025 22:01:07 +0800 Subject: [PATCH] Add half->int8 saturate conversion to promise valid range (#1983) * Add half->int8 saturate conversion to promise valid range * add gpu only macro --------- Co-authored-by: Haicheng Wu --- include/cutlass/numeric_conversion.h | 38 ++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index b62a90cca..298163d80 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -267,6 +267,44 @@ struct NumericConverter { } }; +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for cutlass::half_t => int8_t +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct NumericConverter { + + using result_type = int8_t; + using source_type = cutlass::half_t; + static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & s) { + #if defined(__CUDA_ARCH__) + union { int8_t int8[2]; int16_t int16; }; + union { cutlass::half_t fp16; int16_t int16_in; }; + fp16 = s; + asm volatile ("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in)); + return int8[0]; + #elif !defined(__CUDACC_RTC__) + std::fesetround(FE_TONEAREST); + int32_t intermediate = (int32_t)std::nearbyint(static_cast(s)); + // Low-end saturation + intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); + // High-end saturation + intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); + return static_cast(intermediate); + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // // Partial specializations for float => integer_subbyte