From bb18c5d07e47420c903cb961ee1fc81be86b8edc Mon Sep 17 00:00:00 2001 From: Kyunggeun Lee Date: Mon, 28 Oct 2024 13:52:12 -0700 Subject: [PATCH] Make second argument of compute_encodings optional (#3443) Signed-off-by: Kyunggeun Lee --- .../aimet_torch/v2/quantsim/quantsim.py | 30 +++++++++++++++++-- .../test/python/v2/quantsim/test_quantsim.py | 24 +++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py index 217094b5479..cff62a6ef38 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py @@ -37,7 +37,7 @@ """ Top level API for performing quantization simulation of a pytorch model """ import copy -from typing import Union, Tuple, Optional +from typing import Union, Tuple, Optional, TypeVar, Any, Callable, overload import warnings import itertools import io @@ -71,6 +71,10 @@ ) +class _NOT_SPECIFIED: + pass + + def _convert_to_qmodule(module: torch.nn.Module): """ Helper function to convert all modules to quantized aimet.nn modules. @@ -202,7 +206,22 @@ def __init__(self, # pylint: disable=too-many-arguments, too-many-locals, too-ma # Set quantization parameters to the device of the original module module.to(device=device) - def compute_encodings(self, forward_pass_callback, forward_pass_callback_args): + + @overload + def compute_encodings(self, forward_pass_callback: Callable[[torch.nn.Module], Any]): # pylint: disable=arguments-differ + ... + + T = TypeVar('T') + + @overload + def compute_encodings(self, + forward_pass_callback: Callable[[torch.nn.Module, T], Any], + forward_pass_callback_args: T): + ... + + del T + + def compute_encodings(self, forward_pass_callback, forward_pass_callback_args=_NOT_SPECIFIED): """ Computes encodings for all quantization sim nodes in the model. It is also used to find initial encodings for Range Learning @@ -218,10 +237,15 @@ def compute_encodings(self, forward_pass_callback, forward_pass_callback_args): :return: None """ + if forward_pass_callback_args is _NOT_SPECIFIED: + args = (self.model,) + else: + args = (self.model, forward_pass_callback_args) + # Run forward iterations so we can collect statistics to compute the appropriate encodings with utils.in_eval_mode(self.model), torch.no_grad(): with aimet_nn.compute_encodings(self.model): - _ = forward_pass_callback(self.model, forward_pass_callback_args) + _ = forward_pass_callback(*args) def export(self, path: str, filename_prefix: str, dummy_input: Union[torch.Tensor, Tuple], *args, **kwargs): diff --git a/TrainingExtensions/torch/test/python/v2/quantsim/test_quantsim.py b/TrainingExtensions/torch/test/python/v2/quantsim/test_quantsim.py index dbbd3851c39..47bd4b55ed4 100644 --- a/TrainingExtensions/torch/test/python/v2/quantsim/test_quantsim.py +++ b/TrainingExtensions/torch/test/python/v2/quantsim/test_quantsim.py @@ -1127,6 +1127,30 @@ def forward(self, *inputs): assert sim.model.module.input_quantizers[1] is not None assert not sim.model.module.input_quantizers[1].is_initialized() + def test_compute_encodings_optional_arg(self): + """ + Given: Two quantsims created with identical model & config + """ + model = test_models.BasicConv2d(kernel_size=3) + dummy_input = torch.rand(1, 64, 16, 16) + sim_a = QuantizationSimModel(model, dummy_input) + sim_b = QuantizationSimModel(model, dummy_input) + + """ + When: Run compute_encodings with second argument omitted in one quantsim and not in the other + Then: The quantizers in both quantsims should have the same encodings + """ + sim_a.compute_encodings(lambda model: model(dummy_input)) + sim_b.compute_encodings(lambda model, x: model(x), + forward_pass_callback_args=dummy_input) + + for qtzr_a, qtzr_b in zip(sim_a.model.modules(), sim_b.model.modules()): + if isinstance(qtzr_a, AffineQuantizerBase): + assert torch.equal(qtzr_a.get_scale(), qtzr_b.get_scale()) + assert torch.equal(qtzr_a.get_offset(), qtzr_b.get_offset()) + assert torch.equal(qtzr_a.get_min(), qtzr_b.get_min()) + assert torch.equal(qtzr_a.get_max(), qtzr_b.get_max()) + class TestQuantsimUtilities: