Skip to content

Commit

Permalink
Added function to merge op type info
Browse files Browse the repository at this point in the history
Signed-off-by: Chetan Gulecha <quic_cgulecha@quicinc.com>
  • Loading branch information
quic-cgulecha committed Nov 23, 2023
1 parent d06d7e6 commit 762d718
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -409,12 +409,12 @@ def populate_supported_kernels_in_json_config(master_opdef_file_path: str,
with open(json_config_file_path) as file:
quantsim_config = json.load(file)

for qnn_op_type in supported_kernels_dict:
supported_kernels = get_supported_kernels_from_backend_info(supported_kernels_dict[qnn_op_type])
if qnn_op_type not in quantsim_config[ConfigDictKeys.OP_TYPE]:
quantsim_config[ConfigDictKeys.OP_TYPE][qnn_op_type] = {}
for backend_op_type in supported_kernels_dict:
supported_kernels = get_supported_kernels_from_backend_info(supported_kernels_dict[backend_op_type])
if backend_op_type not in quantsim_config[ConfigDictKeys.OP_TYPE]:
quantsim_config[ConfigDictKeys.OP_TYPE][backend_op_type] = {}

quantsim_config[ConfigDictKeys.OP_TYPE][qnn_op_type][ConfigDictKeys.SUPPORTED_KERNELS] = supported_kernels
quantsim_config[ConfigDictKeys.OP_TYPE][backend_op_type][ConfigDictKeys.SUPPORTED_KERNELS] = supported_kernels

with open(json_config_file_path, 'w') as file:
json.dump(quantsim_config, file, indent=4)
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# @@-COPYRIGHT-END-@@
# =============================================================================
""" Utilities for parsing and applying quantsim configurations from json config file """
import copy
from abc import abstractmethod
from typing import Dict, List, Tuple, Set, Union
import torch
Expand Down Expand Up @@ -124,7 +125,6 @@ def __init__(self, model, connected_graph: ConnectedGraph, config_file: str, qua
self._set_quantsim_configs()
self._generate_and_apply_op_instance_specific_config()


def _create_named_modules_to_tensor_quantizers_dict(self) -> Dict[torch.nn.Module, TensorQuantizersTupleType]:
"""
For every named module in the graph, associate it with a tuple containing 4 lists:
Expand Down Expand Up @@ -370,42 +370,26 @@ def _set_param_configs(self, param_configs: ParamType):
param_config = param_configs[quantsim_param_name]
_set_config_for_param(quantsim_wrapper.param_quantizers[param_name], param_config)

def _merge_and_set_op_type_config_for_module(self, module: torch.nn.Module, op_configs: OpType, qnn_type: str,
onnx_types: List[str], input_output_tensor_quantizers: TensorQuantizersTupleType,
modified_tensor_quantizers: Dict[TensorQuantizer, Set]):
def _merge_op_types_info(self, op_configs: OpTypeType):
"""
Merges and set op type config for module
Merges op type info. from ONNX op type to backend op types
:param module: Module to set config of (will be None for elementwise ops)
:param op_config: Configuration for the op
:param qnn_type: qnn_type for module
:param onnx_types: onnx types for module
:param input_output_tensor_quantizers: Tuple of 4 lists containing the following:
- List of tensor quantizers to change if op's input quantizer setting is set to True
- List of tensor quantizers to change if op's output quantizer setting is set to True
- List of tensor quantizers to change if op's input quantizer setting is set to False
- List of tensor quantizers to change if op's output quantizer setting is set to False
:param modified_tensor_quantizers: Dictionary of tensor quantizers mapping to set of settings that have been
changed for that tensor quantizer already.
:param op_configs: Dictionary containing configurations for ops of certain types
"""
qnn_type_op_config = op_configs[qnn_type]
if not onnx_types:
logger.info(' Set op level config for op = {%s}', qnn_type)
self._set_config_for_module(input_output_tensor_quantizers, qnn_type_op_config,
modified_tensor_quantizers, module)
else:
for onnx_type in onnx_types:
if onnx_type in op_configs:
onnx_op_config = op_configs[onnx_type]
if onnx_type != qnn_type:
qnn_type_op_config = merge_op_config(qnn_type_op_config, onnx_op_config)
logger.info(' Set merged op level config for op = {%s}', qnn_type)
self._set_config_for_module(input_output_tensor_quantizers, qnn_type_op_config,
modified_tensor_quantizers, module)
else:
logger.info(' Set op level config for op = {%s}', qnn_type)
self._set_config_for_module(input_output_tensor_quantizers, qnn_type_op_config,
modified_tensor_quantizers, module)
merged_backend_types = []

for module, _ in self._named_modules_to_tensor_quantizers_dict.items():

onnx_types = map_torch_types_to_onnx.get(type(module))
backend_type = aimet_op_to_backend_op_name_map.get(module.__class__.__name__)

if onnx_types and backend_type in op_configs and backend_type not in merged_backend_types:
backend_type_op_config = op_configs[backend_type]
for onnx_type in onnx_types:
if onnx_type in op_configs and backend_type != onnx_type:
onnx_op_config = op_configs[onnx_type]
op_configs[backend_type] = merge_op_type_config(backend_type_op_config, onnx_op_config)
merged_backend_types.append(backend_type)

def _set_op_type_configs(self, op_configs: OpTypeType):
"""
Expand All @@ -414,15 +398,17 @@ def _set_op_type_configs(self, op_configs: OpTypeType):
"""
modified_tensor_quantizers = {}

self._merge_op_types_info(op_configs)

# Set op type configs for named modules
for module, input_output_tensor_quantizers in self._named_modules_to_tensor_quantizers_dict.items():
onnx_types = map_torch_types_to_onnx.get(type(module))

qnn_type = aimet_op_to_backend_op_name_map.get(module.__class__.__name__)
backend_type = aimet_op_to_backend_op_name_map.get(module.__class__.__name__)

if qnn_type in op_configs:
self._merge_and_set_op_type_config_for_module(module, op_configs, qnn_type, onnx_types,
input_output_tensor_quantizers, modified_tensor_quantizers)
if backend_type in op_configs:
self._set_config_for_module(input_output_tensor_quantizers, op_configs[backend_type],
modified_tensor_quantizers, module)
else:
if not onnx_types:
continue
Expand All @@ -440,7 +426,6 @@ def _set_op_type_configs(self, op_configs: OpTypeType):
logger.info(' Set op level config for elementwise op = {%s}', op.type)
self._set_config_for_module(input_output_tensor_quantizers, op_config, modified_tensor_quantizers)


def _set_config_for_module(self, input_output_tensor_quantizers: TensorQuantizersTupleType, op_config: OpType,
modified_tensor_quantizers: Dict[TensorQuantizer, Set], module: torch.nn.Module = None):
"""
Expand Down Expand Up @@ -711,17 +696,22 @@ def _generate_and_apply_op_instance_specific_config(self):
if per_channel_quantization:
wrapper.enable_per_channel_quantization()

def merge_op_config(qnn_type_op_config: Dict, onnx_type_op_config: Dict) -> Dict:
def merge_op_type_config(backend_type_op_config: Dict, onnx_type_op_config: Dict) -> Dict:
"""
Merges op_type_info from qnn_type_op_config to onnx_type_op_config
Merges op_type_info from onnx_type_op_config to backend_type_op_config
except supported kernels
:param qnn_type_op_config: Op. config generated from backend
:param backend_type_op_config: Op. config generated from backend
:param onnx_type_op_config: ONNX op type config
:return: Merged op type config
"""
if ConfigDictKeys.SUPPORTED_KERNELS in qnn_type_op_config:
onnx_type_op_config[ConfigDictKeys.SUPPORTED_KERNELS] = qnn_type_op_config[ConfigDictKeys.SUPPORTED_KERNELS]
return onnx_type_op_config
backend_type_op_config_supported_kernels = backend_type_op_config.get(ConfigDictKeys.SUPPORTED_KERNELS)
backend_type_op_config = copy.deepcopy(onnx_type_op_config)

if backend_type_op_config_supported_kernels:
backend_type_op_config[ConfigDictKeys.SUPPORTED_KERNELS] = backend_type_op_config_supported_kernels

return backend_type_op_config

def config_generator_factory(hw_version, supported_kernels, per_channel_quantization):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,5 +193,6 @@
'FullyConnected': 1,
'LayerNorm': 1,
'InstanceNorm': 1,
'GroupNorm': 1
'GroupNorm': 1,
'MatMul': 1
}

0 comments on commit 762d718

Please sign in to comment.