Skip to content

Commit

Permalink
Populate supported kernels in config file and changes for backend awa…
Browse files Browse the repository at this point in the history
…re quantization
  • Loading branch information
quic-cgulecha authored Nov 30, 2023
1 parent 35e5882 commit 8960a64
Show file tree
Hide file tree
Showing 10 changed files with 299 additions and 787 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,19 @@

""" Utilities for backend aware quantization """


# pylint: disable=import-error, no-name-in-module
import copy
import logging
from typing import List, Dict
from dataclasses import dataclass

import json
import torch
from aimet_torch.qc_quantize_op import QcQuantizeWrapper
from aimet_torch import onnx_utils
from aimet_torch.translation_mapping import op_to_weight_index_map, backend_datatype_to_aimet_map,\
aimet_op_to_backend_op_name_map
from aimet_common.defs import QuantizationDataType
from aimet_common.quantsim_config.json_config_importer import ConfigDictKeys
from aimet_common.libpymo import ModelOpDefParser


Expand Down Expand Up @@ -96,7 +96,51 @@ def merge_constraints_from_xmls(op_name: str, supported_backend_info: SupportedB
else:
merged_op_and_supported_backend_info_map[op_name] = supported_backend_info

# pylint: disable=too-many-locals
def get_activation_constraints(parser: ModelOpDefParser, op_name_in_opdef: str) -> List[Dict]:
"""
Returns activation constraints for the op using parser
:param parser: ModelOpDefParser object
:param op_name_in_opdef: Op name whose activation constraints needs to be fetched
:return: List of activation constraints for the op in {bitwidth, dtype} Dict form
"""
datatype_constraints_size = parser.get_size(op_name_in_opdef)

output_datatype_constraints_size = datatype_constraints_size['output_size']

activation_constraints = []
for output_index in range(output_datatype_constraints_size):
try:
datatype_constraints = parser.get_output_datatype(op_name_in_opdef, output_index)
for datatype in datatype_constraints:
if backend_datatype_to_aimet_map[datatype] not in activation_constraints:
activation_constraints.append(backend_datatype_to_aimet_map[datatype])
# pylint: disable=bare-except
except:
#Parser API will throw appropriate error message if not able to get output datattypes
pass

return activation_constraints

def get_weight_constraints(parser: ModelOpDefParser, op_name_in_opdef: str) -> List[Dict]:
"""
Returns weight constraints for the op using parser
:param parser: ModelOpDefParser object
:param op_name_in_opdef: Op name whose weight constraints needs to be fetched
:return: List of weight constraints for the op in {bitwidth, dtype} Dict form
"""
weight_constraints = []
if op_name_in_opdef in op_to_weight_index_map.keys():
try:
datatype_constraints = parser.get_input_datatype(op_name_in_opdef, op_to_weight_index_map[op_name_in_opdef])
for datatype in datatype_constraints:
if backend_datatype_to_aimet_map[datatype] not in weight_constraints:
weight_constraints.append(backend_datatype_to_aimet_map[datatype])
# pylint: disable=bare-except
except:
#Parser API will throw appropriate error message if not able to get input datatypes
pass
return weight_constraints

def get_backend_info(op_names: List[str], master_opdef_path: str, backend_opdef_paths: List[str]) -> Dict[str, SupportedBackendInfo]:
"""
Returns backend constraints
Expand All @@ -113,7 +157,6 @@ def get_backend_info(op_names: List[str], master_opdef_path: str, backend_opdef_
if op_name in aimet_op_to_backend_op_name_map.keys():
op_names_according_to_backend[i] = aimet_op_to_backend_op_name_map[op_name]

# pylint: disable=too-many-nested-blocks
for backend_opdef_path in backend_opdef_paths:
parser = ModelOpDefParser(master_opdef_path, backend_opdef_path, op_names_according_to_backend)

Expand All @@ -122,38 +165,15 @@ def get_backend_info(op_names: List[str], master_opdef_path: str, backend_opdef_
for i, op_name in enumerate(op_names):
op_name_in_opdef = op_names_according_to_backend[i]

if op_name not in op_and_supported_backend_info_map:
datatype_constraints_size = parser.get_size(op_name_in_opdef)

output_datatype_constraints_size = datatype_constraints_size['output_size']

activation_constraints = []
for output_index in range(output_datatype_constraints_size):
try:
datatype_constraints = parser.get_output_datatype(op_name_in_opdef, output_index)
for datatype in datatype_constraints:
if backend_datatype_to_aimet_map[datatype] not in activation_constraints:
activation_constraints.append(backend_datatype_to_aimet_map[datatype])
# pylint: disable=bare-except
except:
#Parser API will throw appropriate error message if not able to get output datattypes
pass

weight_constraints = []
if op_name_in_opdef in op_to_weight_index_map.keys():
try:
datatype_constraints = parser.get_input_datatype(op_name_in_opdef, op_to_weight_index_map[op_name_in_opdef])
for datatype in datatype_constraints:
weight_constraints.append(backend_datatype_to_aimet_map[datatype])
# pylint: disable=bare-except
except:
#Parser API will throw appropriate error message if not able to get input datatypes
pass
if op_name_in_opdef not in op_and_supported_backend_info_map:

activation_constraints = get_activation_constraints(parser, op_name_in_opdef)
weight_constraints = get_weight_constraints(parser, op_name_in_opdef)

supported_backend_info = SupportedBackendInfo(activation_constraints, weight_constraints)
op_and_supported_backend_info_map[op_name] = supported_backend_info
op_and_supported_backend_info_map[op_name_in_opdef] = supported_backend_info

merge_constraints_from_xmls(op_name, supported_backend_info, merged_opname_supported_backend_info_map)
merge_constraints_from_xmls(op_name_in_opdef, supported_backend_info, merged_opname_supported_backend_info_map)

return merged_opname_supported_backend_info_map

Expand Down Expand Up @@ -252,6 +272,22 @@ def set_datatype_bitwidth_for_activations(module: torch.nn.Module, backend_act_c
logger.info("Setting datatype and bitwidth of %s input activations to %s and %s according to backend constraints.", module_type,
str(dtype_to_set_for_activation), str(bitwidth_to_set_for_activation))

def set_supported_kernel_for_op(module: torch.nn.Module, op_to_supported_kernels: Dict, supported_kernels_for_op: List[Dict]):
"""
Sets supported kernels for the op
:param module: Module
:param op_to_supported_kernels: Dict of op to it's supported kernels
:param supported_kernels_for_op: SUpported kernels for the op
"""

# pylint: disable=protected-access
if type(module._module_to_wrap) in onnx_utils.map_torch_types_to_onnx.keys(): # pylint: disable=unidiomatic-typecheck
onnx_types = onnx_utils.map_torch_types_to_onnx.get(type(module._module_to_wrap))
for op in onnx_types:
if op not in op_to_supported_kernels.keys():
op_to_supported_kernels[op] = supported_kernels_for_op

# pylint: disable=too-many-locals
def populate_backend_info(model: torch.nn.Module, module_types: List[str], master_opdef_file_path: str,
backend_opdef_file_paths: List[str], quantsim_info: QuantsimInfo) -> Dict[str, List]:
"""
Expand All @@ -272,7 +308,6 @@ def populate_backend_info(model: torch.nn.Module, module_types: List[str], maste
op_to_supported_kernels = {}
op_to_supported_kernels['defaults'] = [get_supported_kernel_in_dict_format(default_act_kernel[0], default_weight_kernel[0])]

# pylint:disable=too-many-nested-blocks
for module in model.modules():
if isinstance(module, QcQuantizeWrapper):
# pylint: disable=protected-access
Expand All @@ -295,12 +330,7 @@ def populate_backend_info(model: torch.nn.Module, module_types: List[str], maste
#set module's supported kernels
if is_weight_constraint_present or is_act_constraint_present:
supported_kernels_for_op = set_and_return_supported_kernels(module, backend_act_constraints, backend_weight_constraints, module_type)
# pylint: disable=protected-access
if type(module._module_to_wrap) in onnx_utils.map_torch_types_to_onnx.keys(): # pylint: disable=unidiomatic-typecheck
onnx_types = onnx_utils.map_torch_types_to_onnx.get(type(module._module_to_wrap))
for op in onnx_types:
if op not in op_to_supported_kernels.keys():
op_to_supported_kernels[op] = supported_kernels_for_op
set_supported_kernel_for_op(module, op_to_supported_kernels, supported_kernels_for_op)

#set bitwidth and dtype of module's weights according to supported_kernel
if 'weight' in module.param_quantizers and is_weight_constraint_present:
Expand All @@ -311,3 +341,80 @@ def populate_backend_info(model: torch.nn.Module, module_types: List[str], maste
set_datatype_bitwidth_for_activations(module, backend_act_constraints, module_type)

return op_to_supported_kernels

def get_constraint_accrording_to_json_config(constraint: Dict) -> Dict:
"""
Returns supported kernel constraint according to JSON file
:param constraint: Activation or Param constraint
:return: Dict format of suported kernel
"""
constraint_according_json = {}
constraint_according_json[ConfigDictKeys.BITWIDTH] = constraint[ConfigDictKeys.BITWIDTH]
if constraint[ConfigDictKeys.DTYPE] == QuantizationDataType.int:
constraint_according_json[ConfigDictKeys.DTYPE] = "int"
elif constraint[ConfigDictKeys.DTYPE] == QuantizationDataType.float:
constraint_according_json[ConfigDictKeys.DTYPE] = "float"
return constraint_according_json

def get_supported_kernels_from_backend_info(supported_backend_info: SupportedBackendInfo) -> List[Dict]:
"""
Returns supported for JSON config file from backend constraints
:param supported_backend_info: Object which stores
:return: List of supported kernels
"""
supported_kernels = []
for activation_constraint in supported_backend_info.activation_constraints:
if supported_backend_info.weights_constraints:
for weight_constraint in supported_backend_info.weights_constraints:

if weight_constraint[ConfigDictKeys.DTYPE] == activation_constraint[ConfigDictKeys.DTYPE]:
json_act_constraint = get_constraint_accrording_to_json_config(activation_constraint)
json_param_constraint = get_constraint_accrording_to_json_config(weight_constraint)

supported_kernel = {ConfigDictKeys.ACTIVATION: json_act_constraint,
ConfigDictKeys.PARAM: json_param_constraint}

if supported_kernel not in supported_kernels:
supported_kernels.append(supported_kernel)

else:
json_act_constraint = get_constraint_accrording_to_json_config(activation_constraint)
supported_kernel = {ConfigDictKeys.ACTIVATION: json_act_constraint}

if supported_kernel not in supported_kernels:
supported_kernels.append(supported_kernel)

return supported_kernels

def populate_supported_kernels_in_json_config(master_opdef_file_path: str,
backend_opdef_file_paths: List[str],
json_config_file_path: str):
"""
Populate supported kernels per op basis in JSON config file
:param master_opdef_file_path: Master opdef file path
:param backend_opdef_file_paths: Backend opdef file paths
:param json_config_file_path: Config file in which supported kernels will be populated
"""
supported_kernels_dict = get_backend_info(list(aimet_op_to_backend_op_name_map.keys()), master_opdef_file_path, backend_opdef_file_paths)

op_types_with_no_constraints = [key for key in supported_kernels_dict
if not supported_kernels_dict[key].activation_constraints and
not supported_kernels_dict[key].weights_constraints]

for op_type in op_types_with_no_constraints:
del supported_kernels_dict[op_type]

with open(json_config_file_path) as file:
quantsim_config = json.load(file)

for backend_op_type in supported_kernels_dict:
supported_kernels = get_supported_kernels_from_backend_info(supported_kernels_dict[backend_op_type])
if backend_op_type not in quantsim_config[ConfigDictKeys.OP_TYPE]:
quantsim_config[ConfigDictKeys.OP_TYPE][backend_op_type] = {}

quantsim_config[ConfigDictKeys.OP_TYPE][backend_op_type][ConfigDictKeys.SUPPORTED_KERNELS] = supported_kernels

with open(json_config_file_path, 'w') as file:
json.dump(quantsim_config, file, indent=4)
11 changes: 6 additions & 5 deletions TrainingExtensions/common/src/python/aimet_common/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ class QuantDtypeBwInfo:
QuantDtypeBwInfo holds activation dtype/bw and param dtype/bw
"""

def __init__(self, act_dtype: QuantizationDataType, act_bw: int, param_dtype: QuantizationDataType, param_bw: int):
def __init__(self, act_dtype: QuantizationDataType, act_bw: int, param_dtype: QuantizationDataType = None, param_bw: int = None):
"""
Data class to hold dtype and bw info
:param act_dtype: Activation datatype of type QuantizationDataType
Expand All @@ -387,11 +387,12 @@ def _validate_inputs(self):
"""
Validate inputs
"""
if self.param_dtype == QuantizationDataType.float and self.param_bw != 16:
raise ValueError(
'float param_dtype can only be used when param_bw is set to 16, not ' + str(self.param_bw))
if self.param_dtype and self.param_bw:
if self.param_dtype == QuantizationDataType.float and self.param_bw not in [16, 32]:
raise ValueError(
'float param_dtype can only be used when param_bw is set to 16, not ' + str(self.param_bw))

if self.act_dtype == QuantizationDataType.float and self.act_bw != 16:
if self.act_dtype == QuantizationDataType.float and self.act_bw not in [16, 32]:
raise ValueError(
'float act_dtype can only be used when act_bw is set to 16, not ' + str(self.act_bw))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,15 @@ def _validate_supported_kernels(supported_kernels: List):
if supported_kernels:
for supported_kernel in supported_kernels:
if supported_kernel["activation"]["dtype"] == QuantizationDataType.float and \
supported_kernel["activation"]["bitwidth"] != 16:
supported_kernel["activation"]["bitwidth"] not in [16, 32]:
logger.error('Activation dtype:float is only supported with bitwidth:16')
raise NotImplementedError('Activation dtype:float is only supported with bitwidth:16')

if supported_kernel["param"]["dtype"] == QuantizationDataType.float and \
supported_kernel["param"]["bitwidth"] != 16:
logger.error('Param dtype:float is only supported with bitwidth:16')
raise NotImplementedError('Param dtype:float is only supported with bitwidth:16')
if "param" in supported_kernel:
if supported_kernel["param"]["dtype"] == QuantizationDataType.float and \
supported_kernel["param"]["bitwidth"] not in [16, 32]:
logger.error('Param dtype:float is only supported with bitwidth:16')
raise NotImplementedError('Param dtype:float is only supported with bitwidth:16')

def _validate_semantics(quantsim_config: ConfigDictType):
"""
Expand Down Expand Up @@ -216,10 +217,11 @@ def _convert_str_to_quantization_data_type_helper(supported_kernels: List):
else:
supported_kernel["activation"]["dtype"] = QuantizationDataType.int

if supported_kernel["param"]["dtype"] == "float":
supported_kernel["param"]["dtype"] = QuantizationDataType.float
else:
supported_kernel["param"]["dtype"] = QuantizationDataType.int
if "param" in supported_kernel:
if supported_kernel["param"]["dtype"] == "float":
supported_kernel["param"]["dtype"] = QuantizationDataType.float
else:
supported_kernel["param"]["dtype"] = QuantizationDataType.int


def _convert_dtype_to_quantization_data_type(quantsim_config: ConfigDictType):
Expand Down
Loading

0 comments on commit 8960a64

Please sign in to comment.