Skip to content

Commit

Permalink
Changes to Pytorch Quantsim to enable output quantizer for specific c…
Browse files Browse the repository at this point in the history
…ast op. (#2637)

* Enable output quantizers for Cast Ops where input is int or bool and output is float

Signed-off-by: Alankar Mahajan <quic_alanmaha@quicinc.com>
  • Loading branch information
quic-alanmaha authored and quic-bharathr committed Sep 13, 2024
1 parent 0029c02 commit 727823f
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
25 changes: 25 additions & 0 deletions TrainingExtensions/torch/src/python/aimet_torch/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def __init__(self, model: torch.nn.Module, dummy_input: Union[torch.Tensor, Tupl

# Add quantization layers
num_inout_tensors = utils.find_num_inout_tensors_per_module(self.model, dummy_input)
inout_tensors_dtypes_for_cast_ops = utils.get_inout_tensors_dtypes_for_cast_modules(self.model, dummy_input)

self._add_quantization_wrappers(self.model, num_inout_tensors, default_data_type)
self._set_tensor_quantizers_for_consts()
Expand All @@ -210,6 +211,8 @@ def __init__(self, model: torch.nn.Module, dummy_input: Union[torch.Tensor, Tupl

self.quant_args = extract_global_quantizer_args(quant_scheme, quantsim_configurator)

self._enable_output_quantizers_for_specific_cast_ops(inout_tensors_dtypes_for_cast_ops)

# pylint: disable=protected-access
self._hw_version = quantsim_configurator._get_hw_version()
self._supported_kernels = quantsim_configurator.get_supported_kernels()
Expand Down Expand Up @@ -1820,6 +1823,28 @@ def _validate_torchquantizer(quant_sim_model):
_validate_torchquantizer(quant_sim_model)
OnnxSaver._export_model_to_onnx(quant_sim_model, dummy_input, model_path, is_conditional, onnx_export_args) # pylint: disable=protected-access

def _enable_output_quantizers_for_specific_cast_ops(self, inout_tensors_dtypes: Dict[torch.nn.Module, Tuple[torch.dtype, torch.dtype]]):
"""
Enable output quantizer for Cast Ops where datatype of input tensor is int/bool
and data type of output tensor is float.
"""
# pylint: disable=protected-access
model_prefix = self.connected_graph._model_name + '.'
torch_int_dtypes = {torch.int8, torch.int16, torch.int32, torch.int64, torch.bool, torch.uint8}
torch_float_dtypes = {torch.float16, torch.float32, torch.float64}

for module, inout_dtypes in inout_tensors_dtypes.items():
input_tensor_dtype = inout_dtypes[0]
output_tensor_dtype = inout_dtypes[1]
# pylint: disable=protected-access
module_name = self.connected_graph._module_to_name[module].split(model_prefix)[-1]

if input_tensor_dtype != output_tensor_dtype and input_tensor_dtype in torch_int_dtypes and output_tensor_dtype in torch_float_dtypes:
logger.info("Enabling output quantizer for module %s", module_name)
wrapped_module = getattr(self.model, module_name)
for output_quantizer in wrapped_module.output_quantizers:
setattr(output_quantizer, 'enabled', True)


def save_checkpoint(quant_sim_model: QuantizationSimModel, file_path: str):
"""
Expand Down
33 changes: 33 additions & 0 deletions TrainingExtensions/torch/src/python/aimet_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from aimet_common.defs import QuantScheme, QuantizationDataType, MAP_QUANT_SCHEME_TO_PYMO
from aimet_common.utils import AimetLogger, Handle, log_with_error_and_assert_if_false
import aimet_common.libpymo as libpymo
from aimet_torch import elementwise_ops


logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils)
Expand Down Expand Up @@ -1011,3 +1012,35 @@ def fn(_, inputs):
handle.remove()

return cached_data


def get_inout_tensors_dtypes_for_cast_modules(model: torch.nn.Module, input_tensor: Union[torch.Tensor, Tuple[torch.Tensor]]) -> Dict:
"""
Get the datatype of input and output tensor of Cast modules in a Pytorch Model.
:param model: Pytorch Model
:param input_tensor: Input tensor to run forward pass for the model.
A tuple of tensors should be passed if model has multiple inputs
:return: map of module -> (data type of input tensor, data type of output tensor)
"""
inout_dtypes_map = {}

def record_dtypes(module, inputs, outputs):

# pylint: disable=protected-access
if isinstance(module, elementwise_ops.Cast):
input_dtype = None

if isinstance(inputs, (list, tuple)):
input_dtype = inputs[0].dtype

elif isinstance(inputs, torch.Tensor):
input_dtype = inputs.dtype

else:
raise ValueError

inout_dtypes_map[module] = (input_dtype, outputs.dtype)

run_hook_for_layers_with_given_input(model, input_tensor, record_dtypes)
return inout_dtypes_map

0 comments on commit 727823f

Please sign in to comment.