Skip to content

Commit

Permalink
Ad-hoc fix in irregular module definitions for compatibility with V1 …
Browse files Browse the repository at this point in the history
…quantsim

Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored and quic-akhobare committed Feb 7, 2024
1 parent 1163cca commit e4de03d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ class FakeQuantizedAimetGroupNorm(FakeQuantizationMixin, aimet_ops.GroupNorm): #
"""
def __quant_init__(self):
super().__quant_init__()
self.input_quantizers = nn.ModuleList([None, None, None])
self.input_quantizers = nn.ModuleList([None, None, None, None])

def quantized_forward(self, # pylint: disable=arguments-differ
input: Tensor,
Expand All @@ -810,11 +810,11 @@ def quantized_forward(self, # pylint: disable=arguments-differ
if self.input_quantizers[0]:
input = self.input_quantizers[0](input)

if weight is not None and self.input_quantizers[1]:
weight = self.input_quantizers[1](weight)
if weight is not None and self.input_quantizers[2]:
weight = self.input_quantizers[2](weight)

if bias is not None and self.input_quantizers[2]:
bias = self.input_quantizers[2](bias)
if bias is not None and self.input_quantizers[3]:
bias = self.input_quantizers[3](bias)

output = super().forward(input, num_groups, weight, bias, eps)

Expand Down Expand Up @@ -905,7 +905,7 @@ class FakeQuantizedWhere(FakeQuantizationMixin, aimet_ops.Where): # pylint: disa
"""
def __quant_init__(self):
super().__quant_init__()
self.input_quantizers = nn.ModuleList([None, None])
self.input_quantizers = nn.ModuleList([None, None, None])
self.output_quantizers = nn.ModuleList([None])

def quantized_forward(self, condition: Tensor, input, other, **kwargs) -> Tensor: # pylint: disable=arguments-differ
Expand All @@ -914,11 +914,11 @@ def quantized_forward(self, condition: Tensor, input, other, **kwargs) -> Tensor
"""
# pylint: disable=redefined-builtin

if isinstance(input, Tensor) and input.is_floating_point() and self.input_quantizers[0]:
input = self.input_quantizers[0](input)
if isinstance(input, Tensor) and input.is_floating_point() and self.input_quantizers[1]:
input = self.input_quantizers[1](input)

if isinstance(other, Tensor) and other.is_floating_point() and self.input_quantizers[1]:
other = self.input_quantizers[1](other)
if isinstance(other, Tensor) and other.is_floating_point() and self.input_quantizers[2]:
other = self.input_quantizers[2](other)

output = super().forward(condition, input, other, **kwargs)

Expand All @@ -935,15 +935,15 @@ class FakeQuantizedMaskedFill(FakeQuantizationMixin, aimet_ops.MaskedFill): # py
"""
def __quant_init__(self):
super().__quant_init__()
self.input_quantizers = nn.ModuleList([None])
self.input_quantizers = nn.ModuleList([None, None])
self.output_quantizers = nn.ModuleList([None])

def quantized_forward(self, mask: Tensor, value) -> Tensor: # pylint: disable=arguments-differ
"""
Quantized forward impl for aimet_ops.MaskedFill.
"""
if isinstance(value, Tensor) and value.is_floating_point() and self.input_quantizers[0]:
value = self.input_quantizers[0](value)
if isinstance(value, Tensor) and value.is_floating_point() and self.input_quantizers[1]:
value = self.input_quantizers[1](value)

output = super().forward(mask, value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,12 @@ def set_recursive(module_list, i, quantizer):
else:
raise RuntimeError

for i, quant_builder in enumerate(self.input_quantizers):
quantizer = quant_builder.realize()
for i, _ in list(enumerate(quantized_module.input_quantizers)):
quantizer = self.input_quantizers[i].realize()
set_recursive(quantized_module.input_quantizers, i, quantizer)

for i, quant_builder in enumerate(self.output_quantizers):
quantizer = quant_builder.realize()
for i, _ in list(enumerate(quantized_module.output_quantizers)):
quantizer = self.output_quantizers[i].realize()
set_recursive(quantized_module.output_quantizers, i, quantizer)

for param_name, quant_builder in self.param_quantizers.items():
Expand Down

0 comments on commit e4de03d

Please sign in to comment.