Skip to content

Commit

Permalink
Add half->int8 saturate conversion to promise valid range (#1983)
Browse files Browse the repository at this point in the history
* Add half->int8 saturate conversion to promise valid range

* add gpu only macro

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
  • Loading branch information
MARD1NO and hwu36 authored Jan 8, 2025
1 parent c506e16 commit 7de6a59
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions include/cutlass/numeric_conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,44 @@ struct NumericConverter<uint8_t, float, FloatRoundStyle::round_toward_zero> {
}
};

/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Partial specializations for cutlass::half_t => int8_t
//
/////////////////////////////////////////////////////////////////////////////////////////////////

template <>
struct NumericConverter<int8_t, cutlass::half_t, FloatRoundStyle::round_to_nearest> {

using result_type = int8_t;
using source_type = cutlass::half_t;
static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest;

CUTLASS_HOST_DEVICE
static result_type convert(source_type const & s) {
#if defined(__CUDA_ARCH__)
union { int8_t int8[2]; int16_t int16; };
union { cutlass::half_t fp16; int16_t int16_in; };
fp16 = s;
asm volatile ("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return int8[0];
#elif !defined(__CUDACC_RTC__)
std::fesetround(FE_TONEAREST);
int32_t intermediate = (int32_t)std::nearbyint(static_cast<float>(s));
// Low-end saturation
intermediate = std::max(intermediate, (int32_t)std::numeric_limits<int8_t>::lowest());
// High-end saturation
intermediate = std::min(intermediate, (int32_t)std::numeric_limits<int8_t>::max());
return static_cast<result_type>(intermediate);
#endif
}

CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};

/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Partial specializations for float => integer_subbyte
Expand Down

0 comments on commit 7de6a59

Please sign in to comment.