From 727823fca5a2d98e2e3c97ea0344359bbc0aef29 Mon Sep 17 00:00:00 2001 From: Alankar Mahajan Date: Mon, 8 Jan 2024 16:49:58 +0530 Subject: [PATCH] Changes to Pytorch Quantsim to enable output quantizer for specific cast op. (#2637) * Enable output quantizers for Cast Ops where input is int or bool and output is float Signed-off-by: Alankar Mahajan --- .../torch/src/python/aimet_torch/quantsim.py | 25 ++++++++++++++ .../torch/src/python/aimet_torch/utils.py | 33 +++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py index e8190ef72d9..9b6fec93ff3 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py @@ -195,6 +195,7 @@ def __init__(self, model: torch.nn.Module, dummy_input: Union[torch.Tensor, Tupl # Add quantization layers num_inout_tensors = utils.find_num_inout_tensors_per_module(self.model, dummy_input) + inout_tensors_dtypes_for_cast_ops = utils.get_inout_tensors_dtypes_for_cast_modules(self.model, dummy_input) self._add_quantization_wrappers(self.model, num_inout_tensors, default_data_type) self._set_tensor_quantizers_for_consts() @@ -210,6 +211,8 @@ def __init__(self, model: torch.nn.Module, dummy_input: Union[torch.Tensor, Tupl self.quant_args = extract_global_quantizer_args(quant_scheme, quantsim_configurator) + self._enable_output_quantizers_for_specific_cast_ops(inout_tensors_dtypes_for_cast_ops) + # pylint: disable=protected-access self._hw_version = quantsim_configurator._get_hw_version() self._supported_kernels = quantsim_configurator.get_supported_kernels() @@ -1820,6 +1823,28 @@ def _validate_torchquantizer(quant_sim_model): _validate_torchquantizer(quant_sim_model) OnnxSaver._export_model_to_onnx(quant_sim_model, dummy_input, model_path, is_conditional, onnx_export_args) # pylint: disable=protected-access + def _enable_output_quantizers_for_specific_cast_ops(self, inout_tensors_dtypes: Dict[torch.nn.Module, Tuple[torch.dtype, torch.dtype]]): + """ + Enable output quantizer for Cast Ops where datatype of input tensor is int/bool + and data type of output tensor is float. + """ + # pylint: disable=protected-access + model_prefix = self.connected_graph._model_name + '.' + torch_int_dtypes = {torch.int8, torch.int16, torch.int32, torch.int64, torch.bool, torch.uint8} + torch_float_dtypes = {torch.float16, torch.float32, torch.float64} + + for module, inout_dtypes in inout_tensors_dtypes.items(): + input_tensor_dtype = inout_dtypes[0] + output_tensor_dtype = inout_dtypes[1] + # pylint: disable=protected-access + module_name = self.connected_graph._module_to_name[module].split(model_prefix)[-1] + + if input_tensor_dtype != output_tensor_dtype and input_tensor_dtype in torch_int_dtypes and output_tensor_dtype in torch_float_dtypes: + logger.info("Enabling output quantizer for module %s", module_name) + wrapped_module = getattr(self.model, module_name) + for output_quantizer in wrapped_module.output_quantizers: + setattr(output_quantizer, 'enabled', True) + def save_checkpoint(quant_sim_model: QuantizationSimModel, file_path: str): """ diff --git a/TrainingExtensions/torch/src/python/aimet_torch/utils.py b/TrainingExtensions/torch/src/python/aimet_torch/utils.py index c5889e0d1d3..8f96b111d2f 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/utils.py @@ -53,6 +53,7 @@ from aimet_common.defs import QuantScheme, QuantizationDataType, MAP_QUANT_SCHEME_TO_PYMO from aimet_common.utils import AimetLogger, Handle, log_with_error_and_assert_if_false import aimet_common.libpymo as libpymo +from aimet_torch import elementwise_ops logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils) @@ -1011,3 +1012,35 @@ def fn(_, inputs): handle.remove() return cached_data + + +def get_inout_tensors_dtypes_for_cast_modules(model: torch.nn.Module, input_tensor: Union[torch.Tensor, Tuple[torch.Tensor]]) -> Dict: + """ + Get the datatype of input and output tensor of Cast modules in a Pytorch Model. + + :param model: Pytorch Model + :param input_tensor: Input tensor to run forward pass for the model. + A tuple of tensors should be passed if model has multiple inputs + :return: map of module -> (data type of input tensor, data type of output tensor) + """ + inout_dtypes_map = {} + + def record_dtypes(module, inputs, outputs): + + # pylint: disable=protected-access + if isinstance(module, elementwise_ops.Cast): + input_dtype = None + + if isinstance(inputs, (list, tuple)): + input_dtype = inputs[0].dtype + + elif isinstance(inputs, torch.Tensor): + input_dtype = inputs.dtype + + else: + raise ValueError + + inout_dtypes_map[module] = (input_dtype, outputs.dtype) + + run_hook_for_layers_with_given_input(model, input_tensor, record_dtypes) + return inout_dtypes_map