Skip to content

Commit

Permalink
Support encoding constraints in aimet onnx
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Tuttle <quic_mtuttle@quicinc.com>
  • Loading branch information
quic-mtuttle authored Oct 24, 2024
1 parent b61ad03 commit 67e8601
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 18 deletions.
32 changes: 29 additions & 3 deletions TrainingExtensions/common/src/python/aimet_common/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,30 @@ def is_non_strict_symmetric(use_symmetric_encodings: bool,
not is_unsigned_symmetric


def create_encoding_from_min_max(min_val: float, max_val: float, bitwidth: int, use_symmetric_encodings: bool,
use_strict_symmetric: bool) -> libpymo.TfEncoding:
"""
Returns a TfEncoding object with the provided min/max/bitwidth/symmetry
:param min_val: Min value of the encoding
:param max_val: Max value of the encoding
:param bitwidth: Encoding bitwidth
:param use_symmetric_encodings: If True, results in encoding with min = -max - delta
:param use_strict_symmetric: If True, results in encoding with min = -max
:return: libpymo.TfEncoding object
"""
delta, offset = calculate_delta_offset(min_val, max_val, bitwidth, use_symmetric_encodings, use_strict_symmetric)

encoding = libpymo.TfEncoding()
encoding.bw = bitwidth
encoding.min = min_val
encoding.max = max_val
encoding.delta = delta
encoding.offset = offset
# Note: need to recompute grid to account for offset rounding
return recompute_grid_params(encoding, bitwidth, use_symmetric_encodings, use_strict_symmetric)


def calculate_delta_offset(min_val: float, max_val: float, bitwidth: int, use_symmetric_encodings: bool,
use_strict_symmetric: bool) -> Tuple[float, int]:
"""
Expand Down Expand Up @@ -148,12 +172,14 @@ def compute_min_max_given_delta_offset(delta: float, offset: int, bitwidth: int,
return min_val, max_val

def recompute_grid_params(current_encoding: libpymo.TfEncoding, bitwidth: int,
use_symmetric_encoding: bool) -> libpymo.TfEncoding:
use_symmetric_encoding: bool, use_strict_symmetric: bool = False) -> libpymo.TfEncoding:
"""
Recomputed the encoding grid params - min/max/offset and delta.
Recomputes the encoding grid params - min/max/offset and delta.
:param current_encoding: Encoding associated with the quantizer as TfEncoding
:param bitwidth: bit width configured for the quantizer
:param use_symmetric_encoding: symmetric or asymmetric mode
:param use_strict_symmetric: True if using strict symmetric, False otherwise
:return: updated encoding params as libpymo.TfEncoding type.
"""

Expand All @@ -167,7 +193,7 @@ def recompute_grid_params(current_encoding: libpymo.TfEncoding, bitwidth: int,
num_positive_steps = (2 ** (bitwidth - 1)) - 1
abs_max_val = max(abs(max_val), abs(min_val))
delta = abs_max_val / num_positive_steps
offset = -(num_positive_steps + 1)
offset = -(num_positive_steps + int(not use_strict_symmetric))
# recompute min/max values
min_val = delta * offset
max_val = delta * num_positive_steps
Expand Down
37 changes: 26 additions & 11 deletions TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@
# =============================================================================
""" Custom QcQuantizeOp to quantize weights and activations using ONNXRuntime """

from typing import Union, List, Optional, Dict
from typing import Union, List, Optional, Dict, Tuple
import numpy as np

import aimet_common.libpymo as libpymo
from aimet_common.libpymo import TensorQuantizerOpMode
from aimet_common.defs import QuantScheme, MAP_QUANT_SCHEME_TO_PYMO, MAP_ROUND_MODE_TO_PYMO, QuantizationDataType, EncodingType
from aimet_common import libquant_info
from aimet_common.utils import deprecated
from aimet_common.quantsim import calculate_delta_offset
from aimet_common.quantsim import calculate_delta_offset, create_encoding_from_min_max
from aimet_onnx import lpbq_utils


Expand Down Expand Up @@ -101,6 +101,7 @@ def __init__(self, quant_info: libquant_info.QcQuantizeInfo,
self._data_type = QuantizationDataType.int
self.tensor_quantizer_params = tensor_quantizer_params
self._reset_encodings()
self._encoding_min_max_fixed_vals = None

def is_encoding_frozen(self) -> bool:
""" Returns is_encoding_frozen var """
Expand Down Expand Up @@ -410,17 +411,23 @@ def compute_encodings(self) -> Optional[List[libpymo.TfEncoding]]:
"""
Compute and return encodings of each tensor quantizer
"""
if not self._is_encoding_frozen:
if self.enabled:
encodings = []
for tensor_quantizer in self._tensor_quantizer:
encodings.append(tensor_quantizer.computeEncoding(self.bitwidth, self.use_symmetric_encodings))
self.load_encodings(encodings)
if self._is_encoding_frozen:
return None

if not self.enabled:
return None

encodings = []
for tensor_quantizer in self._tensor_quantizer:
if self._encoding_min_max_fixed_vals is None:
encodings.append(tensor_quantizer.computeEncoding(self.bitwidth, self.use_symmetric_encodings))
else:
encodings = None
min_val, max_val = self._encoding_min_max_fixed_vals
encodings.append(create_encoding_from_min_max(min_val, max_val, self.bitwidth, self.use_symmetric_encodings,
self.use_strict_symmetric))

return encodings
return None
self.load_encodings(encodings)
return encodings

def get_stats_histogram(self) -> List[List]:
"""
Expand Down Expand Up @@ -567,6 +574,14 @@ def clip_and_recompute_encodings(self, clamp_val: float) -> bool:

return is_clipped

def set_fixed_encoding_range(self, fixed_range: Tuple[float, float]):
"""
Set the min/max values to be used when computing encodings
:param fixed_range: Tuple of (min, max) value to use in-place of observer statistics when computing encodings
"""
self._encoding_min_max_fixed_vals = fixed_range


class GroupedBlockQuantizeDequantize(QcQuantizeOp):
""" Class for performing Grouped Block Quantize Dequantize """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
# =============================================================================
""" Utilities for parsing and applying quantsim configurations from json config file """
from abc import abstractmethod
from typing import List, Dict, Tuple
from typing import List, Dict, Tuple, Union

import onnx
from packaging import version
Expand Down Expand Up @@ -393,6 +393,9 @@ def _set_config_for_op(self, op_name, op_to_quantizer: OpToQuantizers, op_config
self._modify_activation_quantize_op(op_to_quantizer.input_quantizers + op_to_quantizer.output_quantizers,
ConfigDictKeys.IS_SYMMETRIC, op_config[ConfigDictKeys.IS_SYMMETRIC],
modified_quantize_ops)
if ConfigDictKeys.ENCODING_CONSTRAINTS in op_config:
self._modify_activation_quantize_op(op_to_quantizer.output_quantizers, ConfigDictKeys.ENCODING_CONSTRAINTS,
op_config[ConfigDictKeys.ENCODING_CONSTRAINTS], modified_quantize_ops)

# Will only see this in the op_type section, not default
if ConfigDictKeys.PARAMS in op_config:
Expand All @@ -411,7 +414,7 @@ def _get_param_type(self, op_name: str, param_name: str) -> str:

@staticmethod
def _modify_activation_quantize_op(quantize_ops_to_modify: List[QcQuantizeOp], setting_name: str,
quantizer_setting: bool, modified_quantize_ops: Dict):
quantizer_setting: Union[Dict, bool], modified_quantize_ops: Dict):
"""
Modify the appropriate quantize ops for the given quantizer setting. If a quantize op has already been
modified, compare the old setting with the new setting and assert if the settings conflict.
Expand All @@ -430,8 +433,11 @@ def _modify_activation_quantize_op(quantize_ops_to_modify: List[QcQuantizeOp], s
# Tensor quantizer's setting has already been modified
if setting_name in [ConfigDictKeys.IS_INPUT_QUANTIZED, ConfigDictKeys.IS_OUTPUT_QUANTIZED]:
current_setting = quantizer.enabled
else:
elif setting_name == ConfigDictKeys.IS_SYMMETRIC:
current_setting = quantizer.use_symmetric_encodings
else:
current_setting = {ConfigDictKeys.MIN: quantizer._encoding_min_max_fixed_vals[0], # pylint: disable=protected-access
ConfigDictKeys.MAX: quantizer._encoding_min_max_fixed_vals[1]} # pylint: disable=protected-access
if current_setting != quantizer_setting:
logger.error('Conflicting tensor quantizer settings for symmetric encodings')
raise AssertionError('Conflicting tensor quantizer settings for symmetric encodings')
Expand All @@ -442,8 +448,11 @@ def _modify_activation_quantize_op(quantize_ops_to_modify: List[QcQuantizeOp], s
else:
quantizer.enabled = True
quantizer.op_mode = OpMode.updateStats
else:
elif setting_name == ConfigDictKeys.IS_SYMMETRIC:
quantizer.use_symmetric_encodings = quantizer_setting
elif setting_name == ConfigDictKeys.ENCODING_CONSTRAINTS:
quantizer.set_fixed_encoding_range((quantizer_setting[ConfigDictKeys.MIN],
quantizer_setting[ConfigDictKeys.MAX]))
if quantizer not in modified_quantize_ops:
modified_quantize_ops[quantizer] = {setting_type}
else:
Expand Down
34 changes: 34 additions & 0 deletions TrainingExtensions/onnx/test/python/models/models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2454,3 +2454,37 @@ def matmul_add_model():
)
onnx.checker.check_model(model, True)
return model

def softmax_model():
model = helper.make_model(
graph=helper.make_graph(
name='SoftmaxModel',
inputs=[helper.make_tensor_value_info('model_input', TensorProto.FLOAT, shape=[1, 1, 8, 8])],
outputs=[helper.make_tensor_value_info('model_output', TensorProto.FLOAT, shape=[1, 1, 8, 8])],
initializer=[
numpy_helper.from_array(np.full((1, 1, 8, 8), 25.0).astype('float32'), name='matmul_1.weight'),
],
nodes=[
helper.make_node(
'MatMul',
inputs=['model_input', 'matmul_1.weight'],
outputs=['matmul.output'],
name='matmul'
),
helper.make_node(
'Softmax',
inputs=['matmul.output'],
outputs=['softmax.output'],
name='softmax'
),
helper.make_node(
'Sigmoid',
inputs=['softmax.output'],
outputs=['model_output'],
name='sigmoid'
),
]
)
)
onnx.checker.check_model(model, True)
return model
53 changes: 53 additions & 0 deletions TrainingExtensions/onnx/test/python/test_quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,59 @@ def test_lpbq_strict(self):
assert quantizer.quant_info.blockSize == 0
assert not quantizer.quant_info.isIntDataType

def test_encoding_constraints(self, tmp_path):
quantsim_config = {
"defaults":
{
"ops":
{
"is_output_quantized": "True"
},
"params":
{
"is_quantized": "True",
"is_symmetric": "True"
},
"per_channel_quantization": "False",
"strict_symmetric": "False",
"unsigned_symmetric": "False"
},
"params": {},
"op_type": {
"Softmax":
{
"encoding_constraints":
{
"min": 0.0,
"max": 1.0
}
},
"Sigmoid":
{
"encoding_constraints":
{
"min": 0.0,
"max": 2.0
}
},
},
"supergroups": [],
"model_input": {},
"model_output": {}
}
config_name = os.path.join(tmp_path, 'quantsim_config.json')
with open(config_name, 'w') as f:
json.dump(quantsim_config, f)
model = models_for_tests.softmax_model()
sim = QuantizationSimModel(model, config_file=config_name)
sim.compute_encodings(lambda sess, _: sess.run(None, make_dummy_input(model)), None)
assert sim.qc_quantize_op_dict['model_output'].encodings[0].max == 2.0
assert sim.qc_quantize_op_dict['model_output'].encodings[0].min == 0.0
assert sim.qc_quantize_op_dict['softmax.output'].encodings[0].max == 1.0
assert sim.qc_quantize_op_dict['softmax.output'].encodings[0].min == 0.0
assert sim.qc_quantize_op_dict['matmul.output'].encodings[0].max not in (1.0, 2.0)
assert sim.qc_quantize_op_dict['matmul.output'].encodings[0].min != 0.0


class TestEncodingPropagation:

Expand Down

0 comments on commit 67e8601

Please sign in to comment.