diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/auto_quant_v2.py b/TrainingExtensions/onnx/src/python/aimet_onnx/auto_quant_v2.py
new file mode 100644
index 00000000000..0037ec10715
--- /dev/null
+++ b/TrainingExtensions/onnx/src/python/aimet_onnx/auto_quant_v2.py
@@ -0,0 +1,1177 @@
+# -*- mode: python -*-
+# =============================================================================
+# @@-COPYRIGHT-START-@@
+#
+# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# 3. Neither the name of the copyright holder nor the names of its contributors
+# may be used to endorse or promote products derived from this software
+# without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# @@-COPYRIGHT-END-@@
+# =============================================================================
+# pylint: disable=too-many-lines
+
+"""Automatic Post-Training Quantization V2"""
+
+import copy
+from collections import OrderedDict, defaultdict
+from dataclasses import dataclass
+import functools
+import math
+import traceback
+import os
+import sys
+import io
+from unittest.mock import patch
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Mapping, Iterable
+import jinja2
+from tqdm import tqdm
+
+import onnx
+import onnxruntime as ort
+from onnxruntime.quantization.onnx_quantizer import ONNXModel
+import numpy as np
+
+from aimet_onnx import utils
+from aimet_onnx.adaround.adaround_weight import Adaround, AdaroundParameters
+from aimet_onnx.cross_layer_equalization import equalize_model
+from aimet_onnx.batch_norm_fold import fold_all_batch_norms_to_weight
+from aimet_onnx.quantsim import QuantizationSimModel
+
+from aimet_common.auto_quant import Diagnostics
+from aimet_common.cache import Cache
+from aimet_common.defs import QuantScheme
+from aimet_common.utils import AimetLogger, Spinner
+from aimet_common.quantsim import validate_quantsim_inputs
+
+
+_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.AutoQuant)
+
+cache = Cache()
+
+
+@dataclass(frozen=True)
+class _QuantSchemePair:
+ param_quant_scheme: QuantScheme
+ output_quant_scheme: QuantScheme
+ param_percentile: Optional[float] = None
+ output_percentile: Optional[float] = None
+
+ def __str__(self):
+ def scheme_to_str(quant_scheme, percentile):
+ if quant_scheme == QuantScheme.post_training_percentile:
+ return f"{percentile}%ile"
+ if quant_scheme in (QuantScheme.post_training_tf,
+ QuantScheme.training_range_learning_with_tf_init):
+ return "tf"
+ if quant_scheme in (QuantScheme.post_training_tf_enhanced,
+ QuantScheme.training_range_learning_with_tf_enhanced_init):
+ return "tf-enhanced"
+ raise ValueError
+
+ param_str = scheme_to_str(self.param_quant_scheme, self.param_percentile)
+ output_str = scheme_to_str(self.output_quant_scheme, self.output_percentile)
+ return f"W@{param_str} / A@{output_str}"
+
+
+_QUANT_SCHEME_CANDIDATES = (
+ # Weight: tf
+ # Activation: tf
+ _QuantSchemePair(QuantScheme.post_training_tf,
+ QuantScheme.post_training_tf),
+
+ # Weight: tf_enhanced
+ # Activation: tf
+ _QuantSchemePair(QuantScheme.post_training_tf_enhanced,
+ QuantScheme.post_training_tf),
+
+ # Weight: tf_enhanced
+ # Activation: tf_enhanced
+ _QuantSchemePair(QuantScheme.post_training_tf_enhanced,
+ QuantScheme.post_training_tf_enhanced),
+
+ # TODO: Enable below candidates once we figure out how to set percentile value in QcQuantizeOp's Tensor Quantizer
+
+ # Weight: tf_enhanced
+ # Activation: percentile(99.9)
+ # _QuantSchemePair(QuantScheme.post_training_tf_enhanced,
+ # QuantScheme.post_training_percentile,
+ # output_percentile=99.9),
+
+ # Weight: tf_enhanced
+ # Activation: percentile(99.99)
+ # _QuantSchemePair(QuantScheme.post_training_tf_enhanced,
+ # QuantScheme.post_training_percentile,
+ # output_percentile=99.99),
+)
+
+
+def _validate_inputs(model: Union[onnx.ModelProto, ONNXModel], # pylint: disable=too-many-arguments
+ data_loader: Iterable[Union[np.ndarray, List[np.ndarray]]],
+ eval_callback: Callable[[ort.InferenceSession], float],
+ dummy_input: Dict[str, np.ndarray],
+ results_dir: str,
+ strict_validation: bool,
+ quant_scheme: QuantScheme,
+ param_bw: int,
+ output_bw: int,
+ rounding_mode: str):
+ """
+ Confirms inputs are of the correct type
+ :param model: Model to be quantized
+ :param data_loader: A collection that iterates over an unlabeled dataset, used for computing encodings
+ :param eval_callback: Function that calculates the evaluation score
+ :param dummy_input: Dummy input for the model
+ :param results_dir: Directory to save the results of PTQ techniques
+ :param strict_validation: Flag set to True by default. When False, AutoQuant will proceed with execution and try to handle errors internally if possible. This may produce unideal or unintuitive results.
+ :param quant_scheme: Quantization scheme
+ :param param_bw: Parameter bitwidth
+ :param output_bw: Output bitwidth
+ :param rounding_mode: Rounding mode
+ """
+ if not isinstance(model, (onnx.ModelProto, ONNXModel)):
+ raise ValueError('Model must be of type onnx.ModelProto or ONNXModel, not ' + str(type(model).__name__))
+
+ if not isinstance(data_loader, Iterable):
+ raise ValueError('data_loader must be of type Iterable, not ' + str(
+ type(data_loader).__name__))
+
+ if not isinstance(eval_callback, Callable):
+ raise ValueError('eval_callback must be of type Callable, not ' + str(type(eval_callback).__name__))
+
+ if not isinstance(dummy_input, Dict):
+ raise ValueError(
+ 'dummy_input must be of type Dict, not ' + str(type(dummy_input).__name__))
+
+ if not isinstance(results_dir, str):
+ raise ValueError('results_dir must be of type str, not ' + str(type(results_dir).__name__))
+
+ results_dir = os.path.abspath(results_dir)
+ os.makedirs(results_dir, exist_ok=True)
+
+ if not isinstance(strict_validation, bool):
+ raise ValueError('strict_validation must be of type bool, not ' + str(type(strict_validation).__name__))
+
+ validate_quantsim_inputs(quant_scheme, rounding_mode, output_bw, param_bw)
+
+
+class AutoQuant: # pylint: disable=too-many-instance-attributes
+ """
+ Integrate and apply post-training quantization techniques.
+
+ AutoQuant includes 1) batchnorm folding, 2) cross-layer equalization,
+ and 3) Adaround.
+ These techniques will be applied in a best-effort manner until the model
+ meets the evaluation goal given as allowed_accuracy_drop.
+ """
+
+ def __init__( # pylint: disable=too-many-arguments, too-many-locals
+ self,
+ model: Union[onnx.ModelProto, ONNXModel],
+ dummy_input: Dict[str, np.ndarray],
+ data_loader: Iterable[Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray]]],
+ eval_callback: Callable[[ort.InferenceSession], float],
+ eval_callback_args=None,
+ param_bw: int = 8,
+ output_bw: int = 8,
+ quant_scheme: QuantScheme = QuantScheme.post_training_tf_enhanced,
+ rounding_mode: str = 'nearest',
+ use_cuda: bool = True,
+ device: int = 0,
+ config_file: str = None,
+ results_dir: str = "/tmp",
+ cache_id: str = None,
+ strict_validation: bool = True) -> None:
+ '''
+ :param model: Model to be quantized. Assumes model is on the correct device
+ :param dummy_input: Dummy input for the model. Assumes that dummy_input is on the correct device
+ :param data_loader: A collection that iterates over an unlabeled dataset, used for computing encodings
+ :param eval_callback: Function that calculates the evaluation score given the model session
+ :param eval_callback_args: Extra arguments for eval_callback
+ :param param_bw: Parameter bitwidth
+ :param output_bw: Output bitwidth
+ :param quant_scheme: Quantization scheme
+ :param rounding_mode: Rounding mode
+ :param use_cuda: True if using CUDA to run quantization op. False otherwise.
+ :param config_file: Path to configuration file for model quantizers
+ :param results_dir: Directory to save the results of PTQ techniques
+ :param cache_id: ID associated with cache results
+ :param strict_validation: Flag set to True by default. When False, AutoQuant will proceed with execution and handle errors internally if possible. This may produce unideal or unintuitive results.
+ '''
+
+ _validate_inputs(model, data_loader, eval_callback, dummy_input, results_dir,
+ strict_validation, quant_scheme, param_bw, output_bw, rounding_mode)
+
+ if not isinstance(model, ONNXModel):
+ model = ONNXModel(model)
+
+ self.fp32_model = model
+ self.dummy_input = dummy_input
+ self.data_loader = data_loader
+ self.eval_callback = eval_callback
+ self.eval_callback_args = eval_callback_args
+
+ self._quantsim_params = dict(
+ param_bw=param_bw,
+ output_bw=output_bw,
+ quant_scheme=_QuantSchemePair(quant_scheme, quant_scheme),
+ rounding_mode=rounding_mode,
+ config_file=config_file,
+ use_cuda=use_cuda,
+ device=device
+ )
+
+ self.results_dir = results_dir
+ self.cache_dir = None
+ if cache_id:
+ self.cache_dir = os.path.join(results_dir, ".auto_quant_cache", cache_id)
+
+ def forward_pass_callback(session, _: Any = None):
+ for input_data in tqdm(data_loader):
+ input_data_dict = utils.create_input_dict(model.model, input_data)
+ _ = session.run(None, input_data_dict)
+
+ self.forward_pass_callback = forward_pass_callback
+
+ # Use at most 2000 samples for AdaRound.
+ input_instance = next(iter(self.data_loader))
+ batch_size = len(input_instance[0]) if isinstance(input_instance, (List, Tuple)) else len(input_instance)
+ num_batches = 0
+ for _ in self.data_loader:
+ num_batches += 1
+ num_samples = min(num_batches * batch_size, 2000)
+ num_batches = math.ceil(num_samples / batch_size)
+ self.adaround_params = AdaroundParameters(self.data_loader, num_batches)
+
+ self.eval_manager = _EvalManager(
+ quantsim_factory=self._create_quantsim_and_encodings,
+ eval_func=self._evaluate_model_performance,
+ results_dir=self.results_dir,
+ strict_validation=strict_validation)
+
+ self._quant_scheme_candidates = _QUANT_SCHEME_CANDIDATES
+ self._fp32_acc = None
+
+ def _evaluate_model_performance(self, session) -> float:
+ """
+ Evaluate the model performance.
+ """
+ return self.eval_callback(session, self.eval_callback_args)
+
+ def run_inference(self) -> Tuple[QuantizationSimModel, float]:
+ '''
+ Creates a quantization model and performs inference
+
+ :return: QuantizationSimModel, model accuracy as float
+ '''
+ model = self.fp32_model
+
+ # Batchnorm Folding
+ with self.eval_manager.session("Batchnorm Folding - Inference Run") as sess:
+ model, _ = sess.wrap(self._apply_batchnorm_folding)(model)
+ if sess.ptq_result is None:
+ sess.set_ptq_result(model=model,
+ applied_techniques=["batchnorm_folding"])
+
+ sim = self._create_quantsim_and_encodings(model)
+
+ if sess.ptq_result is None:
+ # BN folding failed. Need to measure the eval score
+ acc = self._evaluate_model_performance(sim.session)
+ else:
+ # BN folding success. No need to measure the eval score again
+ acc = sess.ptq_result.accuracy
+
+ return sim, acc
+
+ def optimize(self, allowed_accuracy_drop: float = 0.0) -> Tuple[ONNXModel, float, str]:
+ """
+ Integrate and apply post-training quantization techniques.
+
+ :param allowed_accuracy_drop: Maximum allowed accuracy drop
+ :return: Tuple of (best model, eval score, encoding path)
+ """
+ result = self._optimize_helper(self._optimize_main, allowed_accuracy_drop)
+ return result["model"],\
+ result["accuracy"],\
+ result["encoding_path"]
+
+ def set_adaround_params(self, adaround_params: AdaroundParameters) -> None:
+ """
+ Set Adaround parameters.
+ If this method is not called explicitly by the user, AutoQuant will use
+ `data_loader` (passed to `__init__`) for Adaround.
+
+ :param adaround_params: Adaround parameters.
+ """
+ self.adaround_params = adaround_params
+
+ def _create_quantsim_and_encodings( # pylint: disable=too-many-arguments, too-many-locals, too-many-branches
+ self,
+ fp32_model: ONNXModel,
+ rounding_mode: str = None,
+ output_bw: int = None,
+ output_quant_scheme: QuantScheme = None,
+ output_percentile: float = None,
+ param_bw: int = None,
+ param_quant_scheme: QuantScheme = None,
+ param_percentile: float = None,
+ config_file: str = None,
+ encoding_path: str = None,
+ ) -> QuantizationSimModel:
+ """
+ Create a QuantizationSimModel and compute encoding. If `encoding_path` is not None,
+ it is prioritized over other arguments (`output_bw`, `param_bw`, ...).
+
+ :param fp32_model: Model to quantize.
+ :param rounding_mode: Rounding mode. Defaults to self._quantsim_params["rounding_mode"].
+ :param output_bw: Default bitwidth (4-31) to use for quantizing layer inputs andoutputs.
+ Defaults to self._quantsim_params["output_bw"].
+ :param output_quant_scheme: Quantization scheme for output quantizers.
+ Defaults to self._quantsim_params["quant_scheme"].output_quant_scheme.
+ :param output_percentile: Percentile value for outputs.
+ Only valid if output quant scheme is percentile scheme.
+ :param param_bw: Default bitwidth (4-31) to use for quantizing layer parameters.
+ Defaults to self._quantsim_params["param_bw"].
+ :param param_quant_scheme: Quantization scheme for param quantizers.
+ Defaults to self._quantsim_params["quant_scheme"].param_quant_scheme.
+ :param param_percentile: Percentile value for parameters.
+ Only valid if param quant scheme is percentile scheme.
+ :param config_file: Path to configuration file for model quantizers.
+ Defaults to self._quantsim_params["config_file"].
+ :param encoding_path: Path to parameter encodings file.
+ :return: Quantsim model.
+ """
+ if output_bw is not None:
+ assert output_bw <= 32
+
+ if param_bw is not None:
+ assert param_bw <= 32
+
+ if output_quant_scheme is None or param_quant_scheme is None:
+ assert self._quantsim_params["quant_scheme"] is not None
+
+ model = copy.deepcopy(fp32_model)
+ kwargs = dict(
+ rounding_mode=(rounding_mode or self._quantsim_params["rounding_mode"]),
+ default_activation_bw=(output_bw or self._quantsim_params["output_bw"]),
+ default_param_bw=(param_bw or self._quantsim_params["param_bw"]),
+ config_file=(config_file or self._quantsim_params["config_file"]),
+ use_cuda=self._quantsim_params['use_cuda'],
+ device=self._quantsim_params['device']
+ )
+ sim = QuantizationSimModel(model, self.dummy_input, **kwargs)
+
+ param_quantizers, activation_quantizers = sim.get_all_quantizers()
+
+ default_quant_scheme = self._quantsim_params.get("quant_scheme")
+
+ output_quant_scheme = output_quant_scheme or\
+ default_quant_scheme.output_quant_scheme
+ output_percentile = output_percentile or default_quant_scheme.output_percentile
+ param_quant_scheme = param_quant_scheme or\
+ default_quant_scheme.param_quant_scheme
+ param_percentile = param_percentile or default_quant_scheme.param_percentile
+
+ # Set activation quantizers' quant schemes
+ for quantizer in activation_quantizers:
+ quantizer.set_quant_scheme(output_quant_scheme)
+ # TODO: Enable once we figure out how to set percentile value in QcQuantizeOp's Tensor Quantizer
+ # if quantizer.quant_scheme == QuantScheme.post_training_percentile and\
+ # output_percentile is not None:
+ # quantizer.set_percentile_value(output_percentile)
+
+ # Set param quantizers' quant schemes
+ for quantizer in param_quantizers:
+ quantizer.set_quant_scheme(param_quant_scheme)
+ # TODO: Enable once we figure out how to set percentile value in QcQuantizeOp's Tensor Quantizer
+ # if quantizer.quant_scheme == QuantScheme.post_training_percentile and\
+ # param_percentile is not None:
+ # quantizer.set_percentile_value(param_percentile)
+
+ if encoding_path:
+ sim.set_and_freeze_param_encodings(encoding_path)
+
+ # TODO: Other frameworks had this second call to fetch tensor quantizers. Need to check if also required for ONNX.
+ # param_quantizers, activation_quantizers = sim.get_all_quantizers()
+
+ # Disable activation quantizers, using fp32 to simulate int32.
+ if output_bw == 32:
+ for quantizer in activation_quantizers:
+ quantizer.enabled = False
+
+ # Disable param quantizers, using fp32 to simulate int32.
+ if param_bw == 32:
+ for quantizer in param_quantizers:
+ quantizer.enabled = False
+
+ # Skip encoding computation if none of the quantizers are enabled
+ if any(quantizer.enabled for quantizer in param_quantizers +\
+ activation_quantizers):
+ sim.compute_encodings(self.forward_pass_callback, None)
+
+ return sim
+
+ @staticmethod
+ @cache.mark("batchnorm_folding")
+ def _apply_batchnorm_folding(model: ONNXModel)\
+ -> Tuple[onnx.ModelProto, Tuple[List]]:
+ """
+ Apply batchnorm folding.
+
+ NOTE: Input model is not mutated.
+
+ :param model: Model to apply batchnorm folding.
+ :return: Output model and folded pairs.
+ """
+ model = copy.deepcopy(model)
+ conv_bns, bn_convs = fold_all_batch_norms_to_weight(model)
+ return model, conv_bns + bn_convs
+
+ @staticmethod
+ @cache.mark("cle")
+ def _apply_cross_layer_equalization(model: ONNXModel) -> onnx.ModelProto:
+ """
+ Apply cross-layer equalization.
+
+ NOTE: Input model is not mutated.
+
+ :param model: Model to apply cross-layer-equalization.
+ :return: Output model.
+ """
+ model = copy.deepcopy(model)
+ equalize_model(model)
+ return model
+
+ @cache.mark("adaround")
+ def _apply_adaround(self, model: ONNXModel) -> Tuple[onnx.ModelProto, str]:
+ """
+ Apply adaround.
+
+ NOTE1: Input model is not mutated.
+ NOTE2: Parameters `param_bw_override_list` and `ignore_quant_ops_list` are always set to None.
+
+ :param model: Model to apply adaround.
+ :return: Output model and the path to the parameter encoding file.
+ """
+ filename_prefix = "adaround"
+ adaround_encoding_path = os.path.join(self.results_dir,
+ "{}.encodings".format(filename_prefix))
+
+ sim = self._create_quantsim_and_encodings(model)
+
+ _, activation_quantizers = sim.get_all_quantizers()
+ for quantizer in activation_quantizers:
+ quantizer.enabled = False
+
+ model = Adaround._apply_adaround(sim, model, self.adaround_params, # pylint: disable=protected-access
+ path=self.results_dir, filename_prefix=filename_prefix)
+
+ return model, adaround_encoding_path
+
+ def _optimize_helper(
+ self,
+ optimize_fn: Callable,
+ allowed_accuracy_drop: float) -> Tuple[ONNXModel, float, str]:
+ """
+ Integrate and apply post-training quantization techniques.
+
+ :param allowed_accuracy_drop: Maximum allowed accuracy drop
+ :return: Tuple of (best model, eval score, encoding path)
+ """
+ allowed_accuracy_drop = float(allowed_accuracy_drop)
+ if allowed_accuracy_drop < 0:
+ raise ValueError(
+ "`allowed_accuracy_drop` must be a positive value. Got {:.2f}"
+ .format(allowed_accuracy_drop)
+ )
+
+ self.eval_manager.clear()
+
+ try:
+ with cache.enable(self.cache_dir):
+ _logger.info("Starting AutoQuant")
+
+ if self._quantsim_params['use_cuda']:
+ providers = [('CUDAExecutionProvider', {'device_id': self._quantsim_params['device']}), 'CPUExecutionProvider']
+ else:
+ providers = ['CPUExecutionProvider']
+ fp32_model_session = QuantizationSimModel.build_session(self.fp32_model.model, providers)
+ self._fp32_acc = self._evaluate_model_performance(fp32_model_session)
+ target_acc = self._fp32_acc - allowed_accuracy_drop
+ _logger.info("Target eval score: %f", target_acc)
+ _logger.info("FP32 eval score (W32A32): %f", self._fp32_acc)
+
+ ret = optimize_fn(self.fp32_model, target_acc)
+
+ acc = ret["accuracy"]
+ if acc is not None:
+ _logger.info("Best eval score: %f", acc)
+
+ if acc < target_acc:
+ _logger.info(
+ "AutoQuant is unable to match the target accuracy. "
+ "Consider Quantization Aware Training."
+ )
+
+ return ret
+ finally:
+ self.eval_manager.export_diagnostics()
+
+ def get_quant_scheme_candidates(self) -> Tuple[_QuantSchemePair, ...]:
+ """
+ Return the candidates for quant scheme search.
+ During :meth:`~AutoQuant.optimize`, the candidate with the highest accuracy
+ will be selected among them.
+
+ :return: Candidates for quant scheme search
+ """
+ return self._quant_scheme_candidates
+
+ def _choose_default_quant_scheme(self):
+ def eval_fn(pair: _QuantSchemePair):
+ sim = self._create_quantsim_and_encodings(
+ self.fp32_model,
+ param_quant_scheme=pair.param_quant_scheme,
+ param_percentile=pair.param_percentile,
+ output_quant_scheme=pair.output_quant_scheme,
+ output_percentile=pair.output_percentile,
+ )
+ eval_score = self._evaluate_model_performance(sim.session)
+ _logger.info("Evaluation finished: %s (eval score: %f)", pair, eval_score)
+ return eval_score
+
+ param_bw = self._quantsim_params["param_bw"]
+ output_bw = self._quantsim_params["output_bw"]
+
+ candidates = self.get_quant_scheme_candidates()
+
+ # If the weight representation has sufficient precision (i.e. bitwidth >= 16),
+ # always use tf scheme
+ if param_bw >= 16:
+ candidates = [
+ candidate for candidate in candidates
+ if candidate.param_quant_scheme == QuantScheme.post_training_tf
+ ]
+
+ # If the output representation has sufficient precision (i.e. bitwidth >= 16),
+ # always use tf scheme
+ if output_bw >= 16:
+ candidates = [
+ candidate for candidate in candidates
+ if candidate.output_quant_scheme == QuantScheme.post_training_tf
+ ]
+
+ # If we have only one candidate left, we don't need to evaluated
+ # the quant scheme for comparison
+ if len(candidates) == 1:
+ return candidates[0]
+
+ assert candidates
+
+ # Find the quant scheme that yields the best eval score
+ best_quant_scheme = max(candidates, key=eval_fn)
+ _logger.info("Best Quant Scheme: %s", best_quant_scheme)
+
+ return best_quant_scheme
+
+ def _optimize_main(self, fp32_model: ONNXModel, target_acc: float):
+ """
+ Helper function of apply().
+
+ :param fp32_model: Model to apply PTQ techniques.
+ :param target_acc: Target eval score.
+
+ :raises RuntimeError: If none of the PTQ techniques were finished successfully.
+
+ :return: The best ptq result as a dictionary.
+ """
+
+ # Choose best quant scheme automatically.
+ with self.eval_manager.session("QuantScheme Selection") as sess:
+ self._quantsim_params["quant_scheme"] = sess.wrap(self._choose_default_quant_scheme)()
+
+ # Early exit
+ with self.eval_manager.session(f"W32 Evaluation") as sess:
+ w32_eval_score = sess.wrap(sess.eval)(fp32_model, param_bw=32)
+ _logger.info("Evaluation finished: W32A%d (eval score: %f)",
+ self._quantsim_params["output_bw"], w32_eval_score)
+
+ if w32_eval_score < target_acc:
+ _logger.info(
+ "W32A%d eval score (%f) is lower "
+ "than the target eval score (%f). This means it is unlikely that "
+ "the target eval score can be met using PTQ techniques. "
+ "Please consider finetuning the model using range learning.",
+ self._quantsim_params["output_bw"], w32_eval_score, target_acc
+ )
+
+ # Since AutoQuant pipeline exited early, all the return values are set to None
+ return {
+ "model": None,
+ "accuracy": None,
+ "encoding_path": None,
+ "applied_techniques": None,
+ }
+
+ sess.result["target_satisfied"] = True
+
+ # Batchnorm Folding
+ with self.eval_manager.session("Batchnorm Folding", ptq=True) as sess:
+ model, _ = sess.wrap(self._apply_batchnorm_folding)(fp32_model)
+ if sess.ptq_result is None:
+ sess.set_ptq_result(model=model,
+ applied_techniques=["batchnorm_folding"])
+
+ best_result = self.eval_manager.get_best_ptq_result()
+ if best_result and best_result.accuracy >= target_acc:
+ sess.result["target_satisfied"] = True
+ return best_result.as_dict()
+
+ # Cross-Layer Equalization
+ with self.eval_manager.session("Cross-Layer Equalization", ptq=True) as sess:
+ model = sess.wrap(self._apply_cross_layer_equalization)(fp32_model)
+ if sess.ptq_result is None:
+ sess.set_ptq_result(model=model,
+ applied_techniques=["cross_layer_equalization"])
+
+ best_result = self.eval_manager.get_best_ptq_result()
+ if best_result and best_result.accuracy >= target_acc:
+ sess.result["target_satisfied"] = True
+ return best_result.as_dict()
+
+ if best_result is None:
+ model = fp32_model
+ applied_techniques = []
+ else:
+ if "cross_layer_equalization" not in best_result.applied_techniques:
+ sess.result["effective"] = False
+ model = best_result.load_model()
+ applied_techniques = best_result.applied_techniques
+
+ # AdaRound
+ with self.eval_manager.session("AdaRound", ptq=True) as sess:
+ model, encoding_path = sess.wrap(self._apply_adaround)(model)
+ if sess.ptq_result is None:
+ sess.set_ptq_result(model=model,
+ encoding_path=encoding_path,
+ applied_techniques=[*applied_techniques, "adaround"])
+
+ best_result = self.eval_manager.get_best_ptq_result()
+ if best_result:
+ if "adaround" not in best_result.applied_techniques:
+ sess.result["effective"] = False
+ if best_result.accuracy >= target_acc:
+ sess.result["target_satisfied"] = True
+ return best_result.as_dict()
+
+ raise RuntimeError("None of Batchnorm Folding, CLE, or Adaround "
+ "has been finished successfully.")
+
+
+@dataclass
+class PtqResult:
+ """
+ Evaluation results.
+ :param tag: Identifier string of the evaluation result.
+ :param model_path: Path to the serialized model.
+ :param encoding_path: Path to the encoding file.
+ :param accuracy: Accuracy of the model.
+ """
+ model_path: str
+ encoding_path: str
+ accuracy: float
+ applied_techniques: List[str]
+
+ def load_model(self) -> ONNXModel:
+ """
+ Load model.
+ :return: Loaded model.
+ """
+ return ONNXModel(onnx.load(self.model_path))
+
+ def as_dict(self):
+ """Convert to dictionary"""
+ return dict(model=self.load_model(),
+ accuracy=self.accuracy,
+ encoding_path=self.encoding_path,
+ applied_techniques=self.applied_techniques)
+
+
+class _EvalManager:
+ """
+ Evaluation manager for AutoQuant.
+ """
+ def __init__(self,
+ quantsim_factory: Callable,
+ eval_func: Callable[[ort.InferenceSession], float],
+ results_dir: str,
+ strict_validation: bool):
+ """
+ :param quantsim_factory: A factory function that returns QuantizationSimModel.
+ :param eval_func: Evaluation function.
+ :param results_dir: Base directory to save the temporary serialized model.
+ """
+ self._quantsim_factory = quantsim_factory
+ self._eval_func = eval_func
+ self._results_dir = results_dir
+ self._strict_validation = strict_validation
+
+ os.makedirs(self._results_dir, exist_ok=True)
+
+ self._all_sessions = OrderedDict() # type: OrderedDict[str, _EvalSession]
+
+ def clear(self):
+ """
+ Clear all the session status saved in the previous run
+ """
+ for sess in self._all_sessions.values():
+ sess.reset_status()
+
+ def get_best_ptq_result(self) -> Optional[PtqResult]:
+ """
+ Get the results with the highest evaluation score among the ptq results evaluated so far.
+ :return: The best evaluation result so far.
+ """
+ # pylint: disable=protected-access
+ ptq_results = [sess.ptq_result for sess in self._all_sessions.values()
+ if sess.ptq_result is not None and sess._ptq]
+ if not ptq_results:
+ return None
+
+ return max(ptq_results, key=lambda ptq_result: ptq_result.accuracy)
+
+ def session(self, title: str, ptq: bool = False):
+ """
+ Session factory.
+ :param title: Title of the session.
+ :param ptq: True if this session is a ptq session
+ :return: Session object.
+ """
+ if title not in self._all_sessions:
+ session = _EvalSession(title,
+ self._quantsim_factory,
+ self._eval_func,
+ results_dir=os.path.join(self._results_dir, ".trace"),
+ strict_validation=self._strict_validation,
+ ptq=ptq)
+ self._all_sessions[title] = session
+ return self._all_sessions[title]
+
+ HTML_TEMPLATE_FILE = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)),
+ "auto_quant_v2_diagnostics_template.html",
+ )
+
+ def export_diagnostics(self) -> str:
+ """
+ Export diagnostics in html format.
+ :return: Diagnostics string in html format.
+ """
+ loader = jinja2.FileSystemLoader(os.path.dirname(self.HTML_TEMPLATE_FILE))
+ env = jinja2.Environment(loader=loader)
+ template = env.get_template(os.path.basename(self.HTML_TEMPLATE_FILE))
+
+ if any(sess.diagnostics.contains_bokeh() for sess in self._all_sessions.values()):
+ from bokeh.resources import CDN
+ head = CDN.render()
+ else:
+ head = ""
+
+ log = io.StringIO()
+ for sess in self._all_sessions.values():
+ if sess.diagnostics.is_empty():
+ continue
+ log.write(
+ f"
{sess.title}
\n"
+ )
+ content = "\n".join(
+ line.get_html_elem() for line in sess.diagnostics
+ )
+ log.write(f"{content}\n")
+
+ result = OrderedDict()
+ result["ptq_techniques"] = OrderedDict()
+
+ for sess in self._all_sessions.values():
+ if sess.is_ptq_session():
+ result["ptq_techniques"][sess.title_lowercase] = sess.result
+ else:
+ result[sess.title_lowercase] = sess.result
+
+ flowchart_metadata = _build_flowchart_metadata(result)
+
+ html = template.render(head=head, log=log.getvalue(), **flowchart_metadata)
+
+ filename = os.path.join(self._results_dir, "diagnostics.html")
+ with open(filename, "w") as f:
+ f.write(html)
+ return html
+
+
+class _EvalSession: # pylint: disable=too-many-instance-attributes
+ """
+ Evaluation session for AutoQuant.
+
+ Each session object contains a title and diagnostics produced during the session.
+ The collected diagnostics will be exported into a html file by _EvalManager.
+ """
+ def __init__(
+ self,
+ title: str,
+ quantsim_factory: Callable,
+ eval_func: Callable[[ort.InferenceSession], float],
+ results_dir: str,
+ strict_validation: bool,
+ ptq: bool,
+ ):
+ """
+ :param title: Title of the session.
+ :param quantsim_factory: A factory function that returns QuantizationSimModel.
+ :param eval_func: Evaluation function.
+ :param results_dir: Base directory to save the temporary serialized model.
+ :param ptq: True if this session is a ptq session
+ """
+ self.title = title
+ self._quantsim_factory = quantsim_factory
+ self._eval_func = eval_func
+ self._results_dir = results_dir
+ self._strict_validation = strict_validation
+ self._ptq = ptq
+
+ self._spinner = None
+
+ self.result = {
+ "status": None,
+ "error": None,
+ "target_satisfied": False,
+ "effective": True,
+ }
+
+ os.makedirs(self._results_dir, exist_ok=True)
+
+ self.diagnostics = Diagnostics()
+
+ # Map session title to file name.
+ # e.g. title: "Cross-Layer Equalization" -> filename: "cross_layer_equalization"
+ self.title_lowercase = self.title.lower().replace("-", " ")
+ self.title_lowercase = "_".join(self.title_lowercase.split())
+
+ stdout_write = sys.stdout.write
+ self._log = io.StringIO()
+
+ # Redirects stdout to self._log
+ def write_wrapper(*args, **kwargs):
+ self._log.write(*args, **kwargs)
+ return stdout_write(*args, **kwargs)
+
+ self._stdout_redirect = patch.object(sys.stdout, "write", write_wrapper)
+ self._ptq_result = None
+ self._cached_result = None
+
+ def is_ptq_session(self):
+ """
+ Getter method of self._ptq flag
+ """
+ return self._ptq
+
+ def reset_status(self):
+ """
+ Reset the session status saved in the previous run
+ """
+ self.result = {
+ "status": None,
+ "error": None,
+ "target_satisfied": False,
+ "effective": True,
+ }
+
+ def wrap(self, fn):
+ """
+ Return a wrapper function that caches the return value.
+
+ :param fn: Function to wrap.
+ :returns: Function whose return value is cached.
+ """
+ import pickle
+ from uuid import uuid4
+
+ results_dir = self._results_dir
+ class CachedResult:
+ """Cached result """
+ def __init__(self, obj):
+ self._filename = os.path.join(results_dir, f".{uuid4()}")
+ while os.path.exists(self._filename):
+ self._filename = os.path.join(results_dir, f".{uuid4()}")
+ with open(self._filename, "wb") as f:
+ pickle.dump(obj, f)
+
+ def load(self):
+ """Load cached result """
+ with open(self._filename, "rb") as f:
+ return pickle.load(f)
+
+ @functools.wraps(fn)
+ def wrapper(*args, **kwargs):
+ if self._cached_result:
+ return self._cached_result.load()
+ ret = fn(*args, **kwargs)
+ self._cached_result = CachedResult(ret)
+ return ret
+ return wrapper
+
+ def eval(self, model: ONNXModel, **kwargs):
+ """
+ Evaluate the model.
+ :param model: Model to evaluate.
+ :param **kwargs: Additional arguments to the quantsim factory.
+ :return: Eval score.
+ """
+ sim = self._quantsim_factory(model, **kwargs)
+ acc = self._eval_func(sim.session)
+ return acc
+
+ def __enter__(self):
+ self._spinner = Spinner(self.title)
+ self._spinner.__enter__()
+ self._stdout_redirect.start()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self._ptq_result is not None:
+ _logger.info("Session finished: %s. (eval score: %f)",
+ self.title, self._ptq_result.accuracy)
+
+ self._spinner.__exit__(exc_type, exc_val, exc_tb)
+
+ if exc_val:
+ buffer = io.StringIO()
+ traceback.print_exception(exc_type, exc_val, exc_tb, file=buffer)
+
+ if self._strict_validation:
+ print(buffer.getvalue())
+ else:
+ print(
+ "################################################################\n"
+ "################################################################\n"
+ "################################################################\n"
+ "WARNING: The following exception was raised but ignored:\n\n"
+ f"{buffer.getvalue()}"
+ "################################################################\n"
+ "################################################################\n"
+ "################################################################\n"
+ )
+
+ self._stdout_redirect.stop()
+ self.diagnostics.add(self._log.getvalue())
+
+ self.result["error"] = exc_val
+ if not exc_val:
+ self.result["status"] = "success"
+ elif self._strict_validation:
+ self.result["status"] = "error-failed"
+ else:
+ self.result["status"] = "error-ignored"
+
+ if exc_val and not self._strict_validation:
+ # Return True so that the error doesn't propagate further
+ return True
+ return None
+
+ @property
+ def ptq_result(self) -> Optional[PtqResult]:
+ """Getter of self._ptq_result."""
+ return self._ptq_result
+
+ def set_ptq_result(
+ self,
+ applied_techniques: List[str],
+ model: onnx.ModelProto = None,
+ sim: QuantizationSimModel = None,
+ acc: float = None,
+ **kwargs
+ ) -> None:
+ """
+ Set the result of PTQ. Should be called exactly once inside a with-as block.
+
+ Exactly one among model and (sim, acc) pair should be specified.
+ 1) If sim and acc is specified, save them as the result of this session.
+ 2) If model is specified, evaluate the quantized accuracy of the model and save the result.
+
+ :param model: Result of PTQ.
+ :param sim: Result of PTQ. The quamtization encoding (compute_encodings()) is
+ assumed to have been computed in advance.
+ :param acc: Eval score.
+ :param **kwargs: Additional arguments to the quantsim factory.
+ :return: None
+ """
+
+ if sim is None:
+ assert acc is None
+ assert model is not None
+ sim = self._quantsim_factory(model, **kwargs)
+ acc = self._eval_func(sim.session)
+ else:
+ assert acc is not None
+ assert model is None
+
+ self._set_ptq_result(sim, acc, applied_techniques)
+
+ def _set_ptq_result(
+ self,
+ sim: QuantizationSimModel,
+ acc: float,
+ applied_techniques: List[str],
+ ) -> PtqResult:
+ """
+ Set the result of PTQ. Should be called exactly once inside a with-as block.
+
+ :param sim: Result of PTQ. The quamtization encoding (compute_encodings()) is
+ assumed to have been computed in advance.
+ :param acc: Eval score.
+ :return: PtqResult object.
+ """
+ if self._ptq_result is not None:
+ raise RuntimeError(
+ "sess.eval() can be called only once per each _EvalSession instance."
+ )
+
+ model_path, encoding_path = self._export(sim)
+ self._ptq_result = PtqResult(
+ model_path=model_path,
+ encoding_path=encoding_path,
+ accuracy=acc,
+ applied_techniques=applied_techniques,
+ )
+ return self._ptq_result
+
+ def _export(self, sim: QuantizationSimModel) -> Tuple[str, str]:
+ """
+ Export quantsim.
+ :param sim: QuantizationSimModel object to export.
+ :return: The paths where model and encoding are saved
+ """
+ sim.export(path=self._results_dir,
+ filename_prefix=self.title_lowercase)
+ model_path = os.path.join(self._results_dir, f"{self.title_lowercase}.onnx")
+ encoding_path = os.path.join(self._results_dir, f"{self.title_lowercase}.encodings")
+ _logger.info("The results of %s is saved in %s and %s.",
+ self.title, model_path, encoding_path)
+ return model_path, encoding_path
+
+
+def _build_flowchart_metadata(result: Mapping) -> Dict: # pylint: disable=too-many-return-statements
+ """
+ Build flowchart metadata for the html template of summary report
+
+ :param result: Result of AutoQuant with the following format:
+
+ result := {
+ "quantscheme_selection": _stage_result,
+ "w32_evaluation": _stage_result,
+ "ptq_techniques" [
+ "batchnorm_folding": _stage_result,
+ "cross_layer_equalization": _stage_result,
+ "adaround": _stage_result,
+ ]
+
+ }
+
+ where _stage_result is a dictionary defined as below:
+
+ _stage_result := {
+ "status": str,
+ "error": Exception,
+ "target_satisfied": bool,
+ "effective": bool,
+ }
+
+ :return: Dictionary that contains flowchart metadata for html template
+ """
+ metadata = defaultdict(str)
+ metadata.update(
+ edge_quant_scheme_selection_in='data-visited="true"',
+ )
+ if "quantscheme_selection" in result:
+ status = result['quantscheme_selection']['status']
+ metadata.update(
+ node_quant_scheme_selection=f'data-visited="true" data-stage-result="{status}"',
+ )
+
+ if status == 'error-failed':
+ return metadata
+
+ metadata.update(
+ edge_quant_scheme_selection_out='data-visited="true"',
+ node_test_w32_eval_score='data-visited="true"',
+ )
+
+ if not result["w32_evaluation"]["target_satisfied"]:
+ metadata.update(
+ edge_test_w32_eval_score_if_false='data-visited="true"',
+ node_result_fail='data-visited="true"',
+ )
+ return metadata
+
+ metadata.update(
+ edge_test_w32_eval_score_if_true='data-visited="true"',
+ )
+
+ for ptq_name, ptq_result in result["ptq_techniques"].items():
+ status = ptq_result['status']
+ effective = ptq_result['effective']
+ if status == "success" and not effective:
+ status = "discarded"
+ metadata.update({
+ f"node_{ptq_name}": f'data-visited="true" data-stage-result="{status}"',
+ })
+
+ if status == 'error-failed':
+ return metadata
+
+ metadata.update({
+ f'edge_{ptq_name}_out': 'data-visited="true"',
+ f'node_test_{ptq_name}': 'data-visited="true"',
+ })
+
+ if ptq_result['target_satisfied']:
+ metadata.update({
+ f'edge_test_{ptq_name}_if_true': 'data-visited="true"',
+ 'node_result_success': 'data-visited="true"',
+ })
+ return metadata
+
+ metadata.update({
+ f'edge_test_{ptq_name}_if_false': 'data-visited="true"',
+ })
+
+ metadata.update(
+ node_result_fail='data-visited="true"',
+ )
+
+ return metadata
diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/auto_quant_v2_diagnostics_template.html b/TrainingExtensions/onnx/src/python/aimet_onnx/auto_quant_v2_diagnostics_template.html
new file mode 100644
index 00000000000..5ff95564823
--- /dev/null
+++ b/TrainingExtensions/onnx/src/python/aimet_onnx/auto_quant_v2_diagnostics_template.html
@@ -0,0 +1,331 @@
+
+
+
+ {{ head }}
+
+
+
+
+
+
+ {{ log }}
+
+
+
diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/batch_norm_fold.py b/TrainingExtensions/onnx/src/python/aimet_onnx/batch_norm_fold.py
index a4edc0de38d..df949fe32a0 100644
--- a/TrainingExtensions/onnx/src/python/aimet_onnx/batch_norm_fold.py
+++ b/TrainingExtensions/onnx/src/python/aimet_onnx/batch_norm_fold.py
@@ -41,6 +41,7 @@
import numpy as np
import onnx
from onnx import numpy_helper
+from onnxruntime.quantization.onnx_quantizer import ONNXModel
from packaging import version
from aimet_common.bias_correction import ConvBnPatternHandler
@@ -208,6 +209,8 @@ def fold_all_batch_norms_to_weight(model: ModelProto) -> [List]:
:param model: onnx Model to perform BN fold on
:return: A list of pairs of layers [(Conv/Linear, BN layer that got folded)]
"""
+ if isinstance(model, ONNXModel):
+ model = model.model
connected_graph = ConnectedGraph(model)
model = connected_graph.model
conv_bn_pairs, bn_conv_pairs = find_all_batch_norms_to_fold(connected_graph)
diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py
index 4c43f73e884..6da3783cc07 100644
--- a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py
+++ b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py
@@ -37,7 +37,7 @@
""" Implementation for simulating models running on Quantized hardware """
import os
-from typing import Dict, List, Union
+from typing import Dict, List, Union, Tuple
import json
import numpy as np
import onnx
@@ -542,6 +542,21 @@ def _create_libpymo_encodings(encoding):
self.qc_quantize_op_dict[quantizer_name].use_symmetric_encodings = is_symmetric
self.qc_quantize_op_dict[quantizer_name].freeze_encodings()
+ def get_all_quantizers(self) -> Tuple[List, List]:
+ """
+ Returns all QcQuantizeOps through which TensorQuantizer's attributes can be accessed.
+ """
+ param_quantizers = []
+ activation_quantizers = []
+
+ for param in self.param_names:
+ param_quantizers.append(self.qc_quantize_op_dict[param])
+
+ for activation in self.activation_names:
+ activation_quantizers.append(self.qc_quantize_op_dict[activation])
+
+ return param_quantizers, activation_quantizers
+
def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, onnx_encoding_path: str):
"""
diff --git a/TrainingExtensions/onnx/test/python/models/models_for_tests.py b/TrainingExtensions/onnx/test/python/models/models_for_tests.py
index 6d10d63b598..7944cf36fae 100644
--- a/TrainingExtensions/onnx/test/python/models/models_for_tests.py
+++ b/TrainingExtensions/onnx/test/python/models/models_for_tests.py
@@ -1760,3 +1760,35 @@ def forward(self, x):
custom_opsets={"my_ops": 2})
model_onnx = ONNXModel(load_model('./simple_custom_model.onnx'))
return model_onnx
+
+
+def conv_relu_model():
+ class ConvReluModel(torch.nn.Module):
+ def __init__(self):
+ super(ConvReluModel, self).__init__()
+ self._conv_0 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=(3, 3), padding=1)
+ self._relu = torch.nn.ReLU()
+
+ def forward(self, x: torch.Tensor):
+ return self._relu(self._conv_0(x))
+
+ torch.manual_seed(10)
+ model = ConvReluModel().eval()
+ x = torch.randn((1, 3, 8, 8))
+
+ torch.onnx.export(model, # model being run
+ x, # model input (or a tuple for multiple inputs)
+ "./conv_relu.onnx", # where to save the model (can be a file or file-like object),
+ training=torch.onnx.TrainingMode.EVAL,
+ export_params=True, # store the trained parameter weights inside the model file
+ opset_version=12, # the ONNX version to export the model to
+ do_constant_folding=False, # whether to execute constant folding for optimization
+ input_names=['input'], # the model's input names
+ output_names=['output'],
+ dynamic_axes={
+ 'input': {0: 'batch_size'},
+ 'output': {0: 'batch_size'},
+ })
+
+ model = load_model('./conv_relu.onnx')
+ return model
diff --git a/TrainingExtensions/onnx/test/python/test_auto_quant_v2.py b/TrainingExtensions/onnx/test/python/test_auto_quant_v2.py
new file mode 100644
index 00000000000..baa90f90ae0
--- /dev/null
+++ b/TrainingExtensions/onnx/test/python/test_auto_quant_v2.py
@@ -0,0 +1,757 @@
+# -*- mode: python -*-
+# =============================================================================
+# @@-COPYRIGHT-START-@@
+#
+# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# 3. Neither the name of the copyright holder nor the names of its contributors
+# may be used to endorse or promote products derived from this software
+# without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# @@-COPYRIGHT-END-@@
+# =============================================================================
+
+import contextlib
+from dataclasses import dataclass
+import itertools
+from unittest.mock import patch, MagicMock
+import os
+from bs4 import BeautifulSoup
+import pytest
+import shutil
+from typing import Callable, Dict, List
+
+import onnx
+from onnxruntime.quantization.onnx_quantizer import ONNXModel
+import numpy as np
+
+from aimet_onnx.quantsim import QuantizationSimModel
+from aimet_onnx.adaround.adaround_weight import AdaroundParameters
+from aimet_onnx.auto_quant_v2 import AutoQuant, PtqResult
+
+from aimet_common.defs import QuantScheme, QuantizationDataType
+
+from models.models_for_tests import conv_relu_model
+
+
+@pytest.fixture(scope="function")
+def onnx_model():
+ model = ONNXModel(conv_relu_model())
+ setattr(model, 'applied_bn_folding', False)
+ setattr(model, 'applied_cle', False)
+ setattr(model, 'applied_adaround', False)
+ return model
+
+
+@pytest.fixture(scope="session")
+def dummy_input():
+ return {'input': np.random.randn(1, 3, 8, 8).astype(np.float32)}
+
+
+@pytest.fixture(scope="session")
+def unlabeled_data_loader():
+ data_loader = [np.random.randn(1, 3, 8, 8) for _ in range(10)]
+ return data_loader
+
+
+def assert_html(html_parsed, properties):
+ for id_, prop in properties.items():
+ elem = html_parsed.find(id=id_)
+ assert elem is not None
+ for prop_name, prop_val in prop.items():
+ if prop_val is None:
+ assert prop_name not in elem.attrs
+ else:
+ assert elem[prop_name] == prop_val
+
+
+_VISITED = { 'data-visited': 'true', }
+_NOT_VISITED = { 'data-visited': None, }
+_SUCCESS = {
+ 'data-visited': 'true',
+ 'data-stage-result': 'success'
+}
+_DISCARDED = {
+ 'data-visited': 'true',
+ 'data-stage-result': 'discarded'
+}
+_ERROR_IGNORED = {
+ 'data-visited': 'true',
+ 'data-stage-result': 'error-ignored'
+}
+_ERROR_FAILED = {
+ 'data-visited': 'true',
+ 'data-stage-result': 'error-failed'
+}
+
+
+def assert_applied_techniques(
+ output_model, acc, encoding_path,
+ target_acc, bn_folded_acc, cle_acc, adaround_acc,
+ results_dir,
+):
+ html_path = os.path.join(results_dir, 'diagnostics.html')
+ with open(html_path) as f:
+ html_parsed = BeautifulSoup(f.read(), features="html.parser")
+
+ # Batchnorm folding is always applied.
+ assert output_model.applied_bn_folding
+ assert_html(html_parsed, {
+ 'node_batchnorm_folding': _SUCCESS,
+ 'node_test_batchnorm_folding': _VISITED,
+ })
+
+ # If accuracy is good enough after batchnorm folding
+ if bn_folded_acc >= target_acc:
+ assert acc == bn_folded_acc
+ assert encoding_path.endswith("batchnorm_folding.encodings")
+ assert not output_model.applied_cle
+ assert not output_model.applied_adaround
+
+ assert_html(html_parsed, {
+ 'node_cross_layer_equalization': _NOT_VISITED,
+ 'node_test_cross_layer_equalization': _NOT_VISITED,
+ 'node_adaround': _NOT_VISITED,
+ 'node_test_adaround': _NOT_VISITED,
+ 'node_result_fail': _NOT_VISITED,
+ 'node_result_success': _VISITED,
+ })
+ return
+
+ # CLE should be applied if and only if it brings accuracy gain
+ assert output_model.applied_cle == (bn_folded_acc < cle_acc)
+
+ assert_html(html_parsed, {
+ 'node_cross_layer_equalization': _SUCCESS if output_model.applied_cle else _DISCARDED,
+ 'node_test_cross_layer_equalization': _VISITED,
+ })
+
+ # If accuracy is good enough after cle
+ if cle_acc >= target_acc:
+ assert acc == cle_acc
+ assert encoding_path.endswith("cross_layer_equalization.encodings")
+ assert output_model.applied_cle
+ assert not output_model.applied_adaround
+
+ assert_html(html_parsed, {
+ 'node_adaround': _NOT_VISITED,
+ 'node_test_adaround': _NOT_VISITED,
+ 'node_result_fail': _NOT_VISITED,
+ 'node_result_success': _VISITED,
+ })
+ return
+
+ assert output_model.applied_adaround == (adaround_acc >= max(bn_folded_acc, cle_acc))
+
+ assert_html(html_parsed, {
+ 'node_adaround': _SUCCESS if output_model.applied_adaround else _DISCARDED,
+ 'node_test_adaround': _VISITED,
+ })
+
+ # If accuracy is good enough after adaround
+ if adaround_acc >= target_acc:
+ assert acc == adaround_acc
+ assert encoding_path.endswith("adaround.encodings")
+ assert output_model.applied_adaround
+
+ assert_html(html_parsed, {
+ 'node_result_fail': _NOT_VISITED,
+ 'node_result_success': _VISITED,
+ })
+ return
+
+ assert_html(html_parsed, {
+ 'node_result_fail': _VISITED,
+ 'node_result_success': _NOT_VISITED,
+ })
+
+ assert acc == max(bn_folded_acc, cle_acc, adaround_acc)
+
+ if max(bn_folded_acc, cle_acc, adaround_acc) == bn_folded_acc:
+ assert encoding_path.endswith("batchnorm_folding.encodings")
+ elif max(bn_folded_acc, cle_acc, adaround_acc) == cle_acc:
+ assert encoding_path.endswith("cross_layer_equalization.encodings")
+ else:
+ assert encoding_path.endswith("adaround.encodings")
+
+
+FP32_ACC = 0.8
+W32_ACC = FP32_ACC # Assume W32 accuracy is equal to FP32 accuracy
+RAW_QUANTSIM_ACC = 0.1
+
+
+@contextlib.contextmanager
+def patch_ptq_techniques(bn_folded_acc, cle_acc, adaround_acc, fp32_acc=None, w32_acc=None, raw_quantsim_acc=None):
+ if fp32_acc is None:
+ fp32_acc = FP32_ACC
+
+ if w32_acc is None:
+ w32_acc = W32_ACC
+
+ if raw_quantsim_acc is None:
+ raw_quantsim_acc = RAW_QUANTSIM_ACC
+
+ def bn_folding(model, *_, **__):
+ setattr(model, 'applied_bn_folding', True)
+ return [], []
+
+ def cle(model, *_, **__):
+ setattr(model, 'applied_bn_folding', True)
+ setattr(model, 'applied_cle', True)
+
+ def adaround(sim, model, *_, **__):
+ setattr(model, 'applied_adaround', True)
+ return model
+
+ class _PtqResult(PtqResult):
+
+ def load_model(self) -> ONNXModel:
+ model = super().load_model()
+ bnf_val = True if "batchnorm_folding" in self.applied_techniques else False
+ cle_val = True if "cross_layer_equalization" in self.applied_techniques else False
+ if cle_val:
+ bnf_val = True
+ ada_val = True if "adaround" in self.applied_techniques else False
+ model.__setattr__("applied_bn_folding", bnf_val)
+ model.__setattr__("applied_cle", cle_val)
+ model.__setattr__("applied_adaround", ada_val)
+ return model
+
+ class _QuantizationSimModel(QuantizationSimModel):
+ def __init__(self,
+ model: onnx.ModelProto,
+ dummy_input: Dict[str, np.ndarray] = None,
+ quant_scheme: QuantScheme = QuantScheme.post_training_tf_enhanced,
+ rounding_mode: str = 'nearest',
+ default_param_bw: int = 8,
+ default_activation_bw: int = 8,
+ use_symmetric_encodings: bool = False, use_cuda: bool = True,
+ device: int = 0, config_file: str = None, default_data_type: QuantizationDataType = QuantizationDataType.int,
+ user_onnx_libs: List[str] = None):
+ super(_QuantizationSimModel, self).__init__(model, dummy_input, quant_scheme, rounding_mode, default_param_bw, default_activation_bw,
+ use_symmetric_encodings, use_cuda, device, config_file, default_data_type, user_onnx_libs)
+
+ self.session = {'applied_bn_folding': getattr(model, 'applied_bn_folding'),
+ 'applied_cle': getattr(model, 'applied_cle'),
+ 'applied_adaround': getattr(model, 'applied_adaround'),
+ 'qsim_w32': True if default_param_bw == 32 else False}
+
+ def compute_encodings(self, *_):
+ pass
+
+ def set_and_freeze_param_encodings(self, _):
+ pass
+
+ def mock_eval_callback(session, _):
+ if isinstance(session, MagicMock):
+ # Not quantized: return fp32 accuracy
+ return fp32_acc
+ if session['qsim_w32']:
+ # W32 evaluation for early exit. Return W32 accuracy
+ return w32_acc
+
+ acc = raw_quantsim_acc
+ if session['applied_bn_folding']:
+ acc = bn_folded_acc
+ if session['applied_cle']:
+ acc = cle_acc
+ if session['applied_adaround']:
+ acc = adaround_acc
+ return acc
+
+ @dataclass
+ class Mocks:
+ eval_callback: Callable
+ QuantizationSimModel: MagicMock
+ fold_all_batch_norms: MagicMock
+ equalize_model: MagicMock
+ apply_adaround: MagicMock
+ PtqResult: MagicMock
+
+ with patch("aimet_onnx.auto_quant_v2.QuantizationSimModel", side_effect=_QuantizationSimModel) as mock_qsim,\
+ patch("aimet_onnx.auto_quant_v2.fold_all_batch_norms_to_weight", side_effect=bn_folding) as mock_bn_folding,\
+ patch("aimet_onnx.auto_quant_v2.equalize_model", side_effect=cle) as mock_cle,\
+ patch("aimet_onnx.auto_quant_v2.Adaround._apply_adaround", side_effect=adaround) as mock_adaround, \
+ patch("aimet_onnx.auto_quant_v2.PtqResult", side_effect=_PtqResult) as mock_ptq:
+ try:
+ yield Mocks(
+ eval_callback=mock_eval_callback,
+ QuantizationSimModel=mock_qsim,
+ fold_all_batch_norms=mock_bn_folding,
+ equalize_model=mock_cle,
+ apply_adaround=mock_adaround,
+ PtqResult=mock_ptq
+ )
+ finally:
+ pass
+
+
+class TestAutoQuant:
+ def test_auto_quant_run_inference(self, onnx_model, dummy_input, unlabeled_data_loader):
+ bn_folded_acc = .5
+
+ with patch_ptq_techniques(
+ bn_folded_acc, None, None
+ ) as mocks:
+ with create_tmp_directory() as results_dir:
+ auto_quant = AutoQuant(onnx_model,
+ dummy_input,
+ unlabeled_data_loader,
+ mocks.eval_callback,
+ eval_callback_args=None,
+ results_dir=results_dir)
+ auto_quant.run_inference()
+
+ @pytest.mark.parametrize(
+ "bn_folded_acc, cle_acc, adaround_acc",
+ itertools.permutations([.5, .6, .7])
+ )
+ @pytest.mark.parametrize("allowed_accuracy_drop", [.05, .15])
+ def test_auto_quant(
+ self, onnx_model, dummy_input, unlabeled_data_loader,
+ allowed_accuracy_drop, bn_folded_acc, cle_acc, adaround_acc,
+ ):
+ self._test_auto_quant(
+ onnx_model, dummy_input, unlabeled_data_loader,
+ allowed_accuracy_drop, bn_folded_acc, cle_acc, adaround_acc,
+ )
+
+ def test_consecutive_calls(self, onnx_model, dummy_input, unlabeled_data_loader):
+ bn_folded_acc, cle_acc, adaround_acc = .5, .6, .7
+
+ with patch_ptq_techniques(
+ bn_folded_acc, cle_acc, adaround_acc
+ ) as mocks:
+ with create_tmp_directory() as results_dir:
+ auto_quant = AutoQuant(onnx_model,
+ dummy_input,
+ unlabeled_data_loader,
+ mocks.eval_callback,
+ eval_callback_args=None,
+ results_dir=results_dir)
+
+ # Should return proper model & summary report
+ # regardless of consecutive calls
+ for allowed_accuracy_drop in (.5, .4, .3, .2, .1, .05):
+ self._do_test_optimize_auto_quant(
+ auto_quant, onnx_model,
+ allowed_accuracy_drop, bn_folded_acc, cle_acc, adaround_acc
+ )
+
+ with patch_ptq_techniques(
+ bn_folded_acc, cle_acc, adaround_acc
+ ) as mocks:
+ with create_tmp_directory() as results_dir:
+ auto_quant = AutoQuant(onnx_model,
+ dummy_input,
+ unlabeled_data_loader,
+ mocks.eval_callback,
+ eval_callback_args=None,
+ results_dir=results_dir)
+
+ # When run_inference() and optimize() are called in back-to-back,
+ # reusable intermediate reseults should be always reused.
+ auto_quant.run_inference()
+ auto_quant.optimize()
+ assert mocks.fold_all_batch_norms.call_count == 2
+ assert mocks.equalize_model.call_count == 1
+
+ auto_quant.optimize()
+ assert mocks.fold_all_batch_norms.call_count == 2
+ assert mocks.equalize_model.call_count == 1
+
+ self._do_test_optimize_auto_quant(
+ auto_quant, onnx_model,
+ 0.0, bn_folded_acc, cle_acc, adaround_acc
+ )
+ assert mocks.fold_all_batch_norms.call_count == 2
+ assert mocks.equalize_model.call_count == 1
+
+ def _test_auto_quant(
+ self, model, dummy_input, unlabeled_data_loader,
+ allowed_accuracy_drop, bn_folded_acc, cle_acc, adaround_acc,
+ ):
+ with patch_ptq_techniques(
+ bn_folded_acc, cle_acc, adaround_acc
+ ) as mocks:
+ with create_tmp_directory() as results_dir:
+ auto_quant = AutoQuant(model,
+ dummy_input,
+ unlabeled_data_loader,
+ mocks.eval_callback,
+ eval_callback_args=None,
+ results_dir=results_dir)
+ self._do_test_optimize_auto_quant(
+ auto_quant, model, allowed_accuracy_drop,
+ bn_folded_acc, cle_acc, adaround_acc
+ )
+
+ def _do_test_optimize_auto_quant(
+ self, auto_quant, input_model,
+ allowed_accuracy_drop, bn_folded_acc, cle_acc, adaround_acc,
+ ):
+ target_acc = FP32_ACC - allowed_accuracy_drop
+
+ output_model, acc, encoding_path = auto_quant.optimize(allowed_accuracy_drop)
+
+ assert_applied_techniques(
+ output_model, acc, encoding_path,
+ target_acc, bn_folded_acc, cle_acc, adaround_acc,
+ auto_quant.results_dir,
+ )
+
+ def test_auto_quant_invalid_input(self, onnx_model, dummy_input, unlabeled_data_loader):
+ with pytest.raises(ValueError):
+ AutoQuant(None, dummy_input, unlabeled_data_loader, lambda: None)
+
+ with pytest.raises(ValueError):
+ AutoQuant(onnx_model, None, unlabeled_data_loader, lambda: None)
+
+ with pytest.raises(ValueError):
+ AutoQuant(onnx_model, dummy_input, None, lambda: None)
+
+ with pytest.raises(ValueError):
+ AutoQuant(onnx_model, dummy_input, unlabeled_data_loader, None)
+
+ with pytest.raises(ValueError):
+ AutoQuant(onnx_model, dummy_input, unlabeled_data_loader, lambda: None, results_dir=None)
+
+ with pytest.raises(ValueError):
+ AutoQuant(onnx_model, dummy_input, unlabeled_data_loader, lambda: None, strict_validation=None)
+
+ # Bitwidth < 4 or bitwidth > 32
+ with pytest.raises(ValueError):
+ AutoQuant(onnx_model, dummy_input, unlabeled_data_loader, lambda: None, param_bw=2)
+
+ with pytest.raises(ValueError):
+ AutoQuant(onnx_model, dummy_input, unlabeled_data_loader, lambda: None, param_bw=64)
+
+ with pytest.raises(ValueError):
+ AutoQuant(onnx_model, dummy_input, unlabeled_data_loader, lambda: None, output_bw=2)
+
+ with pytest.raises(ValueError):
+ AutoQuant(onnx_model, dummy_input, unlabeled_data_loader, lambda: None, output_bw=64)
+
+ auto_quant = AutoQuant(onnx_model, dummy_input, unlabeled_data_loader, lambda: None)
+ # Allowed accuracy drop < 0
+ with pytest.raises(ValueError):
+ _ = auto_quant.optimize(-1.0)
+
+ def test_auto_quant_inference_fallback(
+ self, onnx_model, dummy_input, unlabeled_data_loader,
+ ):
+ class _Exception(Exception):
+ pass
+
+ def error_fn(*_, **__):
+ raise _Exception
+
+ bn_folded_acc = .4
+ raw_quantsim_acc = bn_folded_acc + 1e-5
+ with patch_ptq_techniques(
+ bn_folded_acc, None, None, raw_quantsim_acc=raw_quantsim_acc
+ ) as mocks:
+ with create_tmp_directory() as results_dir:
+ auto_quant = AutoQuant(onnx_model,
+ dummy_input,
+ unlabeled_data_loader,
+ mocks.eval_callback,
+ results_dir=results_dir,
+ strict_validation=False)
+ with patch("aimet_onnx.auto_quant_v2.fold_all_batch_norms_to_weight", side_effect=error_fn):
+ # If BN folding fail, should return raw quantsim model
+ _, acc = auto_quant.run_inference()
+ assert acc == raw_quantsim_acc
+
+ def test_auto_quant_optimize_fallback(
+ self, onnx_model, dummy_input, unlabeled_data_loader,
+ ):
+ class _Exception(Exception):
+ pass
+
+ def error_fn(*_, **__):
+ raise _Exception
+
+ bn_folded_acc, cle_acc, adaround_acc = .4, .5, .6
+ with patch_ptq_techniques(
+ bn_folded_acc, cle_acc, adaround_acc
+ ) as mocks:
+ with create_tmp_directory() as results_dir:
+
+ auto_quant = AutoQuant(onnx_model,
+ dummy_input,
+ unlabeled_data_loader,
+ mocks.eval_callback,
+ results_dir=results_dir,
+ strict_validation=False)
+ with patch("aimet_onnx.auto_quant_v2.fold_all_batch_norms_to_weight", side_effect=error_fn):
+ # If batchnorm folding fails, should return Adaround results
+ _, acc, _ = auto_quant.optimize()
+ assert acc == adaround_acc
+
+ with open(os.path.join(results_dir, 'diagnostics.html')) as f:
+ html_parsed = BeautifulSoup(f.read(), features="html.parser")
+ assert_html(html_parsed, {
+ 'node_batchnorm_folding': _ERROR_IGNORED,
+ 'node_cross_layer_equalization': _SUCCESS,
+ 'node_adaround': _SUCCESS,
+ })
+
+ auto_quant = AutoQuant(onnx_model,
+ dummy_input,
+ unlabeled_data_loader,
+ mocks.eval_callback,
+ results_dir=results_dir,
+ strict_validation=False)
+ with patch("aimet_onnx.auto_quant_v2.equalize_model", side_effect=error_fn):
+ # If CLE fails, should return Adaround results
+ _, acc, _ = auto_quant.optimize()
+ assert acc == adaround_acc
+
+ with open(os.path.join(results_dir, 'diagnostics.html')) as f:
+ html_parsed = BeautifulSoup(f.read(), features="html.parser")
+ assert_html(html_parsed, {
+ 'node_batchnorm_folding': _SUCCESS,
+ 'node_cross_layer_equalization': _ERROR_IGNORED,
+ 'node_adaround': _SUCCESS,
+ })
+
+ auto_quant = AutoQuant(onnx_model,
+ dummy_input,
+ unlabeled_data_loader,
+ mocks.eval_callback,
+ results_dir=results_dir,
+ strict_validation=False)
+ with patch("aimet_onnx.auto_quant_v2.Adaround._apply_adaround", side_effect=error_fn):
+ # If adaround fails, should return CLE results
+ _, acc, _ = auto_quant.optimize()
+ assert acc == cle_acc
+
+ with open(os.path.join(results_dir, 'diagnostics.html')) as f:
+ html_parsed = BeautifulSoup(f.read(), features="html.parser")
+ assert_html(html_parsed, {
+ 'node_batchnorm_folding': _SUCCESS,
+ 'node_cross_layer_equalization': _SUCCESS,
+ 'node_adaround': _ERROR_IGNORED,
+ })
+
+ auto_quant = AutoQuant(onnx_model,
+ dummy_input,
+ unlabeled_data_loader,
+ mocks.eval_callback,
+ results_dir=results_dir,
+ strict_validation=False)
+ with patch("aimet_onnx.auto_quant_v2.fold_all_batch_norms_to_weight", side_effect=error_fn),\
+ patch("aimet_onnx.auto_quant_v2.equalize_model", side_effect=error_fn),\
+ patch("aimet_onnx.auto_quant_v2.Adaround._apply_adaround", side_effect=error_fn):
+ # If everything fails, should raise an error
+ with pytest.raises(RuntimeError):
+ auto_quant.optimize()
+
+ with open(os.path.join(results_dir, 'diagnostics.html')) as f:
+ html_parsed = BeautifulSoup(f.read(), features="html.parser")
+ assert_html(html_parsed, {
+ 'node_batchnorm_folding': _ERROR_IGNORED,
+ 'node_cross_layer_equalization': _ERROR_IGNORED,
+ 'node_adaround': _ERROR_IGNORED,
+ })
+
+ auto_quant = AutoQuant(onnx_model,
+ dummy_input,
+ unlabeled_data_loader,
+ mocks.eval_callback,
+ results_dir=results_dir,
+ strict_validation=True)
+ with patch("aimet_onnx.auto_quant_v2.equalize_model", side_effect=error_fn):
+ # Hard stop if strict_validation=True
+ with pytest.raises(_Exception):
+ auto_quant.optimize()
+
+ with open(os.path.join(results_dir, 'diagnostics.html')) as f:
+ html_parsed = BeautifulSoup(f.read(), features="html.parser")
+ assert_html(html_parsed, {
+ 'node_batchnorm_folding': _SUCCESS,
+ 'node_cross_layer_equalization': _ERROR_FAILED,
+ 'node_adaround': _NOT_VISITED,
+ })
+
+ def test_auto_quant_early_exit(self, onnx_model, dummy_input, unlabeled_data_loader):
+ allowed_accuracy_drop = 0.1
+ w32_acc = FP32_ACC - (allowed_accuracy_drop * 2)
+
+ with create_tmp_directory() as results_dir:
+ with patch_ptq_techniques(
+ bn_folded_acc=0, cle_acc=0, adaround_acc=0, w32_acc=w32_acc
+ ) as mocks:
+ auto_quant = AutoQuant(onnx_model,
+ dummy_input,
+ unlabeled_data_loader,
+ mocks.eval_callback,
+ results_dir=results_dir)
+ output_model, acc, encoding_path = auto_quant.optimize(allowed_accuracy_drop)
+
+ assert output_model is None
+ assert acc is None
+ assert encoding_path is None
+
+ with open(os.path.join(results_dir, 'diagnostics.html')) as f:
+ html_parsed = BeautifulSoup(f.read(), features="html.parser")
+ assert_html(html_parsed, {
+ 'node_test_w32_eval_score': _VISITED,
+ 'node_batchnorm_folding': _NOT_VISITED,
+ 'node_cross_layer_equalization': _NOT_VISITED,
+ 'node_adaround': _NOT_VISITED,
+ 'node_result_fail': _VISITED,
+ })
+
+ def test_auto_quant_caching(
+ self, onnx_model, dummy_input, unlabeled_data_loader,
+ ):
+ allowed_accuracy_drop = 0.0
+ bn_folded_acc, cle_acc, adaround_acc = .4, .5, .6
+ cache_id = "unittest"
+
+ with patch_ptq_techniques(
+ bn_folded_acc, cle_acc, adaround_acc
+ ) as mocks:
+ with create_tmp_directory() as results_dir:
+ auto_quant = AutoQuant(onnx_model,
+ dummy_input,
+ unlabeled_data_loader,
+ mocks.eval_callback,
+ results_dir=results_dir,
+ cache_id=cache_id)
+
+ cache_files = [
+ os.path.join(results_dir, ".auto_quant_cache", cache_id, f"{key}.pkl")
+ for key in ("batchnorm_folding", "cle", "adaround")
+ ]
+
+ # No previously cached results
+ auto_quant.optimize(allowed_accuracy_drop)
+
+ for cache_file in cache_files:
+ assert os.path.exists(cache_file)
+
+ assert mocks.fold_all_batch_norms.call_count == 1
+ assert mocks.equalize_model.call_count == 1
+ assert mocks.apply_adaround.call_count == 1
+
+ auto_quant = AutoQuant(onnx_model,
+ dummy_input,
+ unlabeled_data_loader,
+ mocks.eval_callback,
+ results_dir=results_dir,
+ cache_id=cache_id)
+ # Load cached result
+ auto_quant.optimize(allowed_accuracy_drop)
+
+ # PTQ functions should not be called twice.
+ assert mocks.fold_all_batch_norms.call_count == 1
+ assert mocks.equalize_model.call_count == 1
+ assert mocks.apply_adaround.call_count == 1
+
+ def test_auto_quant_scheme_selection(
+ self, onnx_model, dummy_input, unlabeled_data_loader,
+ ):
+ allowed_accuracy_drop = 0.0
+ bn_folded_acc, cle_acc, adaround_acc = .4, .5, .6
+ with patch_ptq_techniques(
+ bn_folded_acc, cle_acc, adaround_acc
+ ) as mocks:
+ def eval_callback(session, _):
+ # Assumes the model's eval score drops to zero
+ # unless param_quant_scheme == tfe and output_quant_scheme == tf
+ if isinstance(session, MagicMock):
+ return mocks.eval_callback(session, _)
+ if 'param_quant_scheme' in session and session['param_quant_scheme'] != QuantScheme.post_training_tf_enhanced:
+ return 0.0
+ if 'output_quant_scheme' in session and session['output_quant_scheme'] != QuantScheme.post_training_tf:
+ return 0.0
+ return mocks.eval_callback(session, _)
+
+ _optimize = AutoQuant.optimize
+ def optimize(self, *args, **kwargs):
+ # Since all the other candidates (tf-tf, tfe-tfe, and tfe-percentile) yields zero accuracy,
+ # it is expected that tf-tfe is selected as the quant scheme for AutoQuant.
+ ret = _optimize(self, *args, **kwargs)
+ assert self._quantsim_params["quant_scheme"].param_quant_scheme == QuantScheme.post_training_tf_enhanced
+ assert self._quantsim_params["quant_scheme"].output_quant_scheme == QuantScheme.post_training_tf
+ return ret
+
+ _mock_create_quantsim_and_encodings = AutoQuant._create_quantsim_and_encodings
+ def mock_create_quantsim_and_encodings(self, *args, **kwargs):
+ sim = _mock_create_quantsim_and_encodings(self, *args, **kwargs)
+ if 'output_quant_scheme' in kwargs:
+ sim.session['output_quant_scheme'] = kwargs['output_quant_scheme']
+ if 'param_quant_scheme' in kwargs:
+ sim.session['param_quant_scheme'] = kwargs['param_quant_scheme']
+ return sim
+
+ with patch("aimet_onnx.auto_quant_v2.AutoQuant.optimize", optimize), \
+ patch("aimet_onnx.auto_quant_v2.AutoQuant._create_quantsim_and_encodings", mock_create_quantsim_and_encodings):
+ auto_quant = AutoQuant(onnx_model,
+ dummy_input,
+ unlabeled_data_loader,
+ eval_callback)
+ auto_quant.optimize(allowed_accuracy_drop)
+
+ def test_set_additional_params(self, onnx_model, dummy_input, unlabeled_data_loader):
+ allowed_accuracy_drop = 0
+ bn_folded_acc = .1
+ cle_acc = .2
+ adaround_acc = .3
+ with patch_ptq_techniques(bn_folded_acc, cle_acc, adaround_acc) as mocks:
+ auto_quant = AutoQuant(onnx_model,
+ dummy_input,
+ unlabeled_data_loader,
+ mocks.eval_callback)
+ adaround_params = AdaroundParameters(unlabeled_data_loader, 1)
+ auto_quant.set_adaround_params(adaround_params)
+ self._do_test_optimize_auto_quant(
+ auto_quant, onnx_model,
+ allowed_accuracy_drop, bn_folded_acc, cle_acc, adaround_acc
+ )
+ adaround_args, _ = mocks.apply_adaround.call_args
+ _, _, actual_adaround_params = adaround_args
+ assert adaround_params == actual_adaround_params
+
+
+@contextlib.contextmanager
+def create_tmp_directory(dirname: str = "/tmp/.aimet_unittest"):
+ success = False
+ try:
+ os.makedirs(dirname, exist_ok=True)
+ success = True
+ except FileExistsError:
+ raise
+
+ try:
+ yield dirname
+ finally:
+ if success:
+ shutil.rmtree(dirname)