From c00f3c8bf30bb4a72858c476dcc9247cf7d8f0d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Tue, 31 Dec 2024 10:53:53 +0100 Subject: [PATCH] Improve activation quant function --- candle-nn/src/bit_linear.rs | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/candle-nn/src/bit_linear.rs b/candle-nn/src/bit_linear.rs index 2fdd4a9ae..97c4a54bd 100644 --- a/candle-nn/src/bit_linear.rs +++ b/candle-nn/src/bit_linear.rs @@ -42,17 +42,10 @@ fn weight_quant(x: &Tensor) -> Result { } fn activation_quant(x: &Tensor) -> Result { - let scale = (127.0 - / x.abs()? - .max(D::Minus1)? - .max(D::Minus1)? - .clamp(1e-5, f32::INFINITY)?)? - .to_dtype(x.dtype())?; + let scale = x.abs()?.max_keepdim(D::Minus1)?.clamp(1e-5, f32::INFINITY)?; + let scale = (127.0 / scale)?; - let y = x - .broadcast_mul(&scale.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?)? - .clamp(-128.0, 127.0)? - .broadcast_div(&scale)?; + let y = (x.broadcast_mul(&scale))?.round()?.clamp(-128., 127.)?.broadcast_div(&scale)?; Ok(y) }