diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantization_mixin.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantization_mixin.py new file mode 100644 index 00000000000..ec5c2a64f3f --- /dev/null +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantization_mixin.py @@ -0,0 +1,59 @@ +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2023-2023, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= +# pylint: skip-file +""" Placeholder for _QuantizationMixin definition, to be deleted/moved/updated """ + +from abc import ABC, abstractmethod + +class _QuantizationMixin(ABC): + """ Base class for quantized modules """ + + @abstractmethod + def export_input_encodings(self): + ... + + @abstractmethod + def export_output_encodings(self): + ... + + @abstractmethod + def export_param_encodings(self): + ... + + @abstractmethod + def get_original_module(self): + ... \ No newline at end of file diff --git a/TrainingExtensions/torch/src/python/aimet_torch/qc_quantize_op.py b/TrainingExtensions/torch/src/python/aimet_torch/qc_quantize_op.py index 48612d00cf1..857b4e62af7 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/qc_quantize_op.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/qc_quantize_op.py @@ -40,7 +40,7 @@ # pylint: disable=too-many-lines import abc from enum import Enum -from typing import Dict, Tuple, Union, List, Callable, Type, Any +from typing import Dict, Tuple, Union, List, Callable, Type, Any, Optional import os import torch from torch import nn @@ -52,7 +52,7 @@ from aimet_torch.custom import custom_tensor_utils from aimet_torch import utils from aimet_torch.tensor_quantizer import StaticGridPerTensorQuantizer, StaticGridPerChannelQuantizer, TensorQuantizer, \ - LearnedGridTensorQuantizer, set_encoding_min_max_gating_threshold + LearnedGridTensorQuantizer, set_encoding_min_max_gating_threshold, StaticGridTensorQuantizer from aimet_torch.torch_quantizer import TorchQuantizer import aimet_torch.quantsim_straight_through_grad as ste @@ -490,6 +490,30 @@ def should_perform_quant_dequant(tensor: torch.Tensor, tensor_quantizer: TensorQ return False return True + def get_original_module(self) -> torch.nn.Module: + """ + Returns the wrapped torch.nn.Module + """ + return self._module_to_wrap + + def export_param_encodings(self) -> Dict[str, List]: + """ + Returns the layer's parameter encodings in an exportable format + """ + return {name: export_quantizer_encoding(quantizer) for name, quantizer in self.param_quantizers.items()} + + def export_output_encodings(self) -> List[List[Dict]]: + """ + Returns the layer's output encodings in an exportable format + """ + return [export_quantizer_encoding(quantizer) for quantizer in self.output_quantizers] + + def export_input_encodings(self) -> List[List[Dict]]: + """ + Returns the layer's input encodings in an exportable format + """ + return [export_quantizer_encoding(quantizer) for quantizer in self.input_quantizers] + class StaticGridQuantWrapper(QcQuantizeWrapper): """ A custom PyTorch module that derives from QcQuantizeWrapper and quantizes modules """ @@ -1324,3 +1348,37 @@ def backward(ctx, grad): dloss_by_db = grad.sum(dim=0) return dloss_by_dx, dloss_by_dW, dloss_by_db, dloss_by_dmin, dloss_by_dmax, None + + +def get_encoding_by_quantizer(quantizer: Union[StaticGridTensorQuantizer, LearnedGridTensorQuantizer]) \ + -> Optional[Union[libpymo.TfEncoding, List[libpymo.TfEncoding]]]: + """ + Retrieve encoding object by quantizer type (StaticGridTensorQuantizer or LearnedGridTensorQuantizer) + In particular, LearnedGridTensorQuantizer should use get_effective_encoding to achieve true encoding + + :param quantizer: TensorQuantizer (StaticGridTensorQuantizer or LearnedGridTensorQuantizer) + :return: TfEncoding or list of TfEncoding. None if quantizer is not enabled + """ + if isinstance(quantizer, LearnedGridTensorQuantizer): + return quantizer.get_effective_encoding() + + return quantizer.encoding + + +def export_quantizer_encoding(quantizer: Union[StaticGridTensorQuantizer, LearnedGridTensorQuantizer]) \ + -> Optional[List[Dict]]: + """ + Returns the encoding of a quantizer in exportable form. + + :param quantizer: Quantizer from which to export the encoding + :return: List of encoding dictionaries for the quantizer + """ + if not quantizer.enabled: + return None + encoding = get_encoding_by_quantizer(quantizer) + if isinstance(encoding, List): + encoding = [utils.create_encoding_dict(enc, quantizer, False) for enc in encoding] + else: + encoding = utils.create_encoding_dict(encoding, quantizer, False) + encoding = [encoding] if encoding else None + return encoding diff --git a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py index 9b6fec93ff3..ec292991a09 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py @@ -42,7 +42,7 @@ import io import copy import pickle -from typing import Tuple, List, Union, Dict, Callable, Optional, Any +from typing import Tuple, List, Union, Dict, Callable, Optional, Any, runtime_checkable, Protocol from collections.abc import Iterable import json import torch @@ -61,6 +61,7 @@ StaticGridQuantWrapper, LearnedGridQuantWrapper, NativeTorchQuantWrapper, QUANTIZER_TYPE_INPUT, QUANTIZER_TYPE_OUTPUT from aimet_torch.tensor_quantizer import StaticGridTensorQuantizer, LearnedGridTensorQuantizer, \ initialize_learned_grid_quantizer_attributes +from aimet_torch.qc_quantize_op import get_encoding_by_quantizer as _get_encoding_by_quantizer from aimet_torch import torchscript_utils, utils, transformer_utils, onnx_utils from aimet_torch.onnx_utils import OnnxSaver, OnnxExportApiArgs, CustomMarker, get_pytorch_name_from_onnx_name from aimet_torch.meta.connectedgraph import ConnectedGraph @@ -88,19 +89,6 @@ SUPPORTED_KERNELS_ACTION = SupportedKernelsAction.warn_on_error -def _get_encoding_by_quantizer(quantizer: Union[StaticGridTensorQuantizer, LearnedGridTensorQuantizer]) \ - -> Optional[Union[libpymo.TfEncoding, List[libpymo.TfEncoding]]]: - """ - Retrieve encoding object by quantizer type (StaticGridTensorQuantizer or LearnedGridTensorQuantizer) - In particular, LearnedGridTensorQuantizer should use get_effective_encoding to achieve true encoding - :param quantizer: TensorQuantizer (StaticGridTensorQuantizer or LearnedGridTensorQuantizer) - :return: TfEncoding or list of TfEncoding. None if quantizer is not enabled - """ - if isinstance(quantizer, LearnedGridTensorQuantizer): - return quantizer.get_effective_encoding() - - return quantizer.encoding - class QuantParams: """ @@ -131,6 +119,33 @@ def __init__(self, self.config_file = config_file +@runtime_checkable +class ExportableQuantModule(Protocol): + """ + Defines the minimum interface requirements for exporting encodings from a module. + """ + + def export_input_encodings(self) -> List[List[Dict]]: + """ + Returns a list of input encodings, each represented as a List of Dicts + """ + + def export_output_encodings(self) -> List[List[Dict]]: + """ + Returns a list of output encodings, each represented as a List of Dicts + """ + + def export_param_encodings(self) -> Dict[str, List[Dict]]: + """ + Returns a dict of {param name: param encodings}, with each encoding represented as a List of Dicts + """ + + def get_original_module(self) -> torch.nn.Module: + """ + Returns the floating point version of quantized module + """ + + class QuantizationSimModel: """ Implements mechanism to add quantization simulations ops to a model. This allows for off-target simulation of @@ -861,13 +876,15 @@ def _export_encodings_to_files(sim_model: torch.nn.Module, path: str, filename_p tensor_to_consumer_map = QuantizationSimModel._get_tensor_to_consumer_map(op_to_io_tensor_map) layer_names_not_found = [] - for layer_name, layer in QuantizationSimModel._get_qc_quantized_layers(sim_model): + for layer_name, layer in sim_model.named_modules(): + if not isinstance(layer, (ExportableQuantModule, QcQuantizeRecurrent)): + continue if not has_valid_encodings(layer): continue # TODO: specifically call out dropout layers here since they are specifically switched out during export. # These ops should eventually be reworked as part of math invariant ops to ignore quantization altogether. # pylint: disable=protected-access - if isinstance(layer, QcQuantizeWrapper) and isinstance(layer._module_to_wrap, utils.DROPOUT_TYPES): + if isinstance(layer, ExportableQuantModule) and isinstance(layer.get_original_module(), utils.DROPOUT_TYPES): continue if layer_name not in layers_to_onnx_op_names.keys(): @@ -959,7 +976,7 @@ def _get_layers_in_io_tensor_map(op_to_io_tensor_map: Dict) -> Dict[str, str]: return layers_to_onnx_op_names @staticmethod - def _update_param_encodings_dict_for_layer(layer: torch.nn.Module, layer_name: str, param_encodings: Dict, + def _update_param_encodings_dict_for_layer(layer: ExportableQuantModule, layer_name: str, param_encodings: Dict, valid_param_set: set): """ :param layer: layer as torch.nn.Module @@ -968,28 +985,17 @@ def _update_param_encodings_dict_for_layer(layer: torch.nn.Module, layer_name: s :param valid_param_set: a set of valid param input names in model """ - for orig_param_name, param_quantizer in layer.param_quantizers.items(): + for orig_param_name, param_encoding in layer.export_param_encodings().items(): param_name = layer_name + '.' + orig_param_name - if not param_quantizer.enabled: + if param_encoding is None: continue elif param_name not in valid_param_set: logger.error('Param tensor {%s} not found in valid param set', param_name) continue - elif isinstance(param_quantizer.encoding, Iterable): - param_encodings[param_name] = [] - quantizer_encoding = _get_encoding_by_quantizer(param_quantizer) - for encoding in quantizer_encoding: - enc_dict = QuantizationSimModel._create_encoding_dict(encoding, - param_quantizer, propagate_encodings=False) - param_encodings[param_name].append(enc_dict) - else: - quantizer_encoding = _get_encoding_by_quantizer(param_quantizer) - enc_dict = QuantizationSimModel._create_encoding_dict(quantizer_encoding, param_quantizer, - propagate_encodings=False) - param_encodings[param_name] = [enc_dict] + param_encodings[param_name] = param_encoding @staticmethod - def _update_encoding_dicts_for_layer(layer: torch.nn.Module, layer_name: str, activation_encodings_onnx: Dict, + def _update_encoding_dicts_for_layer(layer: ExportableQuantModule, layer_name: str, activation_encodings_onnx: Dict, activation_encodings_torch: Dict, param_encodings: Dict, op_to_io_tensor_map: Dict, valid_param_set: set, propagate_encodings: bool, tensor_to_consumer_map: Dict[str, str], @@ -1010,7 +1016,7 @@ def _update_encoding_dicts_for_layer(layer: torch.nn.Module, layer_name: str, ac :param layers_to_onnx_op_names: Dictionary mapping PyTorch layer names to names of corresponding ONNX ops """ - if isinstance(layer, QcQuantizeWrapper): + if isinstance(layer, ExportableQuantModule): # -------------------------------------- # Update encodings for Input activations @@ -1090,7 +1096,7 @@ def find_op_names_for_layer(layer_name: str, op_to_io_tensor_map: Dict, return end_op_names, op_names @staticmethod - def _update_encoding_dict_for_output_activations(layer: torch.nn.Module, layer_name: str, op_to_io_tensor_map: Dict, + def _update_encoding_dict_for_output_activations(layer: ExportableQuantModule, layer_name: str, op_to_io_tensor_map: Dict, activation_encodings_onnx: Dict, activation_encodings_torch: Dict, propagate_encodings: bool, tensor_to_consumer_map: Dict[str, str], layers_to_onnx_op_names: Dict[str, str]): @@ -1099,63 +1105,51 @@ def _update_encoding_dict_for_output_activations(layer: torch.nn.Module, layer_n op_to_io_tensor_map, tensor_to_consumer_map, layers_to_onnx_op_names) - num_quantizers = len(layer.output_quantizers) - num_outputs = len(output_tensors) - if len(output_tensors) != num_quantizers: + output_encodings = layer.export_output_encodings() + + if len(output_tensors) != len(output_encodings): logger.warning("number of output quantizers: %d available for layer: %s " - "doesn't match with number of output tensors: %d", num_quantizers, layer_name, num_outputs) + "doesn't match with number of output tensors: %d", len(output_encodings), layer_name, + len(output_tensors)) - for index, (output_tensor, quantizer) in enumerate(zip(output_tensors, layer.output_quantizers)): - if quantizer.enabled: - quantizer_encoding = _get_encoding_by_quantizer(quantizer) - enc = QuantizationSimModel._create_encoding_dict(quantizer_encoding, - quantizer, - propagate_encodings=False) - activation_encodings_onnx[output_tensor] = [enc] + for index, (output_tensor, encoding) in enumerate(zip(output_tensors, output_encodings)): - # Check if layer exists in the pytorch encoding dictionary + if encoding is not None: + activation_encodings_onnx[output_tensor] = encoding if layer_name not in activation_encodings_torch: activation_encodings_torch[layer_name] = {} if QUANTIZER_TYPE_OUTPUT not in activation_encodings_torch[layer_name]: activation_encodings_torch[layer_name][QUANTIZER_TYPE_OUTPUT] = {} - activation_encodings_torch[layer_name][QUANTIZER_TYPE_OUTPUT][index] = enc + activation_encodings_torch[layer_name][QUANTIZER_TYPE_OUTPUT][index] = encoding[0] if propagate_encodings: - enabled_quantizers = [q for q in layer.output_quantizers if q.enabled] - if enabled_quantizers: - quantizer = enabled_quantizers[0] + valid_encodings = [enc for enc in output_encodings if enc is not None] + if valid_encodings: + encoding = valid_encodings[0] for activation_tensor in propagate_tensors: - quantizer_encoding = _get_encoding_by_quantizer(quantizer) - enc = QuantizationSimModel._create_encoding_dict(quantizer_encoding, - quantizer, - propagate_encodings=True) - activation_encodings_onnx[activation_tensor] = [enc] + activation_encodings_onnx[activation_tensor] = utils.get_propagated_encoding_dict(encoding) @staticmethod - def _update_encoding_dict_for_input_activations(layer: torch.nn.Module, layer_name: str, op_to_io_tensor_map: Dict, + def _update_encoding_dict_for_input_activations(layer: ExportableQuantModule, layer_name: str, op_to_io_tensor_map: Dict, activation_encodings_onnx: Dict, activation_encodings_torch: Dict, layers_to_onnx_op_names: Dict[str, str]): - - # skip layer if all input quantizers are disabled. - if all(not quantizer.enabled for quantizer in layer.input_quantizers): + input_encodings = layer.export_input_encodings() + # skip layer if it has no input encodings. + if all(encoding is None for encoding in input_encodings): return input_tensors = QuantizationSimModel._get_layer_input_tensors(layer, layer_name, op_to_io_tensor_map, layers_to_onnx_op_names) - num_quantizers = len(layer.input_quantizers) - num_inputs = len(input_tensors) - if len(input_tensors) != num_quantizers: + + if len(input_tensors) != len(input_encodings): logger.warning("number of input quantizers: %d available for layer: %s " - "doesn't match with number of input tensors: %d", num_quantizers, layer_name, num_inputs) + "doesn't match with number of input tensors: %d", len(input_encodings), layer_name, + len(input_tensors)) - for index, (input_tensor, quantizer) in enumerate(zip(input_tensors, layer.input_quantizers)): - if quantizer.enabled: - quantizer_encoding = _get_encoding_by_quantizer(quantizer) - encoding = QuantizationSimModel._create_encoding_dict(quantizer_encoding, - quantizer, - propagate_encodings=False) - activation_encodings_onnx[input_tensor] = [encoding] + for index, (input_tensor, encoding) in enumerate(zip(input_tensors, input_encodings)): + if encoding is not None: + activation_encodings_onnx[input_tensor] = encoding # Check if layer exists in the pytorch encoding dictionary if layer_name not in activation_encodings_torch: activation_encodings_torch[layer_name] = {} @@ -1163,7 +1157,7 @@ def _update_encoding_dict_for_input_activations(layer: torch.nn.Module, layer_na activation_encodings_torch[layer_name][QUANTIZER_TYPE_INPUT] = {} # Store encodings for a particular index so that they can be used to check if a quantizer was # enabled or not - activation_encodings_torch[layer_name][QUANTIZER_TYPE_INPUT][index] = encoding + activation_encodings_torch[layer_name][QUANTIZER_TYPE_INPUT][index] = encoding[0] @staticmethod def _get_layer_input_tensors(layer: torch.nn.Module, layer_name: str, op_to_io_tensor_map: Dict, @@ -1178,7 +1172,7 @@ def _get_layer_input_tensors(layer: torch.nn.Module, layer_name: str, op_to_io_t :return: list of input tensor names. """ - param_inputs = [layer_name + '.' + param_name for param_name in layer.param_quantizers] + param_inputs = [layer_name + '.' + param_name for param_name, _ in layer.named_parameters()] if version.parse(torch.__version__) < version.parse("1.13.0") or not onnx_utils.EXPORT_TO_ONNX_DIRECT: start_op_names = [key for key in op_to_io_tensor_map if (key.startswith(layer_name) and '#0' in key) or key == layer_name] @@ -1414,26 +1408,7 @@ def _create_encoding_dict(encoding: libpymo.TfEncoding, quantizer, propagate_enc ops. :return: Encoding Dictionary """ - data_type, bitwidth = quantizer.data_type, quantizer.bitwidth - - if data_type == QuantizationDataType.float: - enc_dict = {'bitwidth': bitwidth, 'dtype': "float"} - else: - if encoding: - if propagate_encodings: - # Shortened encodings will be filled into a layer that only exists due to expansion of PyTorch ops - # into multiple ONNX ops so that it's necessarily to use the same bitwidth and type - enc_dict = {'bitwidth': encoding.bw, 'dtype': "int"} - else: - encoding_min, encoding_max, bw, scale, offset = encoding.min, encoding.max, encoding.bw, \ - encoding.delta, encoding.offset - is_symmetric = quantizer.use_symmetric_encodings - - enc_dict = {'min': encoding_min, 'max': encoding_max, 'scale': scale, 'offset': int(offset), - 'bitwidth': bw, 'is_symmetric': str(is_symmetric), 'dtype': "int"} - else: - enc_dict = None - return enc_dict + return utils.create_encoding_dict(encoding, quantizer, propagate_encodings) @classmethod def _remove_quantization_wrappers(cls, starting_module, list_of_modules_to_exclude): @@ -1448,10 +1423,10 @@ def _remove_quantization_wrappers(cls, starting_module, list_of_modules_to_exclu # If modules is in the exclude list, remove the wrapper if module_ref in list_of_modules_to_exclude: - if isinstance(module_ref, QcQuantizeWrapper): + if isinstance(module_ref, ExportableQuantModule): # Remove the wrapper, gets auto-deleted # pylint: disable=protected-access - setattr(starting_module, module_name, module_ref._module_to_wrap) + setattr(starting_module, module_name, module_ref.get_original_module()) elif isinstance(module_ref, QcQuantizeStandAloneBase): setattr(starting_module, module_name, torch.nn.Identity()) @@ -1475,17 +1450,23 @@ def get_original_model(model: torch.nn.Module): QuantizationSimModel._remove_quantization_wrappers(original_model, all_modules_in_original_model) return original_model - def _add_inputs_hook(self, hooks): + def _get_leaf_module_to_name_map(self): + """ + Returns a mapping from leaf modules to module name, where any ExportableQuantModule is considered a leaf module, + and is therefore not further recursed (since we do not want to retrieve all internal quantizers/modules). + """ + def recursively_populate_map(starting_module, module_map, start_str): + for name, module in starting_module.named_children(): + if isinstance(module, ExportableQuantModule) or utils.is_leaf_module(module): + module_map[module] = start_str + name + else: + recursively_populate_map(module, module_map, start_str + name + "/") module_to_name_map = {} - for name, module in self.model.named_modules(): - if isinstance(module, QcQuantizeWrapper): - # pylint: disable=protected-access - module_to_name_map[module._module_to_wrap] = name + recursively_populate_map(self.model, module_to_name_map, "") + return module_to_name_map - # Add any leaf modules that are not wrapped by QcQuantizeWrapper (like Identity) - for name, module in self.model.named_modules(): - if utils.is_leaf_module(module) and module not in module_to_name_map.keys(): - module_to_name_map[module] = name + def _add_inputs_hook(self, hooks): + module_to_name_map = self._get_leaf_module_to_name_map() def inputs_hook(module_ref, inputs, _): # Need to remove hook here, otherwise the jit trace of CustomMarker with module ref will error since the @@ -1493,28 +1474,22 @@ def inputs_hook(module_ref, inputs, _): hooks[module_ref].remove() del hooks[module_ref] module_name = module_to_name_map[module_ref] + if isinstance(module_ref, ExportableQuantModule): + module_ref = module_ref.get_original_module() marker_layer = torch.jit.trace(CustomMarker(module_ref, module_name, 'True'), inputs) self._module_marker_map[module_name] = marker_layer for name, module in self.model.named_modules(): - if name not in self._module_marker_map and utils.is_leaf_module(module): + if name in module_to_name_map.values(): hooks[module] = module.register_forward_hook(inputs_hook) def _validate_module_marker_map(self): """ Check to make sure all leaf modules have traced Custom Markers associated with them. """ - all_leaf_modules = set() + all_leaf_modules = self._get_leaf_module_to_name_map().values() missing_inputs_entries = [] - for name, module in self.model.named_modules(): - if isinstance(module, QcQuantizeWrapper): - all_leaf_modules.add(name) - - # Add any modules that are not wrapped by QcQuantizeWrappers (like Identity) - for name, module in self.model.named_modules(): - if utils.is_leaf_module(module) and '_module_to_wrap' not in name: - all_leaf_modules.add(name) for leaf_module in all_leaf_modules: if leaf_module not in self._module_marker_map.keys(): @@ -1954,24 +1929,23 @@ def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, pytorch_encodin quant_sim_model.replace_wrappers_for_quantize_dequantize() -def has_valid_encodings(qc_quantize_op: Union[QcQuantizeWrapper, QcQuantizeRecurrent]) -> bool: +def has_valid_encodings(qc_quantize_op: ExportableQuantModule) -> bool: """ Utility for determining whether a given qc_quantize_op has any valid encodings. :param qc_quantize_op: Qc quantize op to evaluate :return: True if any input, param, or output quantizers have valid encodings, False otherwise """ - if not isinstance(qc_quantize_op, (QcQuantizeWrapper, QcQuantizeRecurrent)): + if not isinstance(qc_quantize_op, (ExportableQuantModule, QcQuantizeRecurrent)): logger.error("has_valid_encodings only supported for QcQuantizeWrapper and QcQuantizeRecurrent " "modules") - assert isinstance(qc_quantize_op, (QcQuantizeWrapper, QcQuantizeRecurrent)) - - if isinstance(qc_quantize_op, QcQuantizeWrapper): - input_quantizers = qc_quantize_op.input_quantizers - output_quantizers = qc_quantize_op.output_quantizers - else: - input_quantizers = list(qc_quantize_op.input_quantizers.values()) - output_quantizers = list(qc_quantize_op.output_quantizers.values()) + assert isinstance(qc_quantize_op, (ExportableQuantModule, QcQuantizeRecurrent)) + if isinstance(qc_quantize_op, ExportableQuantModule): + all_encodings = qc_quantize_op.export_output_encodings() + qc_quantize_op.export_input_encodings() + \ + list(qc_quantize_op.export_param_encodings().values()) + return any([encoding is not None for encoding in all_encodings]) + input_quantizers = list(qc_quantize_op.input_quantizers.values()) + output_quantizers = list(qc_quantize_op.output_quantizers.values()) for quantizer in input_quantizers + output_quantizers + list(qc_quantize_op.param_quantizers.values()): if quantizer.enabled and (quantizer.encoding is not None or quantizer.data_type is QuantizationDataType.float): diff --git a/TrainingExtensions/torch/src/python/aimet_torch/utils.py b/TrainingExtensions/torch/src/python/aimet_torch/utils.py index 8f96b111d2f..6c4e6cd73b1 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/utils.py @@ -1044,3 +1044,44 @@ def record_dtypes(module, inputs, outputs): run_hook_for_layers_with_given_input(model, input_tensor, record_dtypes) return inout_dtypes_map + +def create_encoding_dict(encoding: libpymo.TfEncoding, quantizer, propagate_encodings: bool) -> Union[Dict, None]: + """ + Create encoding dictionary from encoding object + :param encoding: Encoding of the quantizer + :param quantizer: Tensor Quantizer + :param propagate_encodings: If True, encoding entries for intermediate ops (when one PyTorch ops results in + multiple ONNX nodes) are filled with the same BW and data_type as the output tensor for that series of + ops. + :return: Encoding Dictionary + """ + data_type, bitwidth = quantizer.data_type, quantizer.bitwidth + + if data_type == QuantizationDataType.float: + enc_dict = {'bitwidth': bitwidth, 'dtype': "float"} + else: + if encoding: + if propagate_encodings: + # Shortened encodings will be filled into a layer that only exists due to expansion of PyTorch ops + # into multiple ONNX ops so that it's necessarily to use the same bitwidth and type + enc_dict = {'bitwidth': encoding.bw, 'dtype': "int"} + else: + encoding_min, encoding_max, bw, scale, offset = encoding.min, encoding.max, encoding.bw, \ + encoding.delta, encoding.offset + is_symmetric = quantizer.use_symmetric_encodings + + enc_dict = {'min': encoding_min, 'max': encoding_max, 'scale': scale, 'offset': int(offset), + 'bitwidth': bw, 'is_symmetric': str(is_symmetric), 'dtype': "int"} + else: + enc_dict = None + return enc_dict + +def get_propagated_encoding_dict(encoding_dict: List[Dict[str, any]]) -> List[Dict[str, any]]: + """ + Creates encoding dictionary for intermediate ops (when one PyTorch ops results in multiple ONNX nodes), which are + filled with the same BW and data_type as the output tensor for that series of ops. + + :param encoding_dict: Encoding dictionary for the final output of the op + :return: Encoding dictionary for intermediate activations of the op + """ + return [{"bitwidth": encoding_dict[0]["bitwidth"], "dtype": encoding_dict[0]["dtype"]}] diff --git a/TrainingExtensions/torch/test/python/experimental/v2/models_/models_to_test.py b/TrainingExtensions/torch/test/python/experimental/v2/models_/models_to_test.py new file mode 100644 index 00000000000..3d4cf58a1e1 --- /dev/null +++ b/TrainingExtensions/torch/test/python/experimental/v2/models_/models_to_test.py @@ -0,0 +1,146 @@ +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= +""" Models for use in unit testing """ + +import torch +from torch import nn + +class SimpleConditional(torch.nn.Module): + """ + Model using conditional paths + Expected input shape = (1, 3) + """ + def __init__(self): + super(SimpleConditional, self).__init__() + self.prelu1 = torch.nn.PReLU(init=.3) + self.prelu2 = torch.nn.PReLU(init=.4) + self.linear1 = torch.nn.Linear(3, 2) + self.linear2 = torch.nn.Linear(3, 10) + self.softmax = torch.nn.Softmax() + + def forward(self, _input, condition): + if condition: + x = self.linear1(_input) + x = x.view(x.size(0), -1) + x = self.prelu1(x) + return x + x = self.linear2(_input) + x = self.prelu2(x) + x = self.softmax(x) + return x + +class ModelWithTwoInputs(nn.Module): + + def __init__(self): + super(ModelWithTwoInputs, self).__init__() + self.conv1_a = nn.Conv2d(1, 10, kernel_size=5) + self.maxpool1_a = nn.MaxPool2d(2) + self.relu1_a = nn.ReLU() + + self.conv1_b = nn.Conv2d(1, 10, kernel_size=5) + self.maxpool1_b = nn.MaxPool2d(2) + self.relu1_b = nn.ReLU() + + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.maxpool2 = nn.MaxPool2d(2) + self.relu2 = nn.LeakyReLU() + self.flatten = nn.Flatten() + + self.fc1 = nn.Linear(320, 50) + self.relu3 = nn.ReLU() + self.dropout = nn.Dropout() + self.fc2 = nn.Linear(50, 10) + + self.softmax = nn.LogSoftmax(dim=1) + + def forward(self, x1, x2): + x1 = self.relu1_a(self.maxpool1_a(self.conv1_a(x1))) + x2 = self.relu1_b(self.maxpool1_b(self.conv1_b(x2))) + x = x1 + x2 + x = self.relu2(self.maxpool2(self.conv2(x))) + x = self.flatten(x) + x = self.relu3(self.fc1(x)) + x = self.dropout(x) + x = self.fc2(x) + return self.softmax(x) + + +class FakeMultiOutputOp(torch.autograd.Function): + """ + This function helps create a custom onnx op to simulate a 5 output tensor onnx node + Note: the forward pass has some tensor computation to prevent torch onnx export from removing onnx node. + """ + + @staticmethod + def symbolic(g, inp): + """ + Magic method that helps with exporting a custom ONNX node + """ + return g.op('aimet_torch::FakeMultiOutputOp', inp, outputs=5) + + @staticmethod + def forward(ctx, x): # pylint: disable=arguments-differ + return x * 2, x * 4, x * 8, x * 16, x * 32 + + @staticmethod + def backward(ctx, _grad): # pylint: disable=arguments-differ + raise NotImplementedError() + + +class ModuleWith5Output(torch.nn.Module): + def forward(self, x): + return FakeMultiOutputOp.apply(x) + + +class ModelWith5Output(torch.nn.Module): + def __init__(self): + super(ModelWith5Output, self).__init__() + self.cust = ModuleWith5Output() + + def forward(self, x): + return self.cust(x) + + +class SoftMaxAvgPoolModel(torch.nn.Module): + def __init__(self): + super(SoftMaxAvgPoolModel, self).__init__() + self.sfmax = torch.nn.Softmax(dim=1) + self.avgpool = torch.nn.AvgPool2d(3) + + def forward(self, inp): + x = self.sfmax(inp) + return self.avgpool(x) \ No newline at end of file diff --git a/TrainingExtensions/torch/test/python/experimental/v2/test_quantsim_v1_export.py b/TrainingExtensions/torch/test/python/experimental/v2/test_quantsim_v1_export.py new file mode 100644 index 00000000000..f32266e3b6b --- /dev/null +++ b/TrainingExtensions/torch/test/python/experimental/v2/test_quantsim_v1_export.py @@ -0,0 +1,326 @@ +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2023-2023, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= +import tempfile + +import pytest +import torch.nn +import copy +import os +import json + +from aimet_torch.experimental.v2.quantization.quantization_mixin import _QuantizationMixin +from aimet_torch.elementwise_ops import Add +from aimet_torch import onnx_utils +from aimet_torch.quantsim import QuantizationSimModel, OnnxExportApiArgs + +from models_.models_to_test import SimpleConditional, ModelWithTwoInputs, ModelWith5Output, SoftMaxAvgPoolModel + +# Key/values don't matter +dummy_encoding = {"min": 0, + "max": 2, + "scale": 2/255, + "offset": 0, + "bitwidth": 8, + "dtype": "int", + "is_symmetric": "False"} + + +class DummyMixin(_QuantizationMixin, torch.nn.Module): + """ Dummy class for testing QuantSim export logic """ + + def __init__(self, module, num_inputs, num_outputs, has_input_encodings, has_output_encodings): + super(DummyMixin, self).__init__() + # Assign a dummy output quantizer (since a real mixin will have child quantizers) + self.output_quantizer = torch.nn.Identity() + # Hide module inside list so it doesnt show up as a child (We will not actually have a wrapped module) + self.module = [copy.deepcopy(module)] + self._parameters = self.module[0]._parameters + self.num_inputs = num_inputs + self.num_outputs = num_outputs + self.has_input_encodings = has_input_encodings + self.has_output_encodings = has_output_encodings + self.dummy_encoding = copy.deepcopy(dummy_encoding) + + @classmethod + def from_module(cls, module: torch.nn.Module, num_inputs=1, num_outputs=1, has_input_encodings=False, has_output_encodings=True): + return cls(module, num_inputs, num_outputs, has_input_encodings, has_output_encodings) + + def forward(self, *inputs): + return self.output_quantizer(self.module[0](*inputs)) + + def export_input_encodings(self): + enc = [self.dummy_encoding] if self.has_input_encodings else None + return [enc] * self.num_inputs + + def export_output_encodings(self): + enc = [self.dummy_encoding] if self.has_output_encodings else None + return [enc] * self.num_outputs + + def export_param_encodings(self): + enc_dict = {} + for name, param in self.module[0].named_parameters(): + if name == "weight": + enc_dict[name] = [self.dummy_encoding] * param.shape[0] + else: + enc_dict[name] = None + return enc_dict + + def get_original_module(self): + return copy.deepcopy(self.module[0]) + +class DummyModel(torch.nn.Module): + + def __init__(self, in_channels): + super(DummyModel, self).__init__() + self.conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=10, kernel_size=3, padding=1) + self.relu = torch.nn.ReLU() + self.conv2 = torch.nn.Conv2d(10, 10, 3, padding=1) + self.add = Add() + self.softmax = torch.nn.Softmax(dim=1) + + def forward(self, x): + x = self.conv1(x) + x_resid = self.relu(x) + x = self.conv2(x_resid) + x = self.add(x, x_resid) + return self.softmax(x) + +export_args = {'opset_version': None, 'input_names': None, 'output_names': None} + +class TestQuantsimOnnxExport: + + def test_onnx_export(self): + export_args = OnnxExportApiArgs(opset_version=10, input_names=["input"], output_names=["output"]) + input_shape = (1, 10, 32, 32) + fname = "test_model" + dummy_input = torch.randn(input_shape) + model = DummyModel(in_channels=input_shape[1]) + sim_model = copy.deepcopy(model) + for name, module in sim_model.named_children(): + has_input_encodings = False if name != "conv1" else True + has_output_encodings = True if name != "conv1" else False + num_inputs = 2 if name == "add" else 1 + sim_model.__setattr__(name, DummyMixin.from_module(module, num_inputs, 1, has_input_encodings, has_output_encodings)) + + + with tempfile.TemporaryDirectory() as path: + QuantizationSimModel.export_onnx_model_and_encodings(path, fname, model, sim_model, dummy_input=dummy_input, + onnx_export_args=export_args, + propagate_encodings=False, module_marker_map={}, + is_conditional=False, excluded_layer_names=None, + quantizer_args=None) + + file_path = os.path.join(path, fname + '.encodings') + + assert os.path.exists(file_path) + + with open(file_path) as f: + encoding_dict = json.load(f) + + # Format is "/layer_name/OnnxType_{output/input}_{idx}" + expected_act_keys = {"input", "/relu/Relu_output_0", "/conv2/Conv_output_0", "/add/Add_output_0", "output"} + expected_param_keys = {"conv1.weight", "conv2.weight"} + + assert set(encoding_dict["activation_encodings"].keys()) == expected_act_keys + assert set(encoding_dict["param_encodings"].keys()) == expected_param_keys + + for encoding in encoding_dict["activation_encodings"].values(): + assert encoding[0] == dummy_encoding + + + + # From: https://github.com/quic/aimet/blob/ce3dafe75d81893cdb8b45ba8abf53a672c28187/TrainingExtensions/torch/test/python/test_quantizer.py#L2731 + def test_export_to_onnx_direct(self): + model = ModelWithTwoInputs() + sim_model = copy.deepcopy(model) + dummy_input = (torch.rand(1, 1, 28, 28), torch.rand(1, 1, 28, 28)) + for name, layer in sim_model.named_children(): + has_input_encodings = name == "conv1_a" + wrapped_layer = DummyMixin.from_module(layer, has_input_encodings=has_input_encodings) + setattr(sim_model, name, wrapped_layer) + + with tempfile.TemporaryDirectory() as temp_dir: + onnx_utils.EXPORT_TO_ONNX_DIRECT = True + QuantizationSimModel.export_onnx_model_and_encodings(temp_dir, "direct_onnx_export", model, + sim_model, dummy_input, export_args, + propagate_encodings=False) + onnx_utils.EXPORT_TO_ONNX_DIRECT = False + QuantizationSimModel.export_onnx_model_and_encodings(temp_dir, "onnxsaver_export", model, + sim_model, dummy_input, export_args, + propagate_encodings=False) + + with open(os.path.join(temp_dir, 'direct_onnx_export.encodings')) as direct_onnx_json: + direct_onnx_encodings = json.load(direct_onnx_json) + with open(os.path.join(temp_dir, 'onnxsaver_export.encodings')) as onnxsaver_json: + onnxsaver_encodings = json.load(onnxsaver_json) + + assert len(direct_onnx_encodings['activation_encodings']) == \ + len(onnxsaver_encodings['activation_encodings']) + assert len(direct_onnx_encodings['param_encodings']) == len(onnxsaver_encodings['param_encodings']) + direct_onnx_act_names = direct_onnx_encodings['activation_encodings'].keys() + onnxsaver_act_names = onnxsaver_encodings['activation_encodings'].keys() + assert direct_onnx_act_names != onnxsaver_act_names + + + def test_encodings_propagation(self): + """ + Test encodings are propagated correctly when more than + one onnx node maps to the same torch module + """ + export_args = OnnxExportApiArgs(opset_version=10, input_names=["input"], output_names=["output"]) + pixel_shuffel = torch.nn.PixelShuffle(2) + model = torch.nn.Sequential(pixel_shuffel) + sim_model = torch.nn.Sequential(DummyMixin.from_module(pixel_shuffel, num_inputs=1, has_input_encodings=True, + has_output_encodings=True)) + dummy_input = torch.randn(1, 4, 8, 8) + + # Save encodings + with tempfile.TemporaryDirectory() as path: + fname_no_prop = "encodings_propagation_false" + fname_prop = "encodings_propagation_true" + QuantizationSimModel.export_onnx_model_and_encodings(path, fname_no_prop, model, sim_model, + dummy_input=dummy_input, + onnx_export_args=export_args, + propagate_encodings=False) + QuantizationSimModel.export_onnx_model_and_encodings(path, fname_prop, model, sim_model, + dummy_input=dummy_input, + onnx_export_args=export_args, + propagate_encodings=True) + with open(os.path.join(path, fname_no_prop + ".encodings")) as f: + encoding_dict_no_prop = json.load(f)["activation_encodings"] + with open(os.path.join(path, fname_prop + ".encodings")) as f: + encoding_dict_prop = json.load(f)["activation_encodings"] + + assert len(encoding_dict_no_prop) == 2 + assert len(encoding_dict_prop) == 4 + filtered_encoding_dict_prop = [{key: val} for key, val in encoding_dict_prop.items() if 'scale' in val[0]] + assert len(filtered_encoding_dict_prop) == 2 + + # From: https://github.com/quic/aimet/blob/ce3dafe75d81893cdb8b45ba8abf53a672c28187/TrainingExtensions/torch/test/python/test_quantizer.py#L3733 + def test_multi_output_onnx_op(self): + """ + Test mapping and exporting of output encodings for multiple output onnx op. + """ + model = ModelWith5Output() + dummy_input = torch.randn(1, 3, 224, 224) + sim_model = copy.deepcopy(model) + class DummyMixinWithDisabledOutput(DummyMixin): + def export_output_encodings(self): + enc = [self.dummy_encoding] + return [None] + ([enc] * (self.num_outputs - 1)) + + sim_model.cust = DummyMixinWithDisabledOutput.from_module(sim_model.cust, num_inputs=1, num_outputs=5, + has_input_encodings=True, has_output_encodings=True) + + with tempfile.TemporaryDirectory() as path: + QuantizationSimModel.export_onnx_model_and_encodings(path, 'module_with_5_output', model, sim_model, + dummy_input, + onnx_export_args=(onnx_utils.OnnxExportApiArgs(opset_version=11)), + propagate_encodings=False) + with open(os.path.join(path, "module_with_5_output.encodings")) as json_file: + activation_encodings = json.load(json_file)['activation_encodings'] + assert '7' not in activation_encodings + assert set(['8', '9', '10', '11', 't.1']).issubset(activation_encodings.keys()) + for item in activation_encodings.values(): + assert item[0] == sim_model.cust.dummy_encoding + + # From: https://github.com/quic/aimet/blob/ce3dafe75d81893cdb8b45ba8abf53a672c28187/TrainingExtensions/torch/test/python/test_quantizer.py#L1935 + def test_mapping_encoding_for_torch_module_with_multiple_onnx_ops(self): + """ + Test the input and output encoding map to input/output at subgraph level when a torch module generates + multiple onnx ops i.e. a sub-graph + """ + dummy_input = torch.randn(1, 4, 256, 512) + model = SoftMaxAvgPoolModel() + + sim_model = copy.deepcopy(model) + sim_model.sfmax = DummyMixin.from_module(sim_model.sfmax, 1, 1, True, True) + sim_model.avgpool = DummyMixin.from_module(sim_model.avgpool, 1, 1, True, True) + with tempfile.TemporaryDirectory() as path: + QuantizationSimModel.export_onnx_model_and_encodings(path, "sfmaxavgpool_model", model, sim_model, + dummy_input, export_args, propagate_encodings=False) + with open(os.path.join(path, "sfmaxavgpool_model" + ".encodings")) as json_file: + encoding_data = json.load(json_file) + + assert len(encoding_data["activation_encodings"]) == 3 + + + def test_conditional_export(self): + """ Test exporting a model with conditional paths """ + model = SimpleConditional() + model.eval() + inp = torch.randn(1, 3) + true_tensor = torch.tensor([1]) + false_tensor = torch.tensor([0]) + + def forward_callback(model, _): + model(inp, true_tensor) + model(inp, false_tensor) + + sim_model = copy.deepcopy(model) + + qsim = QuantizationSimModel(model, dummy_input=(inp, true_tensor)) + qsim.compute_encodings(forward_callback, None) + + for name, module in sim_model.named_children(): + qsim_module = getattr(qsim.model, name) + has_input_encodings = qsim_module.input_quantizers[0].enabled + has_output_encodings = qsim_module.output_quantizers[0].enabled + sim_model.__setattr__(name, DummyMixin.from_module(module, 1, 1, has_input_encodings, has_output_encodings)) + + qsim.model = sim_model + + with tempfile.TemporaryDirectory() as path: + qsim._export_conditional(path, 'simple_cond', dummy_input=(inp, false_tensor), + forward_pass_callback=forward_callback, forward_pass_callback_args=None) + + with open(os.path.join(path, 'simple_cond.encodings')) as f: + encodings = json.load(f) + # verifying the encoding against default eAI HW cfg + # activation encodings -- input, linear1 out, prelu1 out, linear2 out, prelu2 out, softmax out + assert 6 == len(encodings['activation_encodings']) + # param encoding -- linear 1 & 2 weight, prelu 1 & 2 weight + assert 4 == len(encodings['param_encodings']) + + expected_encoding_keys = {"/linear1/Add_output_0", + "/linear2/Add_output_0", + "/prelu1/CustomMarker_1_output_0", + "/prelu2/PRelu_output_0", + "/softmax/CustomMarker_1_output_0", + "_input.1", + } + assert encodings["activation_encodings"].keys() == expected_encoding_keys diff --git a/TrainingExtensions/torch/test/python/test_elementwise_ops.py b/TrainingExtensions/torch/test/python/test_elementwise_ops.py index ea53c0ed200..3cad86418fe 100644 --- a/TrainingExtensions/torch/test/python/test_elementwise_ops.py +++ b/TrainingExtensions/torch/test/python/test_elementwise_ops.py @@ -120,6 +120,7 @@ def test_quantsim_export(self): encodings.delta = 1 encodings.offset = 0.2 sim.model.op1.output_quantizers[0].encoding = encodings + sim.model.op1.input_quantizers[1].enabled = False sim.model.conv1.output_quantizers[0].encoding = encodings sim.model.conv1.param_quantizers['weight'].encoding = encodings sim.export(path='./data', filename_prefix='quant_model', dummy_input=dummy_input) @@ -127,7 +128,7 @@ def test_quantsim_export(self): with open('./data/quant_model.encodings') as f: data = json.load(f) - self.assertTrue(len(data['activation_encodings']) == 3) + self.assertTrue(len(data['activation_encodings']) == 2) self.assertTrue(len(data['param_encodings']) == 1) def test_subtract_op(self): diff --git a/TrainingExtensions/torch/test/python/test_qc_quantize_op.py b/TrainingExtensions/torch/test/python/test_qc_quantize_op.py index 18decee122d..8bf5df9b781 100644 --- a/TrainingExtensions/torch/test/python/test_qc_quantize_op.py +++ b/TrainingExtensions/torch/test/python/test_qc_quantize_op.py @@ -1088,3 +1088,46 @@ def test_wrapper_with_kwargs(self, wrapper): out3 = wrapper(torch.randn(2, 3), torch.randn(2, 3), torch.randn(3, 3)) assert out1.shape == out2.shape assert out2.shape == out3.shape + + @pytest.mark.parametrize("wrapper", + [StaticGridQuantWrapper(torch.nn.Linear(10, 10), + 8, 8, 'nearest', + QuantScheme.post_training_tf_enhanced, + num_inputs=2, num_outputs=2), + LearnedGridQuantWrapper(torch.nn.Linear(10, 10), + 8, 8, 'nearest', + QuantScheme.training_range_learning_with_tf_init, + torch.device('cpu'), + num_inputs=2, num_outputs=2) + ]) + def test_export_quantizer_encodings(self, wrapper): + wrapper.input_quantizers[0].enabled = True + wrapper.input_quantizers[1].enabled = False + wrapper.output_quantizers[0].enabled = False + wrapper.output_quantizers[1].enabled = True + wrapper.param_quantizers["weight"].enabled = True + wrapper.param_quantizers["bias"].enabled = False + encoding = libpymo.TfEncoding() + encoding.max = 127.0 + encoding.min = -128.0 + encoding.delta = 1.0 + encoding.offset = -128 + encoding.bw = 8 + encoding_as_dict = {"min": encoding.min, "max": encoding.max, "scale": encoding.delta, "offset": int(encoding.offset), + "bitwidth": encoding.bw, "is_symmetric": "False", "dtype": "int"} + wrapper.enable_per_channel_quantization() + wrapper.param_quantizers["weight"].encoding = [encoding] * 10 + wrapper.input_quantizers[0].encoding = encoding + wrapper.output_quantizers[1].encoding = encoding + input_encodings = wrapper.export_input_encodings() + output_encodings = wrapper.export_output_encodings() + param_encodings = wrapper.export_param_encodings() + assert len(input_encodings) == 2 + assert input_encodings[0] == [encoding_as_dict] + assert input_encodings[1] is None + assert len(output_encodings) == 2 + assert output_encodings[0] is None + assert output_encodings[1] == [encoding_as_dict] + assert len(param_encodings.items()) == 2 + assert param_encodings["bias"] is None + assert param_encodings["weight"] == [encoding_as_dict] * 10