Skip to content

Commit

Permalink
Add deprecation warning for encoding version
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Hsieh <quic_klhsieh@quicinc.com>
  • Loading branch information
quic-klhsieh authored and quic-twilkens committed Oct 15, 2024
1 parent 96b6f92 commit efece8a
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 29 deletions.
12 changes: 11 additions & 1 deletion TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import os
from typing import Dict, List, Union, Tuple, Optional
import json
import warnings
import numpy as np
import onnx

Expand All @@ -58,7 +59,7 @@
from aimet_common import libquant_info
from aimet_common.defs import QuantScheme, QuantizationDataType
from aimet_common.quantsim import extract_global_quantizer_args, VALID_ENCODING_VERSIONS
from aimet_common.utils import save_json_yaml, AimetLogger
from aimet_common.utils import save_json_yaml, AimetLogger, _red
from aimet_common.connected_graph.product import Product
from aimet_onnx import utils
from aimet_onnx.meta.operations import Op
Expand Down Expand Up @@ -771,6 +772,15 @@ def export(self, path: str, filename_prefix: str):
:param path: dir to save encoding files
:param filename_prefix: filename to save encoding files
"""
if quantsim.encoding_version == '0.6.1':
msg = _red("Encoding version 0.6.1 will be deprecated in a future release, with version 1.0.0 becoming "
"the default. If your code depends on parsing the exported encodings file, ensure that it is "
"updated to be able to parse 1.0.0 format.\n"
"To swap the encoding version to 1.0.0, run the following lines prior to calling quantsim "
"export:\n\n"
"from aimet_common import quantsim\n"
"quantsim.encoding_version = '1.0.0'")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
self._export_encodings(os.path.join(path, filename_prefix) + '.encodings')
self.remove_quantization_nodes()
if self.model.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
""" Quantsim for Keras """
from __future__ import annotations

import contextlib
from dataclasses import dataclass
import json
import os
Expand All @@ -47,7 +48,8 @@

from aimet_common.defs import QuantScheme, QuantizationDataType
from aimet_common.utils import AimetLogger, save_json_yaml
from aimet_common.quantsim import encoding_version, extract_global_quantizer_args
from aimet_common import quantsim
from aimet_common.quantsim import extract_global_quantizer_args
from aimet_tensorflow.keras.connectedgraph import ConnectedGraph
from aimet_tensorflow.keras.graphsearchtuils import GraphSearchUtils
from aimet_tensorflow.keras.quant_sim.qc_quantize_wrapper import QcQuantizeWrapper, QuantizerSettings
Expand Down Expand Up @@ -424,7 +426,7 @@ def get_encodings_dict(self) -> Dict[str, Union[str, Dict]]:
encoding_dict = self._get_encoding_dict_for_quantizer(output_quantizer)
activation_encodings[tensor_name] = encoding_dict
return {
'version': encoding_version,
'version': quantsim.encoding_version,
'activation_encodings': activation_encodings,
'param_encodings': param_encodings,
'quantizer_args': self.quant_args if hasattr(self, "quant_args") else {}
Expand Down Expand Up @@ -477,6 +479,20 @@ def _set_op_mode_parameters(self, op_mode: libpymo.TensorQuantizerOpMode):
if param_quantizer.is_enabled():
param_quantizer.quant_mode = op_mode

@staticmethod
@contextlib.contextmanager
def _set_encoding_version_to_0_6_1():
assert quantsim.encoding_version in {'0.6.1', '1.0.0'}
if quantsim.encoding_version == '1.0.0':
_logger.info('Exporting to encoding version 1.0.0 is not yet supported. Exporting using version 0.6.1 '
'instead.')
old_encoding_version = quantsim.encoding_version
quantsim.encoding_version = '0.6.1'

yield

quantsim.encoding_version = old_encoding_version

def export(self, path, filename_prefix, custom_objects=None, convert_to_pb=True):
"""
This method exports out the quant-sim model so it is ready to be run on-target.
Expand All @@ -488,29 +504,30 @@ def export(self, path, filename_prefix, custom_objects=None, convert_to_pb=True)
:param filename_prefix: Prefix to use for filenames of the model pth and encodings files
:param custom_objects: If there are custom objects to load, Keras needs a dict of them to map them
"""
model_path = os.path.join(path, filename_prefix)
with self._set_encoding_version_to_0_6_1():
model_path = os.path.join(path, filename_prefix)

#TF Version 2.4 has bug i.e. save() in tf format don't work for unrolled LSTM.
for layer in self._model_without_wrappers.layers:
if isinstance(layer, tf.keras.layers.LSTM):
break
else:
self._model_without_wrappers.save(model_path)

self._model_without_wrappers.save(model_path + '.h5', save_format='h5')

# Conversion of saved h5 model to pb model for consumption by SNPE/QNN
try:
if convert_to_pb:
convert_h5_model_to_pb_model(f'{model_path}.h5', custom_objects=custom_objects)
except ValueError:
_logger.error("Could not convert h5 to frozen pb. "
"Please call export() again with custom_objects defined.")
raise
finally:
encodings_dict = self.get_encodings_dict()
encoding_file_path = os.path.join(path, filename_prefix + '.encodings')
save_json_yaml(encoding_file_path, encodings_dict)
#TF Version 2.4 has bug i.e. save() in tf format don't work for unrolled LSTM.
for layer in self._model_without_wrappers.layers:
if isinstance(layer, tf.keras.layers.LSTM):
break
else:
self._model_without_wrappers.save(model_path)

self._model_without_wrappers.save(model_path + '.h5', save_format='h5')

# Conversion of saved h5 model to pb model for consumption by SNPE/QNN
try:
if convert_to_pb:
convert_h5_model_to_pb_model(f'{model_path}.h5', custom_objects=custom_objects)
except ValueError:
_logger.error("Could not convert h5 to frozen pb. "
"Please call export() again with custom_objects defined.")
raise
finally:
encodings_dict = self.get_encodings_dict()
encoding_file_path = os.path.join(path, filename_prefix + '.encodings')
save_json_yaml(encoding_file_path, encodings_dict)

def _compute_and_set_parameter_encodings(self, ops_with_invalid_encodings: List):
# pylint: disable=too-many-nested-blocks
Expand Down
28 changes: 27 additions & 1 deletion TrainingExtensions/tensorflow/test/python/test_quantsim_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#
# @@-COPYRIGHT-END-@@
# =============================================================================

import contextlib
import json
import os
import tempfile
Expand All @@ -50,6 +50,7 @@

import aimet_common.utils
from aimet_common.defs import QuantScheme, RANGE_LEARNING_SCHEMES
from aimet_common import quantsim
from aimet_tensorflow.examples.test_models import keras_model
from aimet_tensorflow.keras.utils.quantizer_utils import SaveModelWithoutQuantsimWrappersCallback
from aimet_tensorflow.keras.cross_layer_equalization import equalize_model
Expand Down Expand Up @@ -1632,3 +1633,28 @@ def test_quantizable_lstm_export_encodings():
assert param_name in encodings['param_encodings']
assert encodings['param_encodings'][param_name] == encoding_dict

def test_quantsim_export_to_1_0_0():
@contextlib.contextmanager
def _swap_encoding_version():
old_version = quantsim.encoding_version
quantsim.encoding_version = '1.0.0'

yield

quantsim.encoding_version = old_version

model = dense_functional()
rand_inp = np.random.randn(100, 5)


qsim = QuantizationSimModel(model, quant_scheme='tf')
qsim.compute_encodings(lambda m, _: m(rand_inp), None)

with tempfile.TemporaryDirectory() as temp_dir, _swap_encoding_version():
assert quantsim.encoding_version == '1.0.0'
qsim.export(temp_dir, 'test_export')
assert quantsim.encoding_version == '1.0.0'

with open(os.path.join(temp_dir, 'test_export.encodings'), 'r') as encoding_file:
encodings = json.load(encoding_file)
assert encodings['version'] == '0.6.1'
13 changes: 11 additions & 2 deletions TrainingExtensions/torch/src/python/aimet_torch/v1/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from typing import Tuple, List, Union, Dict, Callable, Optional, Any, runtime_checkable, Protocol, Mapping
from collections import OrderedDict, defaultdict
import json
import warnings
import torch
import onnx
from packaging import version # pylint: disable=wrong-import-order
Expand All @@ -59,6 +60,7 @@
from aimet_common.defs import QuantScheme, QuantizationDataType, SupportedKernelsAction, QuantDtypeBwInfo
from aimet_common.quantsim import validate_quantsim_inputs, extract_global_quantizer_args, VALID_ENCODING_VERSIONS
from aimet_common.quant_utils import get_conv_accum_bounds
from aimet_common.utils import deprecated, _red

from aimet_torch.v1.nn.modules.custom import MatMul
from aimet_torch.quantsim_config.quantsim_config import QuantSimConfigurator
Expand All @@ -67,7 +69,6 @@
from aimet_torch.tensor_quantizer import initialize_learned_grid_quantizer_attributes, TensorQuantizer
from aimet_torch.qc_quantize_op import get_encoding_by_quantizer as _get_encoding_by_quantizer
from aimet_torch import torchscript_utils, utils, onnx_utils
from aimet_torch.utils import deprecated
from aimet_torch.onnx_utils import (
OnnxSaver,
OnnxExportApiArgs,
Expand Down Expand Up @@ -520,7 +521,15 @@ def export(self, path: str, filename_prefix: str, dummy_input: Union[torch.Tenso
:param filename_prefix_encodings: File name prefix to be used when saving encodings.
If None, then user defaults to filename_prefix value
"""

