Skip to content

Commit

Permalink
API Doc for ONNX Auto Quant V2 (#2617)
Browse files Browse the repository at this point in the history
Signed-off-by: Raj Gite <quic_rgite@quicinc.com>
  • Loading branch information
quic-rgite authored Dec 26, 2023
1 parent de7614c commit 229ec6c
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 14 deletions.
23 changes: 23 additions & 0 deletions Docs/api_docs/onnx_auto_quant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
:orphan:

.. _api-onnx-auto-quant:

===========================
AIMET ONNX AutoQuant API
===========================

User Guide Link
===============
To learn more about this technique, please see :ref:`AutoQuant<ug-auto-quant>`

Top-level API
=============
.. autoclass:: aimet_onnx.auto_quant_v2.AutoQuant
:members:
:member-order: bysource

Code Examples
===============
.. literalinclude:: ../onnx_code_examples/auto_quant_v2.py
:language: python
:lines: 40-
10 changes: 10 additions & 0 deletions Docs/api_docs/onnx_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
AIMET ONNX Quantization APIs
===============================

.. toctree::
:titlesonly:
:hidden:

Quantization Simulation API<onnx_quantsim>
Cross-Layer Equalization API<onnx_cross_layer_equalization>
Adaptive Rounding API<onnx_adaround>
AutoQuant API<onnx_auto_quant>

AIMET Quantization for ONNX Models provides the following functionality.
- :ref:`Quantization Simulation API<api-onnx-quantsim>`: Allows ability to simulate inference on quantized hardware
- :ref:`Cross-Layer Equalization API<api-onnx-cle>`: Post-training quantization technique to equalize layer parameters
- :ref:`Adaround API<api-onnx-adaround>`: Post-training quantization technique to optimize rounding of weight tensors
- :ref:`AutoQuant API<api-onnx-auto-quant>`: Unified API that integrates the post-training quantization techniques provided by AIMET
109 changes: 109 additions & 0 deletions Docs/onnx_code_examples/auto_quant_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================

""" Code example for AutoQuantV2 """

import math
import onnxruntime as ort
import numpy as np

from aimet_onnx.auto_quant_v2 import AutoQuant
from aimet_onnx.adaround.adaround_weight import AdaroundParameters

# Step 1. Define constants
EVAL_DATASET_SIZE = 5000
CALIBRATION_DATASET_SIZE = 500
BATCH_SIZE = 32

# Step 2. Prepare model and dataloader
onnx_model = Model()

input_shape = (1, 3, 224, 224)
dummy_data = np.random.randn(*input_shape).astype(np.float32)
dummy_input = {'input': dummy_data}

# NOTE: Use your dataloader. It should iterate over unlabelled dataset.
# Its data will be directly fed as input to the onnx model's inference session.
unlabelled_data_loader = DataLoader(data=data, batch_size=BATCH_SIZE,
iterations=math.ceil(CALIBRATION_DATASET_SIZE / BATCH_SIZE))

# Step 3. Prepare eval callback
# NOTE: In the actual use cases, the users should implement this part to serve
# their own goals, maintaining the function signature.
def eval_callback(session: ort.InferenceSession, num_of_samples: Optional[int] = None) -> float:
data_loader = EvalDataLoader()
if num_of_samples:
iterations = math.ceil(num_of_samples / data_loader.batch_size)
else:
iterations = len(data_loader)
batch_cntr = 1
acc_top1 = 0
acc_top5 = 0
for input_data, target in data_loader:
pred = session.run(None, {'input': input_data})

batch_avg_top_1_5 = accuracy(pred, target, topk=(1, 5))

acc_top1 += batch_avg_top_1_5[0].item()
acc_top5 += batch_avg_top_1_5[1].item()

batch_cntr += 1
if batch_cntr > iterations:
break
acc_top1 /= iterations
acc_top5 /= iterations
return acc_top1

# Step 4. Create AutoQuant object
auto_quant = AutoQuant(onnx_model,
dummy_input,
unlabelled_data_loader,
eval_callback)

# Step 5. (Optional) Set AdaRound params
ADAROUND_DATASET_SIZE = 2000
adaround_data_loader = DataLoader(data=data, batch_size=BATCH_SIZE,
iterations=math.ceil(ADAROUND_DATASET_SIZE / BATCH_SIZE))
adaround_params = AdaroundParameters(adaround_data_loader, num_batches=len(adaround_data_loader))
auto_quant.set_adaround_params(adaround_params)

# Step 6. Run AutoQuant
sim, initial_accuracy = auto_quant.run_inference()
model, optimized_accuracy, encoding_path = auto_quant.optimize(allowed_accuracy_drop=0.01)

print(f"- Quantized Accuracy (before optimization): {initial_accuracy:.4f}")
print(f"- Quantized Accuracy (after optimization): {optimized_accuracy:.4f}")
28 changes: 18 additions & 10 deletions TrainingExtensions/onnx/src/python/aimet_onnx/auto_quant_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@

cache = Cache()

# The number of samples to be used for performance evaluation.
# NOTE: None means "all".
NUM_SAMPLES_FOR_PERFORMANCE_EVALUATION = None


@dataclass(frozen=True)
class _QuantSchemePair:
Expand Down Expand Up @@ -195,8 +199,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
model: Union[onnx.ModelProto, ONNXModel],
dummy_input: Dict[str, np.ndarray],
data_loader: Iterable[Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray]]],
eval_callback: Callable[[ort.InferenceSession], float],
eval_callback_args=None,
eval_callback: Callable[[ort.InferenceSession, int], float],
param_bw: int = 8,
output_bw: int = 8,
quant_scheme: QuantScheme = QuantScheme.post_training_tf_enhanced,
Expand All @@ -208,11 +211,10 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
cache_id: str = None,
strict_validation: bool = True) -> None:
'''
:param model: Model to be quantized. Assumes model is on the correct device
:param dummy_input: Dummy input for the model. Assumes that dummy_input is on the correct device
:param model: Model to be quantized.
:param dummy_input: Dummy input dict for the model.
:param data_loader: A collection that iterates over an unlabeled dataset, used for computing encodings
:param eval_callback: Function that calculates the evaluation score given the model session
:param eval_callback_args: Extra arguments for eval_callback
:param param_bw: Parameter bitwidth
:param output_bw: Output bitwidth
:param quant_scheme: Quantization scheme
Expand All @@ -234,7 +236,6 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
self.dummy_input = dummy_input
self.data_loader = data_loader
self.eval_callback = eval_callback
self.eval_callback_args = eval_callback_args

self._quantsim_params = dict(
param_bw=param_bw,
Expand Down Expand Up @@ -281,7 +282,7 @@ def _evaluate_model_performance(self, session) -> float:
"""
Evaluate the model performance.
"""
return self.eval_callback(session, self.eval_callback_args)
return self.eval_callback(session, NUM_SAMPLES_FOR_PERFORMANCE_EVALUATION)

def run_inference(self) -> Tuple[QuantizationSimModel, float]:
'''
Expand Down Expand Up @@ -417,9 +418,6 @@ def _create_quantsim_and_encodings( # pylint: disable=too-many-arguments, too-ma
if encoding_path:
sim.set_and_freeze_param_encodings(encoding_path)

# TODO: Other frameworks had this second call to fetch tensor quantizers. Need to check if also required for ONNX.
# param_quantizers, activation_quantizers = sim.get_all_quantizers()

# Disable activation quantizers, using fp32 to simulate int32.
if output_bw == 32:
for quantizer in activation_quantizers:
Expand Down Expand Up @@ -553,6 +551,16 @@ def get_quant_scheme_candidates(self) -> Tuple[_QuantSchemePair, ...]:
"""
return self._quant_scheme_candidates

def set_quant_scheme_candidates(self, candidates: Tuple[_QuantSchemePair, ...]):
"""
Set candidates for quant scheme search.
During :meth:`~AutoQuant.optimize`, the candidate with the highest accuracy
will be selected among them.
:param candidates: Candidates for quant scheme search
"""
self._quant_scheme_candidates = copy.copy(candidates)

def _choose_default_quant_scheme(self):
def eval_fn(pair: _QuantSchemePair):
sim = self._create_quantsim_and_encodings(
Expand Down
4 changes: 0 additions & 4 deletions TrainingExtensions/onnx/test/python/test_auto_quant_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ def test_auto_quant_run_inference(self, onnx_model, dummy_input, unlabeled_data_
dummy_input,
unlabeled_data_loader,
mocks.eval_callback,
eval_callback_args=None,
results_dir=results_dir)
auto_quant.run_inference()

Expand Down Expand Up @@ -351,7 +350,6 @@ def test_consecutive_calls(self, onnx_model, dummy_input, unlabeled_data_loader)
dummy_input,
unlabeled_data_loader,
mocks.eval_callback,
eval_callback_args=None,
results_dir=results_dir)

# Should return proper model & summary report
Expand All @@ -370,7 +368,6 @@ def test_consecutive_calls(self, onnx_model, dummy_input, unlabeled_data_loader)
dummy_input,
unlabeled_data_loader,
mocks.eval_callback,
eval_callback_args=None,
results_dir=results_dir)

# When run_inference() and optimize() are called in back-to-back,
Expand Down Expand Up @@ -403,7 +400,6 @@ def _test_auto_quant(
dummy_input,
unlabeled_data_loader,
mocks.eval_callback,
eval_callback_args=None,
results_dir=results_dir)
self._do_test_optimize_auto_quant(
auto_quant, model, allowed_accuracy_drop,
Expand Down

0 comments on commit 229ec6c

Please sign in to comment.