From ab429b4aae2ebe2452e3f02cf01ae796990b2baa Mon Sep 17 00:00:00 2001 From: Kyunggeun Lee Date: Mon, 14 Oct 2024 12:11:31 -0700 Subject: [PATCH] Add error messages Signed-off-by: Kyunggeun Lee --- .../aimet_torch/v2/quantsim/quantsim.py | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py index 14b2f010155..217094b5479 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py @@ -95,7 +95,7 @@ class QuantizationSimModel(V1QuantizationSimModel): """ Overriden QuantizationSimModel that does off-target quantization simulation using v2 quantsim blocks. """ - def __init__(self, # pylint: disable=too-many-arguments, too-many-locals + def __init__(self, # pylint: disable=too-many-arguments, too-many-locals, too-many-branches model: torch.nn.Module, dummy_input: Union[torch.Tensor, Tuple], quant_scheme: Union[str, QuantScheme] = None, # NOTE: Planned to be deprecated @@ -123,12 +123,31 @@ def __init__(self, # pylint: disable=too-many-arguments, too-many-locals else: raise TypeError("'rounding_mode' parameter is no longer supported.") - for module in model.modules(): - if isinstance(module, BaseQuantizationMixin): - raise RuntimeError - - if isinstance(module, QuantizerBase): - raise RuntimeError + qmodules = { + name: module for name, module in model.named_modules() + if isinstance(module, BaseQuantizationMixin) + } + quantizers = { + name: module for name, module in model.named_modules() + if isinstance(module, QuantizerBase) + } + + if isinstance(model, BaseQuantizationMixin): + problem = f"the model itself is already a quantized module of type {type(model)}." + elif isinstance(model, QuantizerBase): + problem = f"the model itself is already a quantizer object of type {type(model)}." + elif qmodules: + problem = f"the model already contains quantized modules: {', '.join(qmodules.keys())}." + elif quantizers: + problem = f"the model already contains quantizers: {', '.join(quantizers.keys())}." + else: + problem = None + + if problem: + raise RuntimeError( + "QuantizationSimModel can only take base models WITHOUT quantized modules or quantizers, " + "but " + problem + ) if not in_place: model = copy.deepcopy(model)