if quantsim.encoding_version == '0.6.1':
msg = _red("Encoding version 0.6.1 will be deprecated in a future release, with version 1.0.0 becoming "
"the default. If your code depends on parsing the exported encodings file, ensure that it is "
"updated to be able to parse 1.0.0 format.\n"
"To swap the encoding version to 1.0.0, run the following lines prior to calling quantsim "
"export:\n\n"
"from aimet_common import quantsim\n"
"quantsim.encoding_version = '1.0.0'")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
warning_str = 'Exporting encodings to yaml will be deprecated in a future release. Ensure that your ' \
'code can work with the exported files ending in ".encodings" which are saved using json ' \
'format. For the time being, if yaml export is needed, set aimet_common.utils.SAVE_TO_YAML to ' \
Expand Down
2 changes: 1 addition & 1 deletion TrainingExtensions/torch/test/python/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def test_add_quantization_wrappers_with_modulelist_with_layers_to_ignore(self):
assert 'layers_deep.5.0' in sim._excluded_layer_names
assert 'layers_deep.5.1' in sim._excluded_layer_names

with tempfile.TemporaryDirectory() as tmpdir:
with tempfile.TemporaryDirectory() as tmpdir, pytest.warns(DeprecationWarning):
sim.export(tmpdir, 'modulelist_with_layers_to_ignore', dummy_input=torch.rand(1, 3, 12, 12))
with open(os.path.join(tmpdir, "modulelist_with_layers_to_ignore.encodings"), "r") as encodings_file:
encodings = json.load(encodings_file)
Expand Down

0 comments on commit efece8a

Please sign in to comment.