diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py index 184e694e0f9..816130c42a5 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py @@ -542,6 +542,63 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class QuantizedBatchNorm(_DispatchMixin, QuantizationMixin, BatchNorm): """ Quantized BatchNorm """ _builtin_torch_fn = torch.nn.functional.batch_norm + + def __quant_init__(self): + super().__quant_init__() + # pylint: disable=attribute-defined-outside-init + self.input_quantizers = nn.ModuleList([None, None, None, None, None]) + + def _builtin_torch_fn_helper(self, fn: Callable[..., Tensor]): + # pylint: disable=redefined-builtin + def batch_norm_wrapper( + input: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, + ) -> Tensor: + + input = _quantize_dequantize_if_applicable(input, self.input_quantizers[0]) + running_mean = _quantize_dequantize_if_applicable(running_mean, self.input_quantizers[1]) + running_var = _quantize_dequantize_if_applicable(running_var, self.input_quantizers[2]) + weight = _quantize_dequantize_if_applicable(weight, self.input_quantizers[3]) + bias = _quantize_dequantize_if_applicable(bias, self.input_quantizers[4]) + + # PyTorch doesn't support gradient calculation of running_mean/var + output = fn(input, running_mean.detach(), running_var.detach(), + weight, bias, training, momentum, eps) + + return _quantize_dequantize_if_applicable(output, self.output_quantizers[0]) + + return batch_norm_wrapper + + def _custom_kernel_helper(self, fn: Callable[..., Tensor]): + # pylint: disable=redefined-builtin + def batch_norm_wrapper( + input: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, + ) -> Tensor: + input = _quantize_if_applicable(input, self.input_quantizers[0]) + running_mean = _quantize_if_applicable(running_mean, self.input_quantizers[1]) + running_var = _quantize_if_applicable(running_var, self.input_quantizers[2]) + weight = _quantize_if_applicable(weight, self.input_quantizers[3]) + bias = _quantize_if_applicable(bias, self.input_quantizers[4]) + + # PyTorch doesn't support gradient calculation of running_mean/var + output = fn(input, running_mean.detach(), running_var.detach(), + weight, bias, training, momentum, eps) + return _quantize_if_applicable(output, self.output_quantizers[0]) + + return batch_norm_wrapper # # @QuantizationMixin.implements(GroupNorm) diff --git a/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py b/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py index 0b9c29e27d6..3cb14b8641d 100644 --- a/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py +++ b/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py @@ -971,9 +971,12 @@ def _create_quantized_module(module): # (lambda: custom.Interpolate(), lambda: ...), # (lambda: custom.MaxPool2d(), lambda: ...), # (lambda: custom.AdaptiveAvgPool2d(), lambda: ...), - (lambda: custom.BatchNorm(), lambda: (randn(5, 10),torch.zeros(10), torch.ones(10))), - (lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2),torch.zeros(10), torch.ones(10))), - (lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2, 5),torch.zeros(10), torch.ones(10))), + (lambda: custom.BatchNorm(), lambda: (randn(5, 10), zeros(10).requires_grad_(), ones(10).requires_grad_())), + (lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2), zeros(10).requires_grad_(), ones(10).requires_grad_())), + (lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2, 5), zeros(10).requires_grad_(), ones(10).requires_grad_())), + (lambda: custom.BatchNorm(), lambda: (randn(5, 10), zeros(10), ones(10))), + (lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2), zeros(10), ones(10))), + (lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2, 5), zeros(10), ones(10))), (lambda: custom.GroupNorm(), lambda: (randn(20, 6, 10, 10), tensor(6))), (lambda: custom.Normalize(), lambda: randn(100, 100)), # (lambda: custom.Pad(), lambda: ...),