Skip to content

Commit

Permalink
Make second argument of compute_encodings optional (#3443)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored Oct 28, 2024
1 parent 460a827 commit bb18c5d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
""" Top level API for performing quantization simulation of a pytorch model """

import copy
from typing import Union, Tuple, Optional
from typing import Union, Tuple, Optional, TypeVar, Any, Callable, overload
import warnings
import itertools
import io
Expand Down Expand Up @@ -71,6 +71,10 @@
)


class _NOT_SPECIFIED:
pass


def _convert_to_qmodule(module: torch.nn.Module):
"""
Helper function to convert all modules to quantized aimet.nn modules.
Expand Down Expand Up @@ -202,7 +206,22 @@ def __init__(self, # pylint: disable=too-many-arguments, too-many-locals, too-ma
# Set quantization parameters to the device of the original module
module.to(device=device)

def compute_encodings(self, forward_pass_callback, forward_pass_callback_args):

@overload
def compute_encodings(self, forward_pass_callback: Callable[[torch.nn.Module], Any]): # pylint: disable=arguments-differ
...

T = TypeVar('T')

@overload
def compute_encodings(self,
forward_pass_callback: Callable[[torch.nn.Module, T], Any],
forward_pass_callback_args: T):
...

del T

def compute_encodings(self, forward_pass_callback, forward_pass_callback_args=_NOT_SPECIFIED):
"""
Computes encodings for all quantization sim nodes in the model. It is also used to find initial encodings for
Range Learning
Expand All @@ -218,10 +237,15 @@ def compute_encodings(self, forward_pass_callback, forward_pass_callback_args):
:return: None
"""
if forward_pass_callback_args is _NOT_SPECIFIED:
args = (self.model,)
else:
args = (self.model, forward_pass_callback_args)

# Run forward iterations so we can collect statistics to compute the appropriate encodings
with utils.in_eval_mode(self.model), torch.no_grad():
with aimet_nn.compute_encodings(self.model):
_ = forward_pass_callback(self.model, forward_pass_callback_args)
_ = forward_pass_callback(*args)

def export(self, path: str, filename_prefix: str, dummy_input: Union[torch.Tensor, Tuple],
*args, **kwargs):
Expand Down
24 changes: 24 additions & 0 deletions TrainingExtensions/torch/test/python/v2/quantsim/test_quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,30 @@ def forward(self, *inputs):
assert sim.model.module.input_quantizers[1] is not None
assert not sim.model.module.input_quantizers[1].is_initialized()

def test_compute_encodings_optional_arg(self):
"""
Given: Two quantsims created with identical model & config
"""
model = test_models.BasicConv2d(kernel_size=3)
dummy_input = torch.rand(1, 64, 16, 16)
sim_a = QuantizationSimModel(model, dummy_input)
sim_b = QuantizationSimModel(model, dummy_input)

"""
When: Run compute_encodings with second argument omitted in one quantsim and not in the other
Then: The quantizers in both quantsims should have the same encodings
"""
sim_a.compute_encodings(lambda model: model(dummy_input))
sim_b.compute_encodings(lambda model, x: model(x),
forward_pass_callback_args=dummy_input)

for qtzr_a, qtzr_b in zip(sim_a.model.modules(), sim_b.model.modules()):
if isinstance(qtzr_a, AffineQuantizerBase):
assert torch.equal(qtzr_a.get_scale(), qtzr_b.get_scale())
assert torch.equal(qtzr_a.get_offset(), qtzr_b.get_offset())
assert torch.equal(qtzr_a.get_min(), qtzr_b.get_min())
assert torch.equal(qtzr_a.get_max(), qtzr_b.get_max())


class TestQuantsimUtilities:

Expand Down

0 comments on commit bb18c5d

Please sign in to comment.