diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py index 7a6c558d9f0..e343a8e13ad 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py @@ -590,6 +590,9 @@ def batch_norm_wrapper( momentum: float = 0.1, eps: float = 1e-5, ) -> Tensor: + if training: + if self.input_quantizers[1] is not None or self.input_quantizers[2] is not None: + raise RuntimeError(f"{self.__class__} doesn't support quantizing running_mean or running_var in training mode") input = _quantize_if_applicable(input, self.input_quantizers[0]) running_mean = _quantize_if_applicable(running_mean, self.input_quantizers[1])