diff --git a/TrainingExtensions/torch/src/python/aimet_torch/tensor_quantizer.py b/TrainingExtensions/torch/src/python/aimet_torch/tensor_quantizer.py index 274949c0d78..46c640078d1 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/tensor_quantizer.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/tensor_quantizer.py @@ -814,6 +814,8 @@ def _compute_updated_encoding(self) -> Union[libpymo.TfEncoding, List[libpymo.Tf scale, offset = self.compute_scaling_offset(encoding_min.float(), encoding_max.float()) assert scale is not None assert offset is not None + scale = scale.expand_as(encoding_min) + offset = offset.expand_as(encoding_min) if not self.use_symmetric_encodings or self.is_unsigned_symmetric: # Calculate 'min' and 'max' based on 'delta' and 'offset'