Skip to content

Commit

Permalink
Addressing bug in quantized definition of custom.BatchNorm
Browse files Browse the repository at this point in the history
Signed-off-by: Priyanka Dangi <quic_pdangi@quicinc.com>
  • Loading branch information
quic-pdangi committed Oct 28, 2024
1 parent a687ba1 commit 88fb681
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,63 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class QuantizedBatchNorm(_DispatchMixin, QuantizationMixin, BatchNorm):
""" Quantized BatchNorm """
_builtin_torch_fn = torch.nn.functional.batch_norm

def __quant_init__(self):
super().__quant_init__()
# pylint: disable=attribute-defined-outside-init
self.input_quantizers = nn.ModuleList([None, None, None, None, None])

def _builtin_torch_fn_helper(self, fn: Callable[..., Tensor]):
# pylint: disable=redefined-builtin
def batch_norm_wrapper(
input: Tensor,
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
training: bool = False,
momentum: float = 0.1,
eps: float = 1e-5,
) -> Tensor:

input = _quantize_dequantize_if_applicable(input, self.input_quantizers[0])
running_mean = _quantize_dequantize_if_applicable(running_mean, self.input_quantizers[1])
running_var = _quantize_dequantize_if_applicable(running_var, self.input_quantizers[2])
weight = _quantize_dequantize_if_applicable(weight, self.input_quantizers[3])
bias = _quantize_dequantize_if_applicable(bias, self.input_quantizers[4])

# PyTorch doesn't support gradient calculation of running_mean/var
output = fn(input, running_mean.detach(), running_var.detach(),
weight, bias, training, momentum, eps)

return _quantize_dequantize_if_applicable(output, self.output_quantizers[0])

return batch_norm_wrapper

def _custom_kernel_helper(self, fn: Callable[..., Tensor]):
# pylint: disable=redefined-builtin
def batch_norm_wrapper(
input: Tensor,
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
training: bool = False,
momentum: float = 0.1,
eps: float = 1e-5,
) -> Tensor:
input = _quantize_if_applicable(input, self.input_quantizers[0])
running_mean = _quantize_if_applicable(running_mean, self.input_quantizers[1])
running_var = _quantize_if_applicable(running_var, self.input_quantizers[2])
weight = _quantize_if_applicable(weight, self.input_quantizers[3])
bias = _quantize_if_applicable(bias, self.input_quantizers[4])

# PyTorch doesn't support gradient calculation of running_mean/var
output = fn(input, running_mean.detach(), running_var.detach(),
weight, bias, training, momentum, eps)
return _quantize_if_applicable(output, self.output_quantizers[0])

return batch_norm_wrapper
#
#
@QuantizationMixin.implements(GroupNorm)
Expand Down
9 changes: 6 additions & 3 deletions TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,9 +971,12 @@ def _create_quantized_module(module):
# (lambda: custom.Interpolate(), lambda: ...),
# (lambda: custom.MaxPool2d(), lambda: ...),
# (lambda: custom.AdaptiveAvgPool2d(), lambda: ...),
(lambda: custom.BatchNorm(), lambda: (randn(5, 10),torch.zeros(10), torch.ones(10))),
(lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2),torch.zeros(10), torch.ones(10))),
(lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2, 5),torch.zeros(10), torch.ones(10))),
(lambda: custom.BatchNorm(), lambda: (randn(5, 10), zeros(10).requires_grad_(), ones(10).requires_grad_())),
(lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2), zeros(10).requires_grad_(), ones(10).requires_grad_())),
(lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2, 5), zeros(10).requires_grad_(), ones(10).requires_grad_())),
(lambda: custom.BatchNorm(), lambda: (randn(5, 10), zeros(10), ones(10))),
(lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2), zeros(10), ones(10))),
(lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2, 5), zeros(10), ones(10))),
(lambda: custom.GroupNorm(), lambda: (randn(20, 6, 10, 10), tensor(6))),
(lambda: custom.Normalize(), lambda: randn(100, 100)),
# (lambda: custom.Pad(), lambda: ...),
Expand Down

0 comments on commit 88fb681

Please sign in to comment.