diff --git a/candle-nn/src/bit_linear.rs b/candle-nn/src/bit_linear.rs index 2fdd4a9ae..97c4a54bd 100644 --- a/candle-nn/src/bit_linear.rs +++ b/candle-nn/src/bit_linear.rs @@ -42,17 +42,10 @@ fn weight_quant(x: &Tensor) -> Result { } fn activation_quant(x: &Tensor) -> Result { - let scale = (127.0 - / x.abs()? - .max(D::Minus1)? - .max(D::Minus1)? - .clamp(1e-5, f32::INFINITY)?)? - .to_dtype(x.dtype())?; + let scale = x.abs()?.max_keepdim(D::Minus1)?.clamp(1e-5, f32::INFINITY)?; + let scale = (127.0 / scale)?; - let y = x - .broadcast_mul(&scale.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?)? - .clamp(-128.0, 127.0)? - .broadcast_div(&scale)?; + let y = (x.broadcast_mul(&scale))?.round()?.clamp(-128., 127.)?.broadcast_div(&scale)?; Ok(y) }