Skip to content

Commit

Permalink
Make save_encodings_to_json compatible to both v1 and v2 (#2715)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored and quic-bharathr committed Sep 13, 2024
1 parent 26e476f commit 69e62ff
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 82 deletions.
103 changes: 27 additions & 76 deletions TrainingExtensions/torch/src/python/aimet_torch/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import copy
import pickle
from typing import Tuple, List, Union, Dict, Callable, Optional, Any, runtime_checkable, Protocol
from collections.abc import Iterable
from collections import OrderedDict, defaultdict
import json
import torch
import onnx
Expand All @@ -59,8 +59,7 @@
from aimet_torch.quantsim_config.quantsim_config import QuantSimConfigurator
from aimet_torch.qc_quantize_op import QcQuantizeStandAloneBase, QcQuantizeWrapper, QcQuantizeOpMode, \
StaticGridQuantWrapper, LearnedGridQuantWrapper, NativeTorchQuantWrapper, QUANTIZER_TYPE_INPUT, QUANTIZER_TYPE_OUTPUT
from aimet_torch.tensor_quantizer import StaticGridTensorQuantizer, LearnedGridTensorQuantizer, \
initialize_learned_grid_quantizer_attributes
from aimet_torch.tensor_quantizer import initialize_learned_grid_quantizer_attributes
from aimet_torch.qc_quantize_op import get_encoding_by_quantizer as _get_encoding_by_quantizer
from aimet_torch import torchscript_utils, utils, transformer_utils, onnx_utils
from aimet_torch.onnx_utils import OnnxSaver, OnnxExportApiArgs, CustomMarker, get_pytorch_name_from_onnx_name
Expand Down Expand Up @@ -616,86 +615,38 @@ def get_activation_param_encodings(self):
:return: Tuple of activation and param encodings dictionaries mapping torch module names to encodings
"""
activation_encodings = {}
param_encodings = {}
for layer_name, layer in QuantizationSimModel._get_qc_quantized_layers(self.model):
if isinstance(layer.input_quantizers, list):
for index, quantizer in enumerate(layer.input_quantizers):
self._save_activation_encoding(layer_name, index, quantizer, QUANTIZER_TYPE_INPUT,
activation_encodings)
else:
for quantizer_name, quantizer in layer.input_quantizers.items():
self._save_activation_encoding(layer_name, quantizer_name, quantizer, QUANTIZER_TYPE_INPUT,
activation_encodings)
activation_encodings = OrderedDict()
param_encodings = OrderedDict()

if isinstance(layer.output_quantizers, list):
for index, quantizer in enumerate(layer.output_quantizers):
self._save_activation_encoding(layer_name, index, quantizer, QUANTIZER_TYPE_OUTPUT,
activation_encodings)
else:
for quantizer_name, quantizer in layer.output_quantizers.items():
self._save_activation_encoding(layer_name, quantizer_name, quantizer, QUANTIZER_TYPE_OUTPUT,
activation_encodings)
for module_name, module in self.model.named_modules():
if not isinstance(module, ExportableQuantModule):
continue

for orig_param_name, param_quantizer in layer.param_quantizers.items():
self._save_param_encoding(layer_name, orig_param_name, param_quantizer, param_encodings)
return activation_encodings, param_encodings
activation_encodings[module_name] = defaultdict(OrderedDict)

@staticmethod
def _save_activation_encoding(layer_name: str, quantizer_identifier: Union[str, int],
quantizer: Union[StaticGridTensorQuantizer, LearnedGridTensorQuantizer],
quantizer_type: str, activation_encodings: Dict):
"""
Save activation encoding into a dictionary.
for i, encoding in enumerate(module.export_input_encodings()):
if not encoding:
continue
if len(encoding) == 1:
encoding = encoding[0]
activation_encodings[module_name]['input'][i] = encoding

:param layer_name: Name of layer
:param quantizer_identifier: Identifier for the quantizer. Typically either an index or quantizer name
:param quantizer: Quantizer to save encoding for
:param quantizer_type: 'input' or 'output' depending on position of the quantizer
:param activation_encodings: Dictionary to save activation encoding to
"""
if not quantizer.enabled:
return
for i, encoding in enumerate(module.export_output_encodings()):
if not encoding:
continue
if len(encoding) == 1:
encoding = encoding[0]
activation_encodings[module_name]['output'][i] = encoding

quantizer_encoding = _get_encoding_by_quantizer(quantizer)
encoding = QuantizationSimModel._create_encoding_dict(quantizer_encoding,
quantizer,
propagate_encodings=False)
if layer_name not in activation_encodings:
activation_encodings[layer_name] = {}
if quantizer_type not in activation_encodings[layer_name]:
activation_encodings[layer_name][quantizer_type] = {}
activation_encodings[layer_name][quantizer_type][quantizer_identifier] = encoding
if not activation_encodings[module_name]:
del activation_encodings[module_name]

@staticmethod
def _save_param_encoding(layer_name: str, param_name: str,
param_quantizer: Union[StaticGridTensorQuantizer, LearnedGridTensorQuantizer],
param_encodings: Dict):
"""
Save param encoding into a dictionary.
for param_name, encoding in module.export_param_encodings().items():
if not encoding:
continue
param_encodings[f'{module_name}.{param_name}'] = encoding

:param layer_name: Name of layer
:param param_name: Name of the parameter
:param param_quantizer: Quantizer to save encoding for
:param param_encodings: Dictionary to save param encoding to
"""
if not param_quantizer.enabled:
return

param_name = layer_name + '.' + param_name
if isinstance(param_quantizer.encoding, Iterable):
param_encodings[param_name] = []
quantizer_encoding = _get_encoding_by_quantizer(param_quantizer)
for encoding in quantizer_encoding:
enc_dict = QuantizationSimModel._create_encoding_dict(encoding,
param_quantizer,
propagate_encodings=False)
param_encodings[param_name].append(enc_dict)
else:
quantizer_encoding = _get_encoding_by_quantizer(param_quantizer)
enc_dict = QuantizationSimModel._create_encoding_dict(quantizer_encoding, param_quantizer,
propagate_encodings=False)
param_encodings[param_name] = [enc_dict]
return activation_encodings, param_encodings

def exclude_layers_from_quantization(self, layers_to_exclude: List[torch.nn.Module]):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2489,17 +2489,18 @@ def test_export_to_onnx_for_multiple_p_relu_model(self, num_parameters, config_f
assert len(param_encodings[param_name]) == 1
onnx_utils.RESTORE_ONNX_MODEL_INITIALIZERS = False

@pytest.mark.skip("save_encodings_to_json not supported yet")
def test_save_encodings_to_json(self):
model = ModelWithTwoInputsOneToAdd()
dummy_input = (torch.rand(32, 1, 100, 100), torch.rand(32, 10, 22, 22))
qsim = QuantizationSimModel(model, dummy_input, quant_scheme=QuantScheme.post_training_tf)
qsim.compute_encodings(lambda m, _: m(*dummy_input), None)
qsim.save_encodings_to_json('./data', 'saved_encodings')
with open('./data/saved_encodings.json') as encodings_file:
encodings = json.load(encodings_file)
assert len(encodings['activation_encodings']) == 14
assert len(encodings['param_encodings']) == 5

with tempfile.TemporaryDirectory() as tmp_dir:
qsim.save_encodings_to_json(tmp_dir, 'saved_encodings')
with open(f'{tmp_dir}/saved_encodings.json') as encodings_file:
encodings = json.load(encodings_file)
assert len(encodings['activation_encodings']) == 14
assert len(encodings['param_encodings']) == 5

@pytest.mark.skip('compute_encodings_for_sims not supported yet')
def test_compute_encodings_for_multiple_sims(self):
Expand Down

0 comments on commit 69e62ff

Please sign in to comment.