Skip to content

Commit

Permalink
Registered FakeQuantizedAdd module definition
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored and quic-akhobare committed Jan 26, 2024
1 parent 85ebb9b commit ebbff38
Showing 1 changed file with 27 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@
"""Fake-quantized modules"""

from collections import OrderedDict
from typing import Type
from typing import Type, Any

from torch import Tensor
import torch.nn as nn

from aimet_torch.experimental.v2.nn.quant_base import BaseQuantizationMixin
import aimet_torch.elementwise_ops as elementwise_ops
import aimet_torch.elementwise_ops as ops


class FakeQuantizationMixin(BaseQuantizationMixin):
Expand Down Expand Up @@ -237,13 +238,31 @@ def wrapper(quantized_cls):
FakeQuantizedZeroPad2d = FakeQuantizationMixin.wrap(nn.ZeroPad2d)


###########################
### AIMET V1 custom ops ###
###########################

### Custom ops

# pylint: disable=missing-docstring, abstract-method

@FakeQuantizationMixin.implements(elementwise_ops.Subtract)
class FakeQuantizedSubtract(FakeQuantizationMixin, elementwise_ops.Subtract):
@FakeQuantizationMixin.implements(ops.Add)
class FakeQuantizedAdd(FakeQuantizationMixin, ops.Add): # pylint: disable=missing-docstring
def __quant_init__(self):
super().__quant_init__()
self.input_quantizers = nn.ModuleList([None, None])

def quantized_forward(self, x: Any, y: Any) -> Any:
"""
Quantized forward impl for elementwise Add
"""
# pylint: disable=arguments-differ

if isinstance(x, Tensor) and self.input_quantizers[0]:
x = self.input_quantizers[0](x)

if isinstance(y, Tensor) and self.input_quantizers[1]:
y = self.input_quantizers[1](y)

out = super().forward(x, y)

if isinstance(out, Tensor) and self.output_quantizers[0]:
out = self.output_quantizers[0](out)

return out

0 comments on commit ebbff38

Please sign in to comment.