diff --git a/TrainingExtensions/common/src/python/aimet_common/defs.py b/TrainingExtensions/common/src/python/aimet_common/defs.py index 9868ca08de5..52cd626a197 100644 --- a/TrainingExtensions/common/src/python/aimet_common/defs.py +++ b/TrainingExtensions/common/src/python/aimet_common/defs.py @@ -399,3 +399,11 @@ def __init__(self, func: Callable, func_callback_args=None): """ self.func = func self.args = func_callback_args + +class EncodingType(Enum): + """ Encoding type """ + PER_TENSOR = 0 + PER_CHANNEL = 1 + PER_BLOCK = 2 + LPBQ = 3 + VECTOR = 4 diff --git a/TrainingExtensions/common/src/python/aimet_common/quantsim.py b/TrainingExtensions/common/src/python/aimet_common/quantsim.py index d405e68109a..40e744339cd 100644 --- a/TrainingExtensions/common/src/python/aimet_common/quantsim.py +++ b/TrainingExtensions/common/src/python/aimet_common/quantsim.py @@ -54,6 +54,7 @@ # The patching version shall be updated to indicate minor updates to quantization simulation e.g. bug fix etc. encoding_version = '0.6.1' ALLOW_EXPERIMENTAL = False +VALID_ENCODING_VERSIONS = {'0.6.1', '1.0.0'} def gate_min_max(min_val: float, max_val: float) -> Tuple[float, float]: diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_weight.py b/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_weight.py index 7a59c374d4e..f6c15c79fca 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_weight.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_weight.py @@ -48,6 +48,7 @@ from tqdm import tqdm # Import AIMET specific modules +from aimet_common import quantsim from aimet_common.utils import AimetLogger from aimet_common.defs import QuantScheme, QuantizationDataType @@ -335,7 +336,7 @@ def _export_encodings_to_json(cls, path: str, filename_prefix: str, quant_sim: Q :param quant_sim: QunatSim that contains the model and Adaround tensor quantizers """ # pylint: disable=protected-access - param_encodings = quant_sim._get_encodings(quant_sim.param_names) + param_encodings = quant_sim._get_encodings(quant_sim.param_names, quantsim.encoding_version) # export encodings to JSON file os.makedirs(os.path.abspath(path), exist_ok=True) diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py b/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py index 117040f6232..6873a783a36 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py @@ -36,10 +36,10 @@ # ============================================================================= """ Custom QcQuantizeOp to quantize weights and activations using ONNXRuntime """ -from typing import Union, List, Optional +from typing import Union, List, Optional, Dict import aimet_common.libpymo as libpymo from aimet_common.libpymo import TensorQuantizerOpMode -from aimet_common.defs import QuantScheme, MAP_QUANT_SCHEME_TO_PYMO, MAP_ROUND_MODE_TO_PYMO, QuantizationDataType +from aimet_common.defs import QuantScheme, MAP_QUANT_SCHEME_TO_PYMO, MAP_ROUND_MODE_TO_PYMO, QuantizationDataType, EncodingType from aimet_common import libquant_info from aimet_common.utils import deprecated @@ -462,6 +462,9 @@ def export_encodings(self, encoding_version: str = "0.6.1"): if encoding_version == '0.6.1': return self._export_legacy_encodings() + if encoding_version == "1.0.0": + return self._export_1_0_0_encodings() + raise RuntimeError(f"Unsupported encoding export version: {encoding_version}") def _export_legacy_encodings(self) -> Union[List, None]: @@ -490,3 +493,32 @@ def _export_legacy_encodings(self) -> Union[List, None]: return encodings raise RuntimeError(f"Exporting data type {self.data_type} not supported") + + def _encoding_type(self): + if not self.quant_info.usePerChannelMode: + return EncodingType.PER_TENSOR + if not self.quant_info.blockSize: + return EncodingType.PER_CHANNEL + return EncodingType.PER_BLOCK + + def _export_1_0_0_encodings(self) -> Optional[Dict]: + """ + Exports the quantizer's encodings in the "1.0.0" encoding format + """ + if not self.enabled or not self.is_initialized(): + return None + + enc_dict = dict(enc_type=self._encoding_type().name, + dtype="INT" if self.data_type == QuantizationDataType.int else "FLOAT", + bw=self.bitwidth, + ) + + if self.data_type == QuantizationDataType.int: + enc_dict["is_sym"] = self.use_symmetric_encodings + encodings = self.get_encodings() + enc_dict["scale"] = [enc.delta for enc in encodings] + enc_dict["offset"] = [enc.offset for enc in encodings] + if self.quant_info.blockSize > 0: + enc_dict["block_size"] = self.quant_info.blockSize + + return enc_dict diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py index b5c3a5d226f..93ef119083d 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py @@ -53,10 +53,10 @@ from packaging import version # pylint: disable=wrong-import-order -from aimet_common import libpymo +from aimet_common import libpymo, quantsim from aimet_common import libquant_info from aimet_common.defs import QuantScheme, QuantizationDataType -from aimet_common.quantsim import encoding_version, extract_global_quantizer_args +from aimet_common.quantsim import extract_global_quantizer_args, VALID_ENCODING_VERSIONS from aimet_common.utils import save_json_yaml, AimetLogger from aimet_onnx import utils from aimet_onnx.meta.operations import Op @@ -676,25 +676,35 @@ def compute_encodings(self, forward_pass_callback, forward_pass_callback_args): qc_op.compute_encodings() qc_op.op_mode = OpMode.quantizeDequantize - def _get_encodings(self, quantizer_names) -> Dict: + def _get_encodings(self, quantizer_names, enc_version): encoding_dict = {} for name in quantizer_names: - encoding = self.qc_quantize_op_dict[name].export_encodings(encoding_version) + encoding = self.qc_quantize_op_dict[name].export_encodings(enc_version) if encoding is None: continue encoding_dict[name] = encoding - return encoding_dict - def _export_encodings(self, encoding_file_path): + if version.parse(enc_version) < version.parse("1.0.0"): + return encoding_dict + + for name, encoding in encoding_dict.items(): + encoding["name"] = name + return list(encoding_dict.values()) + + def _export_encodings(self, encoding_file_path, enc_version): """ Export encodings to json and yaml file :param encoding_file_path: path to save the encoding files """ - param_encodings = self._get_encodings(self.param_names) - activation_encodings = self._get_encodings(self.activation_names) + if enc_version not in VALID_ENCODING_VERSIONS: + raise NotImplementedError(f'Encoding version {enc_version} not in set of valid encoding ' + f'versions {VALID_ENCODING_VERSIONS}.') + + param_encodings = self._get_encodings(self.param_names, enc_version) + activation_encodings = self._get_encodings(self.activation_names, enc_version) - encodings_dict = {'version': encoding_version, + encodings_dict = {'version': enc_version, 'activation_encodings': activation_encodings, 'param_encodings': param_encodings, 'quantizer_args': self.quant_args} @@ -733,7 +743,7 @@ def export(self, path: str, filename_prefix: str): :param path: dir to save encoding files :param filename_prefix: filename to save encoding files """ - self._export_encodings(os.path.join(path, filename_prefix) + '.encodings') + self._export_encodings(os.path.join(path, filename_prefix) + '.encodings', quantsim.encoding_version) self.remove_quantization_nodes() if self.model.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF: # Note: Saving as external data mutates the saved model, removing all initializer data diff --git a/TrainingExtensions/onnx/test/python/test_qc_quantize_op.py b/TrainingExtensions/onnx/test/python/test_qc_quantize_op.py index f455408b072..8cd8e94f609 100644 --- a/TrainingExtensions/onnx/test/python/test_qc_quantize_op.py +++ b/TrainingExtensions/onnx/test/python/test_qc_quantize_op.py @@ -42,7 +42,7 @@ import os import pytest from aimet_common import libpymo -from aimet_common.defs import QuantScheme, MAP_QUANT_SCHEME_TO_PYMO, MAP_ROUND_MODE_TO_PYMO, QuantizationDataType +from aimet_common.defs import QuantScheme, MAP_QUANT_SCHEME_TO_PYMO, MAP_ROUND_MODE_TO_PYMO, QuantizationDataType, EncodingType from aimet_onnx.qc_quantize_op import QcQuantizeOp, OpMode, TensorQuantizerParams from aimet_common import libquant_info from aimet_common.quantsim import calculate_delta_offset @@ -812,7 +812,7 @@ def test_export_per_tensor_int_encodings(self, symmetric, bitwidth, delta, offse encoding.bw = bitwidth encoding.offset = offset encoding.delta = delta - qc_quantize_op.load_encodings([encoding]) + qc_quantize_op.update_quantizer_and_load_encodings([encoding], symmetric, False, False, QuantizationDataType.int) exported_encodings = qc_quantize_op.export_encodings("0.6.1") assert len(exported_encodings) == 1 assert exported_encodings[0]["scale"] == delta @@ -821,11 +821,26 @@ def test_export_per_tensor_int_encodings(self, symmetric, bitwidth, delta, offse assert exported_encodings[0]["dtype"] == "int" assert exported_encodings[0]["is_symmetric"] == str(symmetric) + exported_encodings = qc_quantize_op.export_encodings("1.0.0") + assert isinstance(exported_encodings, dict) + assert exported_encodings.keys() == {"enc_type", "dtype", "bw", "is_sym", "scale", "offset"} + assert exported_encodings["dtype"] == "INT" + assert exported_encodings["enc_type"] == EncodingType.PER_TENSOR.name + assert exported_encodings["bw"] == bitwidth + assert exported_encodings["is_sym"] == symmetric + assert isinstance(exported_encodings["scale"], list) + assert isinstance(exported_encodings["offset"], list) + assert len(exported_encodings["scale"]) == 1 + assert len(exported_encodings["offset"]) == 1 + assert exported_encodings["scale"][0] == delta + assert exported_encodings["offset"][0] == offset + @pytest.mark.parametrize("symmetric, bitwidth, delta, offset", [(True, 8, 0.1, -128),]) def test_export_per_channel_int_encodings(self, symmetric, bitwidth, delta, offset): channel_axis = 0 + block_axis = 1 tensor_shape = [5, 8] - params = TensorQuantizerParams(tensor_shape, channel_axis) + params = TensorQuantizerParams(tensor_shape, channel_axis, block_axis) quant_info = libquant_info.QcQuantizeInfo() quant_info.usePerChannelMode = False @@ -844,6 +859,22 @@ def test_export_per_channel_int_encodings(self, symmetric, bitwidth, delta, offs exported_encodings = qc_quantize_op.export_encodings("0.6.1") assert len(exported_encodings) == tensor_shape[channel_axis] + exported_encodings = qc_quantize_op.export_encodings("1.0.0") + assert exported_encodings.keys() == {"enc_type", "dtype", "bw", "is_sym", "scale", "offset"} + assert exported_encodings["enc_type"] == EncodingType.PER_CHANNEL.name + assert len(exported_encodings["scale"]) == tensor_shape[channel_axis] + assert len(exported_encodings["offset"]) == tensor_shape[channel_axis] + + block_size = 4 + qc_quantize_op._enable_blockwise_quantization(block_size) + encodings = [libpymo.TfEncoding() for _ in range(tensor_shape[channel_axis] * 2)] + qc_quantize_op.load_encodings(encodings) + exported_encodings = qc_quantize_op.export_encodings("1.0.0") + assert exported_encodings.keys() == {"enc_type", "dtype", "bw", "is_sym", "scale", "offset", "block_size"} + assert exported_encodings["enc_type"] == EncodingType.PER_BLOCK.name + assert len(exported_encodings["scale"]) == tensor_shape[channel_axis] * 2 + assert exported_encodings["block_size"] == block_size + def test_export_float_encodings(self): quant_info = libquant_info.QcQuantizeInfo() qc_quantize_op = QcQuantizeOp(quant_info, bitwidth=16, op_mode=OpMode.quantizeDequantize) @@ -853,6 +884,11 @@ def test_export_float_encodings(self): assert encodings[0]["dtype"] == "float" assert encodings[0]["bitwidth"] == 16 + exported_encodings = qc_quantize_op.export_encodings("1.0.0") + assert exported_encodings.keys() == {"enc_type", "dtype", "bw"} + assert exported_encodings["dtype"] == "FLOAT" + assert exported_encodings["bw"] == 16 + def test_load_float_encodings(self): quant_info = libquant_info.QcQuantizeInfo() qc_quantize_op = QcQuantizeOp(quant_info, bitwidth=16, op_mode=OpMode.quantizeDequantize) diff --git a/TrainingExtensions/onnx/test/python/test_quantsim.py b/TrainingExtensions/onnx/test/python/test_quantsim.py index 9933a0bd926..07e9b01be29 100644 --- a/TrainingExtensions/onnx/test/python/test_quantsim.py +++ b/TrainingExtensions/onnx/test/python/test_quantsim.py @@ -35,6 +35,8 @@ # @@-COPYRIGHT-END-@@ # ============================================================================= +import contextlib +import itertools import json import os import tempfile @@ -46,8 +48,9 @@ import onnxruntime as ort import pytest +from aimet_common import quantsim from aimet_common import libquant_info -from aimet_common.defs import QuantScheme, QuantizationDataType +from aimet_common.defs import QuantScheme, QuantizationDataType, EncodingType from aimet_common.quantsim_config.utils import get_path_for_per_channel_config from aimet_onnx.quantsim import QuantizationSimModel, load_encodings_to_sim, set_blockwise_quantization_for_weights from aimet_onnx.qc_quantize_op import OpMode @@ -103,6 +106,15 @@ def forward(self, inputs): return x +@contextlib.contextmanager +def set_encoding_version(version): + old_version = quantsim.encoding_version + quantsim.encoding_version = version + + yield + + quantsim.encoding_version = old_version + class TestQuantSim: """Tests for QuantizationSimModel""" def test_insert_quantize_op_nodes(self): @@ -225,6 +237,42 @@ def dummy_callback(session, args): param_encodings_keys = list(encoding_data["param_encodings"][param][0].keys()) assert param_encodings_keys == ['bitwidth', 'dtype', 'is_symmetric', 'max', 'min', 'offset', 'scale'] + def test_export_model_1_0_0(self): + """Test to export encodings and model in 1.0.0 format""" + model = build_dummy_model() + with tempfile.TemporaryDirectory() as tempdir: + sim = QuantizationSimModel(model, path=tempdir, config_file=get_path_for_per_channel_config()) + + def dummy_callback(session, _): + session.run(None, make_dummy_input(model)) + + sim.compute_encodings(dummy_callback, None) + with set_encoding_version("1.0.0"): + sim.export(tempdir, 'quant_sim_model') + + with open(os.path.join(tempdir, 'quant_sim_model.encodings'), 'rb') as json_file: + encoding_data = json.load(json_file) + + assert encoding_data["version"] == "1.0.0" + assert isinstance(encoding_data["activation_encodings"], list) + assert isinstance(encoding_data["param_encodings"], list) + + activation_keys = {enc["name"] for enc in encoding_data["activation_encodings"]} + param_keys = {enc["name"] for enc in encoding_data["param_encodings"]} + assert activation_keys == {'4', '5', 'input', 'output'} + assert param_keys == {'conv_w', 'fc_w'} + + for enc in itertools.chain(encoding_data["param_encodings"], encoding_data["activation_encodings"]): + assert isinstance(enc, dict) + assert enc.keys() == {"name", "enc_type", "dtype", "bw", "is_sym", "scale", "offset"} + assert isinstance(enc["scale"], list) + assert enc["dtype"] == "INT" + # Gemm layers do not use per-channel in the default_per_channel_config + if enc["name"] == "conv_w": + assert enc["enc_type"] == EncodingType.PER_CHANNEL.name + else: + assert enc["enc_type"] == EncodingType.PER_TENSOR.name + def test_lstm_gru(self): """Test for LSTM and GRU dummy model""" model = build_lstm_gru_dummy_model() diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantsim/export_utils.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantsim/export_utils.py index 85bccd2dde8..e072b8bc8a5 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantsim/export_utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantsim/export_utils.py @@ -36,28 +36,17 @@ # ============================================================================= """ Export utilities for QuantizationSimModel """ -from enum import Enum import json import os from typing import Dict, List, Tuple from aimet_common.utils import AimetLogger -from aimet_common.defs import QuantizationDataType +from aimet_common.defs import QuantizationDataType, EncodingType from aimet_torch.utils import is_vector_encoding logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant) -VALID_ENCODING_VERSIONS = {'0.6.1', '1.0.0'} - -class EncodingType(Enum): - """ Encoding type """ - PER_TENSOR = 0 - PER_CHANNEL = 1 - PER_BLOCK = 2 - LPBQ = 3 - VECTOR = 4 - def _export_to_1_0_0(path: str, filename_prefix: str, tensor_to_activation_encodings: Dict[str, List], diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v1/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/v1/quantsim.py index 94f8dc7f0a3..0e85d9ab551 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v1/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v1/quantsim.py @@ -57,7 +57,7 @@ from aimet_common.connected_graph.connectedgraph_utils import CG_SPLIT from aimet_common.utils import AimetLogger, save_json_yaml, log_with_error_and_assert_if_false from aimet_common.defs import QuantScheme, QuantizationDataType, SupportedKernelsAction, QuantDtypeBwInfo -from aimet_common.quantsim import validate_quantsim_inputs, extract_global_quantizer_args +from aimet_common.quantsim import validate_quantsim_inputs, extract_global_quantizer_args, VALID_ENCODING_VERSIONS from aimet_common.quant_utils import get_conv_accum_bounds from aimet_torch.v1.nn.modules.custom import MatMul @@ -76,7 +76,7 @@ from aimet_torch.meta.connectedgraph import ConnectedGraph, Op from aimet_torch.qc_quantize_recurrent import QcQuantizeRecurrent from aimet_torch.quantsim_config.builder import LazyQuantizeWrapper -from aimet_torch.experimental.v2.quantsim.export_utils import VALID_ENCODING_VERSIONS, _export_to_1_0_0 +from aimet_torch.experimental.v2.quantsim.export_utils import _export_to_1_0_0 logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)