From ebbff38db0e58efcdc112f492d8ed78d524123a5 Mon Sep 17 00:00:00 2001 From: Kyunggeun Lee Date: Thu, 25 Jan 2024 16:23:54 -0800 Subject: [PATCH] Registered FakeQuantizedAdd module definition Signed-off-by: Kyunggeun Lee --- .../experimental/v2/nn/fake_quant.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/fake_quant.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/fake_quant.py index 9fe6efe4e67..872f99e284e 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/fake_quant.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/fake_quant.py @@ -37,12 +37,13 @@ """Fake-quantized modules""" from collections import OrderedDict -from typing import Type +from typing import Type, Any +from torch import Tensor import torch.nn as nn from aimet_torch.experimental.v2.nn.quant_base import BaseQuantizationMixin -import aimet_torch.elementwise_ops as elementwise_ops +import aimet_torch.elementwise_ops as ops class FakeQuantizationMixin(BaseQuantizationMixin): @@ -237,13 +238,31 @@ def wrapper(quantized_cls): FakeQuantizedZeroPad2d = FakeQuantizationMixin.wrap(nn.ZeroPad2d) +########################### +### AIMET V1 custom ops ### +########################### -### Custom ops - -# pylint: disable=missing-docstring, abstract-method - -@FakeQuantizationMixin.implements(elementwise_ops.Subtract) -class FakeQuantizedSubtract(FakeQuantizationMixin, elementwise_ops.Subtract): +@FakeQuantizationMixin.implements(ops.Add) +class FakeQuantizedAdd(FakeQuantizationMixin, ops.Add): # pylint: disable=missing-docstring def __quant_init__(self): super().__quant_init__() self.input_quantizers = nn.ModuleList([None, None]) + + def quantized_forward(self, x: Any, y: Any) -> Any: + """ + Quantized forward impl for elementwise Add + """ + # pylint: disable=arguments-differ + + if isinstance(x, Tensor) and self.input_quantizers[0]: + x = self.input_quantizers[0](x) + + if isinstance(y, Tensor) and self.input_quantizers[1]: + y = self.input_quantizers[1](y) + + out = super().forward(x, y) + + if isinstance(out, Tensor) and self.output_quantizers[0]: + out = self.output_quantizers[0](out) + + return out