Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for custom normalization modules #3450

Merged
merged 4 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ def forward(self, # pylint: disable=too-many-arguments, arguments-differ
if bias is not None and self.input_quantizers[4]:
bias = self.input_quantizers[4](bias)

output = super().forward(input, running_mean, running_var,
output = super().forward(input, running_mean.detach(), running_var.detach(),
weight, bias, training, momentum, eps)

if self.output_quantizers[0]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ def cat(tensors, dim=0, *, out=None):
# _builtin_torch_fn = torch.floor_divide
#
#
# @QuantizationMixin.implements(Norm)
# class QuantizedNorm(_DispatchMixin, QuantizationMixin, Norm):
# """ Quantized Norm """
# _builtin_torch_fn = torch.norm
@QuantizationMixin.implements(Norm)
class QuantizedNorm(_DispatchMixin, QuantizationMixin, Norm):
""" Quantized Norm """
_builtin_torch_fn = torch.norm
#
#
# @QuantizationMixin.implements(Exponential)
Expand Down Expand Up @@ -538,22 +538,86 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# _builtin_torch_fn = torch.nn.functional.adaptive_avg_pool2d
#
#
# @QuantizationMixin.implements(BatchNorm)
# class QuantizedBatchNorm(_DispatchMixin, QuantizationMixin, BatchNorm):
# """ Quantized BatchNorm """
# _builtin_torch_fn = torch.nn.functional.batch_norm
#
#
# @QuantizationMixin.implements(GroupNorm)
# class QuantizedGroupNorm(_DispatchMixin, QuantizationMixin, GroupNorm):
# """ Quantized GroupNorm """
# _builtin_torch_fn = torch.nn.functional.group_norm
#
#
# @QuantizationMixin.implements(Normalize)
# class QuantizedNormalize(_DispatchMixin, QuantizationMixin, Normalize):
# """ Quantized Normalize """
# _builtin_torch_fn = torch.nn.functional.normalize
@QuantizationMixin.implements(BatchNorm)
class QuantizedBatchNorm(_DispatchMixin, QuantizationMixin, BatchNorm):
""" Quantized BatchNorm """
_builtin_torch_fn = torch.nn.functional.batch_norm

def __quant_init__(self):
super().__quant_init__()
# pylint: disable=attribute-defined-outside-init
self.input_quantizers = nn.ModuleList([None, None, None, None, None])

def _builtin_torch_fn_helper(self, fn: Callable[..., Tensor]):
# pylint: disable=redefined-builtin
def batch_norm_wrapper(
input: Tensor,
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
training: bool = False,
momentum: float = 0.1,
eps: float = 1e-5,
) -> Tensor:
if training:
if self.input_quantizers[1] is not None or self.input_quantizers[2] is not None:
raise RuntimeError(f"{self.__class__} doesn't support quantizing running_mean or running_var in training mode")

input = _quantize_dequantize_if_applicable(input, self.input_quantizers[0])
running_mean = _quantize_dequantize_if_applicable(running_mean, self.input_quantizers[1])
running_var = _quantize_dequantize_if_applicable(running_var, self.input_quantizers[2])
weight = _quantize_dequantize_if_applicable(weight, self.input_quantizers[3])
bias = _quantize_dequantize_if_applicable(bias, self.input_quantizers[4])

# PyTorch doesn't support gradient calculation of running_mean/var
output = fn(input, running_mean.detach(), running_var.detach(),
weight, bias, training, momentum, eps)

return _quantize_dequantize_if_applicable(output, self.output_quantizers[0])

return batch_norm_wrapper

def _custom_kernel_helper(self, fn: Callable[..., Tensor]):
# pylint: disable=redefined-builtin
def batch_norm_wrapper(
input: Tensor,
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
training: bool = False,
momentum: float = 0.1,
eps: float = 1e-5,
) -> Tensor:
if training:
if self.input_quantizers[1] is not None or self.input_quantizers[2] is not None:
raise RuntimeError(f"{self.__class__} doesn't support quantizing running_mean or running_var in training mode")

input = _quantize_if_applicable(input, self.input_quantizers[0])
running_mean = _quantize_if_applicable(running_mean, self.input_quantizers[1])
running_var = _quantize_if_applicable(running_var, self.input_quantizers[2])
weight = _quantize_if_applicable(weight, self.input_quantizers[3])
bias = _quantize_if_applicable(bias, self.input_quantizers[4])

# PyTorch doesn't support gradient calculation of running_mean/var
output = fn(input, running_mean.detach(), running_var.detach(),
weight, bias, training, momentum, eps)
return _quantize_if_applicable(output, self.output_quantizers[0])

return batch_norm_wrapper
#
#
@QuantizationMixin.implements(GroupNorm)
class QuantizedGroupNorm(_DispatchMixin, QuantizationMixin, GroupNorm):
""" Quantized GroupNorm """
_builtin_torch_fn = torch.nn.functional.group_norm
#
#
@QuantizationMixin.implements(Normalize)
class QuantizedNormalize(_DispatchMixin, QuantizationMixin, Normalize):
""" Quantized Normalize """
_builtin_torch_fn = torch.nn.functional.normalize
#
#
# @QuantizationMixin.implements(Pad)
Expand Down
13 changes: 9 additions & 4 deletions TrainingExtensions/torch/test/python/v2/nn/test_true_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ def _create_quantized_module(module):
(lambda: custom.Divide(), lambda: (randn(100), randn(100))),
(lambda: custom.Concat(), lambda: (randn(1, 100), randn(3, 100))),
# (lambda: custom.FloorDivide(), lambda: ...),
# (lambda: custom.Norm(), lambda: ...),
(lambda: custom.Norm(), lambda: randn(100)),
# (lambda: custom.Exponential(), lambda: ...),
# (lambda: custom.Erf(), lambda: ...),
# (lambda: custom.Sqrt(), lambda: ...),
Expand Down Expand Up @@ -971,9 +971,14 @@ def _create_quantized_module(module):
# (lambda: custom.Interpolate(), lambda: ...),
# (lambda: custom.MaxPool2d(), lambda: ...),
# (lambda: custom.AdaptiveAvgPool2d(), lambda: ...),
# (lambda: custom.BatchNorm(), lambda: ...),
# (lambda: custom.GroupNorm(), lambda: ...),
# (lambda: custom.Normalize(), lambda: ...),
(lambda: custom.BatchNorm(), lambda: (randn(5, 10), zeros(10).requires_grad_(), ones(10).requires_grad_())),
(lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2), zeros(10).requires_grad_(), ones(10).requires_grad_())),
(lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2, 5), zeros(10).requires_grad_(), ones(10).requires_grad_())),
(lambda: custom.BatchNorm(), lambda: (randn(5, 10), zeros(10), ones(10))),
(lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2), zeros(10), ones(10))),
(lambda: custom.BatchNorm(), lambda: (randn(5, 10, 3, 2, 5), zeros(10), ones(10))),
(lambda: custom.GroupNorm(), lambda: (randn(20, 6, 10, 10), tensor(6))),
(lambda: custom.Normalize(), lambda: randn(100, 100)),
# (lambda: custom.Pad(), lambda: ...),
# (lambda: custom.GridSample(), lambda: ...),
(lambda: custom.RmsNorm([5, 2, 3], [2], 1e-5), lambda: (randn(5, 2, 3))),
Expand Down
Loading