Skip to content

Commit

Permalink
Add tie observers utility in aimet onnx (quic#3387)
Browse files Browse the repository at this point in the history
* Add utility to tie quantizers in aimet_onnx

Signed-off-by: Hitarth Mehta <quic_hitameht@quicinc.com>

* Incorporate review feedback

Signed-off-by: Hitarth Mehta <quic_hitameht@quicinc.com>

* Expose temp API to apply constraints

Signed-off-by: Hitarth Mehta <quic_hitameht@quicinc.com>

---------

Signed-off-by: Hitarth Mehta <quic_hitameht@quicinc.com>
  • Loading branch information
quic-hitameht authored and quic-twilkens committed Oct 15, 2024
1 parent 6fef47f commit 6ffe10e
Show file tree
Hide file tree
Showing 3 changed files with 337 additions and 32 deletions.
130 changes: 123 additions & 7 deletions TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# =============================================================================
""" Implementation for simulating models running on Quantized hardware """

import contextlib
import tempfile
from dataclasses import dataclass
from pathlib import Path
Expand All @@ -58,6 +59,7 @@
from aimet_common.defs import QuantScheme, QuantizationDataType
from aimet_common.quantsim import extract_global_quantizer_args, VALID_ENCODING_VERSIONS
from aimet_common.utils import save_json_yaml, AimetLogger
from aimet_common.connected_graph.product import Product
from aimet_onnx import utils
from aimet_onnx.meta.operations import Op
from aimet_onnx.meta.utils import get_op_given_param_name, get_param_shape_using_connected_graph
Expand All @@ -83,8 +85,31 @@

allowed_op_type_for_per_channel = ['Conv', 'Gemm', 'MatMul', 'ConvTranspose']

# List of op types whose input and output quantizers to be tied
op_types_to_tie_qtzrs = ['Concat', 'MaxPool', 'AveragePool', 'Resize']
_tie_qtzrs = False

data_types_to_quantize = [np.float32]


@contextlib.contextmanager
def _apply_constraints(flag: bool):
"""
Apply runtime specific constraints.
For certain ``op_types_to_tie_qtzrs``, runtime has constraints to have same encodings for
input and output quantizers.
NOTE: Default setting doesn't apply these constraints.
"""
global _tie_qtzrs # pylint: disable=global-statement
orig_flag = _tie_qtzrs
try:
_tie_qtzrs = flag
yield
finally:
_tie_qtzrs = orig_flag


@dataclass
class EncodingMismatchInfo:
"""
Expand Down Expand Up @@ -185,25 +210,27 @@ def __init__(self,
self._path = path if path else tempfile.mkdtemp()
if not os.path.exists(self._path):
os.makedirs(self._path, exist_ok=True)

# Get names of parameters and activations to quantize
self._get_param_names()
self._get_activations_to_quantize(dummy_input)

# Disable bias quantization
self._disable_bias_quantization()

self._add_quantization_nodes()

self.session = QuantizationSimModel.build_session(self.model.model, self.providers,
user_onnx_libs=self._user_onnx_libs, path=self._path)
# Apply configurations based on provided config file.
quantsim_configurator = self._add_configuration_(config_file)

self._hw_version = quantsim_configurator._get_hw_version()
self._supported_kernels = quantsim_configurator.get_supported_kernels()
self._op_to_supported_kernel = quantsim_configurator.get_op_to_supported_kernels()

self.quant_args = extract_global_quantizer_args(quant_scheme, quantsim_configurator)

self._apply_exception_rules()
self._tie_quantizers()

# Build onnxruntime inference session
self.session = QuantizationSimModel.build_session(self.model.model, self.providers,
user_onnx_libs=self._user_onnx_libs, path=self._path)

def get_supported_kernels(self) -> Dict:
"""
Expand Down Expand Up @@ -548,7 +575,7 @@ def get_op_quantizers(self, op: Op) -> (List, List, Dict):
if param_name in self.qc_quantize_op_dict:
param_quantizers[param_type] = self.qc_quantize_op_dict[param_name]

return (input_quantizers, output_quantizers, param_quantizers)
return input_quantizers, output_quantizers, param_quantizers

def _apply_exception_rules(self):
"""
Expand Down Expand Up @@ -792,6 +819,95 @@ def get_all_quantizers(self) -> Tuple[List, List]:

return param_quantizers, activation_quantizers

def _tie_quantizers(self):
"""
Tie the input and output quantizers for given op types.
"""
if not _tie_qtzrs:
return

cg = self.connected_graph

def _set_quant_info(dst_qtzr_node_name: str, src_qtzr: QcQuantizeOp):
"""
Set quant_info attribute (pointer to the libquant_info object)
:param dst_qtzr_node_name: destination quantizer node name in graph.
:param src_qtzr: source quantizer.
"""
for node in self.model.graph().node:
if node.op_type == 'QcQuantizeOp' and node.name == dst_qtzr_node_name:
for atr in node.attribute:
if atr.name == "quant_info":
atr.i = libpymo.PtrToInt64(src_qtzr.quant_info)
return

def _set_qtzr(dst_qtzr: QcQuantizeOp, src_qtzr: QcQuantizeOp):
"""
Set the dst quantizer by src quantizer and update quant_info attribute (pointer to the libquant_info object)
in the graph node.
:param dst_qtzr: destination quantizer.
:param src_qtzr: source quantizer
"""
for name, qtzr in self.qc_quantize_op_dict.items():
if dst_qtzr == qtzr:
self.qc_quantize_op_dict[name] = src_qtzr
dst_qtzr_node_name = 'QcQuantizeOp_' + name
# update quant_info attribute (pointer to the libquant_info object) in the graph node.
_set_quant_info(dst_qtzr_node_name, src_qtzr)
return

def _set_src_qtzr(x: Product, consumer: Op, src_qtzr):
producer = x.producer

if not producer:
# ``x`` is a root input (i.e. has no producer).
# In this case, set the input quantizer of the consumer to ``src_qtzr``
i = consumer.inputs.index(x)
inp_qtzr, _, __ = self.get_op_quantizers(consumer)
if i >= len(inp_qtzr):
return

_set_qtzr(dst_qtzr=inp_qtzr[i], src_qtzr=src_qtzr)
return

_, out_qtzr, __ = self.get_op_quantizers(producer)

if out_qtzr:
# There exists output quantizer associated with the graph node ``producer``
# In this case, set the output quantizer of the producer to ``src_qtzr`
outputs = [producer.output]
i = outputs.index(x)
_set_qtzr(dst_qtzr=out_qtzr[i], src_qtzr=src_qtzr)

if not out_qtzr or producer.type in op_outputs_to_ignore:
# 1. There is no output quantizer associated with the graph node ``producer``, or
# 2. op is a math invariant op (reshape, permute, etc.).
# In these cases, propagate encoding further to the ancestors
for inp in producer.inputs:
_set_src_qtzr(inp, consumer=producer, src_qtzr=src_qtzr)

for op in reversed(cg.ordered_ops):
if op.type not in op_types_to_tie_qtzrs:
continue

_, out_qtzr, __ = self.get_op_quantizers(op)

if not out_qtzr:
msg = 'Encoding propagation is only supported for ops with exactly ' \
'1 output quantizer, but found output_quantizers[0] == []'
raise RuntimeError(msg)

if len(out_qtzr) != 1:
msg = 'Encoding propagation is only supported for ops with exactly ' \
f'1 output quantizer, but found {len(out_qtzr)} ' \
'output quantizers'
raise RuntimeError(msg)

for inp in op.inputs:
_set_src_qtzr(inp, consumer=op, src_qtzr=out_qtzr[0])


def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, onnx_encoding_path: str, strict=True) -> \
List[EncodingMismatchInfo]:
Expand Down
48 changes: 26 additions & 22 deletions TrainingExtensions/onnx/test/python/models/models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,31 +1634,35 @@ def forward(self, inputs):
return x, y


def _convert_to_onnx_no_fold(model: torch.nn.Module, dummy_input, filename='./temp_model.onnx'):
torch.onnx.export(model.eval(),
dummy_input,
filename,
training=torch.onnx.TrainingMode.PRESERVE,
export_params=True,
opset_version=12,
do_constant_folding=False,
input_names=['input'],
output_names=['output'])
model = ONNXModel(load_model(filename))
def _convert_to_onnx_no_fold(model: torch.nn.Module, dummy_input, filename='temp_model.onnx'):
with tempfile.TemporaryDirectory() as tmp_dir:
save_path = os.path.join(tmp_dir, filename)
torch.onnx.export(model.eval(),
dummy_input,
save_path,
training=torch.onnx.TrainingMode.PRESERVE,
export_params=True,
opset_version=12,
do_constant_folding=False,
input_names=['input'],
output_names=['output'])
model = ONNXModel(load_model(save_path))
return model


def _convert_to_onnx(model: torch.nn.Module, dummy_input, filename='./temp_model.onnx'):
torch.onnx.export(model.eval(),
dummy_input,
filename,
training=torch.onnx.TrainingMode.EVAL,
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'])
model = ONNXModel(load_model(filename))
def _convert_to_onnx(model: torch.nn.Module, dummy_input, filename='temp_model.onnx'):
with tempfile.TemporaryDirectory() as tmp_dir:
save_path = os.path.join(tmp_dir, filename)
torch.onnx.export(model.eval(),
dummy_input,
save_path,
training=torch.onnx.TrainingMode.EVAL,
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'])
model = ONNXModel(load_model(save_path))
return model


Expand Down
Loading

0 comments on commit 6ffe10e

Please sign in to comment.