From a687ba11007355a8d2d88bc4a71deb38f2577a50 Mon Sep 17 00:00:00 2001 From: Priyanka Dangi Date: Tue, 8 Oct 2024 14:27:45 -0700 Subject: [PATCH 1/4] Adding definitions for custom normalization modules Signed-off-by: Priyanka Dangi --- .../v2/nn/fake_quant/_legacy_impl.py | 2 +- .../aimet_torch/v2/nn/modules/custom.py | 32 +++++++++---------- .../test/python/v2/nn/test_true_quant.py | 10 +++--- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/fake_quant/_legacy_impl.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/fake_quant/_legacy_impl.py index a1f73b8caa2..7cefbf1441d 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/fake_quant/_legacy_impl.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/fake_quant/_legacy_impl.py @@ -820,7 +820,7 @@ def forward(self, # pylint: disable=too-many-arguments, arguments-differ if bias is not None and self.input_quantizers[4]: bias = self.input_quantizers[4](bias) - output = super().forward(input, running_mean, running_var, + output = super().forward(input, running_mean.detach(), running_var.detach(), weight, bias, training, momentum, eps) if self.output_quantizers[0]: diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py index b97c8d72331..184e694e0f9 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py @@ -191,10 +191,10 @@ def cat(tensors, dim=0, *, out=None): # _builtin_torch_fn = torch.floor_divide # # -# @QuantizationMixin.implements(Norm) -# class QuantizedNorm(_DispatchMixin, QuantizationMixin, Norm): -# """ Quantized Norm """ -# _builtin_torch_fn = torch.norm +@QuantizationMixin.implements(Norm) +class QuantizedNorm(_DispatchMixin, QuantizationMixin, Norm): + """ Quantized Norm """ + _builtin_torch_fn = torch.norm # # # @QuantizationMixin.implements(Exponential) @@ -538,22 +538,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # _builtin_torch_fn = torch.nn.functional.adaptive_avg_pool2d # # -# @QuantizationMixin.implements(BatchNorm) -# class QuantizedBatchNorm(_DispatchMixin, QuantizationMixin, BatchNorm): -# """ Quantized BatchNorm """ -# _builtin_torch_fn = torch.nn.functional.batch_norm +@QuantizationMixin.implements(BatchNorm) +class QuantizedBatchNorm(_DispatchMixin, QuantizationMixin, BatchNorm): + """ Quantized BatchNorm """ + _builtin_torch_fn = torch.nn.functional.batch_norm # # -# @QuantizationMixin.implements(GroupNorm) -# class QuantizedGroupNorm(_DispatchMixin, QuantizationMixin, GroupNorm): -# """ Quantized GroupNorm """ -# _builtin_torch_fn = torch.nn.functional.group_norm +@QuantizationMixin.implements(GroupNorm) +class QuantizedGroupNorm(_DispatchMixin, QuantizationMixin, GroupNorm): + """ Quantized GroupNorm """ + _builtin_torch_fn = torch.nn.functional.group_norm # # -# @QuantizationMixin.implements(Normalize) -# class QuantizedNormalize(_DispatchMixin, QuantizationMixin, Normalize): -# """ Quantized Normalize """ -# _builtin_torch_fn = torch.nn.functional.normalize +@QuantizationMixin.implements(Normalize) +class QuantizedNormalize(_DispatchMixin, QuantizationMixin, Normalize): + """ Quantized Normalize """ + _builtin_torch_fn = torch.nn.functional.normalize # # # @QuantizationMixin.implements(Pad) diff --git a/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py b/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py index 3b0ddf6aef6..0b9c29e27d6 100644 --- a/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py +++ b/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py @@ -917,7 +917,7 @@ def _create_quantized_module(module): (lambda: custom.Divide(), lambda: (randn(100), randn(100))), (lambda: custom.Concat(), lambda: (randn(1, 100), randn(3, 100))), # (lambda: custom.FloorDivide(), lambda: ...), - # (lambda: custom.Norm(), lambda: ...), + (lambda: custom.Norm(), lambda: randn(100)), # (lambda: custom.Exponential(), lambda: ...), # (lambda: custom.Erf(), lambda: ...), # (lambda: custom.Sqrt(), lambda: ...), @@ -971,9 +971,11 @@ def _create_quantized_module(module): # (lambda: custom.Interpolate(), lambda: ...), # (lambda: custom.MaxPool2d(), lambda: ...), # (lambda: custom.AdaptiveAvgPool2d(), lambda: ...), - # (lambda: custom.BatchNorm(), lambda: ...), - # (lambda: custom.GroupNorm(), lambda: ...), - # (lambda: custom.Normalize(), lambda: ...), + (lambda: custom.BatchNorm(), lambda: (randn(5, 10),torch.zeros(10), torch.ones(10))), + (lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2),torch.zeros(10), torch.ones(10))), + (lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2, 5),torch.zeros(10), torch.ones(10))), + (lambda: custom.GroupNorm(), lambda: (randn(20, 6, 10, 10), tensor(6))), + (lambda: custom.Normalize(), lambda: randn(100, 100)), # (lambda: custom.Pad(), lambda: ...), # (lambda: custom.GridSample(), lambda: ...), (lambda: custom.RmsNorm([5, 2, 3], [2], 1e-5), lambda: (randn(5, 2, 3))), From 88fb6812728b6fa98830c89f21f15c09fdcdfc40 Mon Sep 17 00:00:00 2001 From: Priyanka Dangi Date: Tue, 15 Oct 2024 16:41:43 -0700 Subject: [PATCH 2/4] Addressing bug in quantized definition of custom.BatchNorm Signed-off-by: Priyanka Dangi --- .../aimet_torch/v2/nn/modules/custom.py | 57 +++++++++++++++++++ .../test/python/v2/nn/test_true_quant.py | 9 ++- 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py index 184e694e0f9..816130c42a5 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py @@ -542,6 +542,63 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class QuantizedBatchNorm(_DispatchMixin, QuantizationMixin, BatchNorm): """ Quantized BatchNorm """ _builtin_torch_fn = torch.nn.functional.batch_norm + + def __quant_init__(self): + super().__quant_init__() + # pylint: disable=attribute-defined-outside-init + self.input_quantizers = nn.ModuleList([None, None, None, None, None]) + + def _builtin_torch_fn_helper(self, fn: Callable[..., Tensor]): + # pylint: disable=redefined-builtin + def batch_norm_wrapper( + input: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, + ) -> Tensor: + + input = _quantize_dequantize_if_applicable(input, self.input_quantizers[0]) + running_mean = _quantize_dequantize_if_applicable(running_mean, self.input_quantizers[1]) + running_var = _quantize_dequantize_if_applicable(running_var, self.input_quantizers[2]) + weight = _quantize_dequantize_if_applicable(weight, self.input_quantizers[3]) + bias = _quantize_dequantize_if_applicable(bias, self.input_quantizers[4]) + + # PyTorch doesn't support gradient calculation of running_mean/var + output = fn(input, running_mean.detach(), running_var.detach(), + weight, bias, training, momentum, eps) + + return _quantize_dequantize_if_applicable(output, self.output_quantizers[0]) + + return batch_norm_wrapper + + def _custom_kernel_helper(self, fn: Callable[..., Tensor]): + # pylint: disable=redefined-builtin + def batch_norm_wrapper( + input: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, + ) -> Tensor: + input = _quantize_if_applicable(input, self.input_quantizers[0]) + running_mean = _quantize_if_applicable(running_mean, self.input_quantizers[1]) + running_var = _quantize_if_applicable(running_var, self.input_quantizers[2]) + weight = _quantize_if_applicable(weight, self.input_quantizers[3]) + bias = _quantize_if_applicable(bias, self.input_quantizers[4]) + + # PyTorch doesn't support gradient calculation of running_mean/var + output = fn(input, running_mean.detach(), running_var.detach(), + weight, bias, training, momentum, eps) + return _quantize_if_applicable(output, self.output_quantizers[0]) + + return batch_norm_wrapper # # @QuantizationMixin.implements(GroupNorm) diff --git a/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py b/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py index 0b9c29e27d6..3cb14b8641d 100644 --- a/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py +++ b/TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py @@ -971,9 +971,12 @@ def _create_quantized_module(module): # (lambda: custom.Interpolate(), lambda: ...), # (lambda: custom.MaxPool2d(), lambda: ...), # (lambda: custom.AdaptiveAvgPool2d(), lambda: ...), - (lambda: custom.BatchNorm(), lambda: (randn(5, 10),torch.zeros(10), torch.ones(10))), - (lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2),torch.zeros(10), torch.ones(10))), - (lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2, 5),torch.zeros(10), torch.ones(10))), + (lambda: custom.BatchNorm(), lambda: (randn(5, 10), zeros(10).requires_grad_(), ones(10).requires_grad_())), + (lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2), zeros(10).requires_grad_(), ones(10).requires_grad_())), + (lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2, 5), zeros(10).requires_grad_(), ones(10).requires_grad_())), + (lambda: custom.BatchNorm(), lambda: (randn(5, 10), zeros(10), ones(10))), + (lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2), zeros(10), ones(10))), + (lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2, 5), zeros(10), ones(10))), (lambda: custom.GroupNorm(), lambda: (randn(20, 6, 10, 10), tensor(6))), (lambda: custom.Normalize(), lambda: randn(100, 100)), # (lambda: custom.Pad(), lambda: ...), From 09f233364b635fd011f6a1bf5079fc987ecbb3c2 Mon Sep 17 00:00:00 2001 From: Priyanka Dangi Date: Wed, 16 Oct 2024 14:09:30 -0700 Subject: [PATCH 3/4] Adding sanity check Signed-off-by: Priyanka Dangi --- .../torch/src/python/aimet_torch/v2/nn/modules/custom.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py index 816130c42a5..7a6c558d9f0 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py @@ -560,6 +560,9 @@ def batch_norm_wrapper( momentum: float = 0.1, eps: float = 1e-5, ) -> Tensor: + if training: + if self.input_quantizers[1] is not None or self.input_quantizers[2] is not None: + raise RuntimeError(f"{self.__class__} doesn't support quantizing running_mean or running_var in training mode") input = _quantize_dequantize_if_applicable(input, self.input_quantizers[0]) running_mean = _quantize_dequantize_if_applicable(running_mean, self.input_quantizers[1]) @@ -587,6 +590,7 @@ def batch_norm_wrapper( momentum: float = 0.1, eps: float = 1e-5, ) -> Tensor: + input = _quantize_if_applicable(input, self.input_quantizers[0]) running_mean = _quantize_if_applicable(running_mean, self.input_quantizers[1]) running_var = _quantize_if_applicable(running_var, self.input_quantizers[2]) From 285f5a9f86c73ad51fc64db31c97779fb6594664 Mon Sep 17 00:00:00 2001 From: Priyanka Dangi Date: Wed, 23 Oct 2024 17:26:37 -0700 Subject: [PATCH 4/4] Adding additional sanity check Signed-off-by: Priyanka Dangi --- .../torch/src/python/aimet_torch/v2/nn/modules/custom.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py index 7a6c558d9f0..e343a8e13ad 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/nn/modules/custom.py @@ -590,6 +590,9 @@ def batch_norm_wrapper( momentum: float = 0.1, eps: float = 1e-5, ) -> Tensor: + if training: + if self.input_quantizers[1] is not None or self.input_quantizers[2] is not None: + raise RuntimeError(f"{self.__class__} doesn't support quantizing running_mean or running_var in training mode") input = _quantize_if_applicable(input, self.input_quantizers[0]) running_mean = _quantize_if_applicable(running_mean, self.input_quantizers[1])