Skip to content

Commit

Permalink
Use formal type comparison only (#3453)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored Nov 1, 2024
1 parent 9c0accf commit fea766d
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -454,20 +454,16 @@ def find_supported_candidates(quantizer_groups: List[QuantizerGroup],
supported_kernel_types = set()
for supported_kernel_op in quantizer_group.supported_kernel_ops:
module = module_name_to_module_dict[supported_kernel_op]._module_to_wrap
try:
backend_type = aimet_op_to_backend_op_name_map[module.__class__]
except KeyError:
backend_type = aimet_op_to_backend_op_name_map.get(module.__class__.__name__)
backend_type = aimet_op_to_backend_op_name_map.get(type(module))

if backend_type in supported_kernels:
supported_kernel_types.add(backend_type)
else:
onnx_types = onnx_utils.map_torch_types_to_onnx.get(
type(module_name_to_module_dict[supported_kernel_op]._module_to_wrap), [])
onnx_types = onnx_utils.map_torch_types_to_onnx.get(type(module), [])

if not onnx_types:
logger.warning("No mapping found for %s in the torch to onnx op type mapping dictionary.",
str(type(module_name_to_module_dict[supported_kernel_op]._module_to_wrap)))
type(module))

supported_kernel_types.update(onnx_types)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,7 @@ def _merge_op_types_info(self, op_configs: OpTypeType):
for module, _ in self._named_modules_to_tensor_quantizers_dict.items():

onnx_types = map_torch_types_to_onnx.get(type(module))
try:
backend_type = aimet_op_to_backend_op_name_map[module.__class__]
except KeyError:
backend_type = aimet_op_to_backend_op_name_map.get(module.__class__.__name__)
backend_type = aimet_op_to_backend_op_name_map.get(type(module))

if onnx_types and backend_type in op_configs and backend_type not in merged_backend_types:
backend_type_op_config = op_configs[backend_type]
Expand All @@ -413,11 +410,7 @@ def _set_op_type_configs(self, op_configs: OpTypeType):
# Set op type configs for named modules
for module, input_output_tensor_quantizers in self._named_modules_to_tensor_quantizers_dict.items():
onnx_types = map_torch_types_to_onnx.get(type(module))

try:
backend_type = aimet_op_to_backend_op_name_map[module.__class__]
except KeyError:
backend_type = aimet_op_to_backend_op_name_map.get(module.__class__.__name__)
backend_type = aimet_op_to_backend_op_name_map.get(type(module))

if backend_type in op_configs:
self._set_config_for_module(input_output_tensor_quantizers, op_configs[backend_type],
Expand Down Expand Up @@ -942,11 +935,9 @@ def generate(self, module: torch.nn.Module, op_type: str) -> Tuple[dict, bool]:
:return: supported_kernels and per_channel_quantization fields
"""
supported_kernels = []
if (module.__class__ in aimet_op_to_backend_op_name_map) or (module.__class__.__name__ in aimet_op_to_backend_op_name_map):
try:
backend_type = aimet_op_to_backend_op_name_map[module.__class__]
except KeyError:
backend_type = aimet_op_to_backend_op_name_map[module.__class__.__name__]
backend_type = aimet_op_to_backend_op_name_map.get(type(module))

if backend_type is not None:
supported_kernels = self.op_type_supported_kernels.get(backend_type)

if not supported_kernels:
Expand Down
121 changes: 2 additions & 119 deletions TrainingExtensions/torch/src/python/aimet_torch/translation_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,124 +75,7 @@
QnnDatatype.QNN_DATATYPE_BOOL_8: {'bitwidth': 8,
'dtype': QuantizationDataType.int}
}
aimet_op_to_backend_op_name_map = {"Conv1d":"Conv1d",
"Conv2d":"Conv2d",
"Conv3d":"Conv3d",
"ConvTranspose1d":"TransposeConv1d",
"ConvTranspose2d":"TransposeConv2d",
"ConvTranspose3d":"TransposeConv3d",
"ReLU":"Relu",
"Tanh":"Tanh",
"Sigmoid":"Sigmoid",
"ELU":"Elu",
"ReLU6":"Relu6",
"Hardtanh":"ReluMinMax",
"Hardswish":"HardSwish",
"Add":"ElementWiseAdd",
"Subtract":"ElementWiseSubtract",
"Multiply":"ElementWiseMultiply",
"Divide":"ElementWiseDivide",
"Mul":"ElementWiseMultiply",
"Div":"ElementWiseDivide",
"Minimum":"ElementWiseMinimum",
"Maximum":"ElementWiseMaximum",
"Pow":"ElementWisePower",
"Remainder":"ElementWiseMod",
"Fmod":"ElementWiseFmod",
"Exponential":"ElementWiseExp",
"Log":"ElementWiseLog",
"Sqrt":"ElementWiseRsqrt",
"Abs":"ElementWiseAbs",
"Neg":"ElementWiseNeg",
"Erf":"Gelu",
"Round":"ElementWiseRound",
"Where":"ElementWiseSelect",
"Equal":"ElementWiseEqual",
"Greater":"ElementWiseGreater",
"Less":"ElementWiseLess",
"GreaterEqual":"ElementWiseGreaterEqual",
"LessEqual":"ElementWiseLessEqual",
"LogicalOr":"ElementWiseOr",
"LogicalAnd":"ElementWiseAnd",
"LogicalNot":"ElementWiseNot",
"Mean":"ReduceMean",
"Sum":"ReduceSum",
"Prod":"ReduceProd",
"ElementwiseCeil":"ElementWiseCeil",
"ElementwiseFloor":"ElementWiseFloor",
"Split":"Split",
"Concat":"Concat",
"MaxPool2d":"PoolMax2d",
"MaxPool3d":"PoolMax3d",
"AvgPool2d":"PoolAvg2d",
"AvgPool3d":"PoolAvg3d",
"LPPool2d":"L2Pool2d",
"Reshape":"Reshape",
"Permute":"Transpose",
"Upsample":"Resize",
"Linear":"FullyConnected",
"Softmax":"Softmax",
"LogSoftmax":"LogSoftmax",
"LayerNorm":"LayerNorm",
"Softplus":"ElementWiseSoftplus",
"PReLU":"Prelu",
"CustomGather":"Gather",
"InstanceNorm1d":"InstanceNorm",
"InstanceNorm2d":"InstanceNorm",
"InstanceNorm3d":"InstanceNorm",
"MatMul":"MatMul",
"CumSum":"CumulativeSum",
"Argmin":"Argmin",
"Argmax":"Argmax",
"Sin":"ElementWiseSin",
"Cos":"ElementWiseCos",
"Asin":"ElementWiseAsin",
"Atan":"ElementWiseAtan",
"Normalize":"L2Norm",
"Gather":"Gather",
"ChannelShuffle":"ChannelShuffle",
"Pad":"Pad",
"ElementwiseUnarySign":"ElementWiseUnary",
"RoIPool":"RoiPooling",
"PixelShuffle":"DepthToSpace",
"DepthToSpaceDCRMode":"DepthToSpace",
"PixelUnshuffle":"SpaceToDepth",
"Min":"ReduceMin",
"Max":"ReduceMax",
"NonZero":"NonZero",
"TopK":"TopK",
"Shape":"Shape",
"Tile":"Tile",
"LocalResponseNorm":"Lrn",
"LSTM":"Lstm",
"ScatterND":"ScatterNd",
"RoiAlign":"RoiAlign",
"NonMaxSuppression":"NonMaxSuppression",
"GatherNd":"GatherNd",
"BatchNorm1d":"Batchnorm",
"BatchNorm2d":"Batchnorm",
"BatchNorm3d":"Batchnorm",
"OneHot":"OneHot",
"ScatterElements":"ScatterElements",
"LeakyReLU":"Prelu",
"GRU":"Gru",
"CustomLayerNorm":"LayerNorm",
"IndexSelect":"Gather",
"Embedding":"Gather",
"Expand":"ElementWiseMultiply",
"Stack":"Pack",
"UnBind":"UnPack",
"SpaceToBatch":"SpaceToBatch",
"BatchToSpace":"BatchToSpace",
"Moments":"Moments",
"CropAndResize":"CropAndResize",
"FloorDivide":"ElementWiseFloorDiv",
"GELU":"Gelu",
"Cast":"Cast",
"StridedSlice":"StridedSlice",
"GroupNorm":"GroupNorm"
}
aimet_op_to_backend_op_name_map.update({
aimet_op_to_backend_op_name_map = {
torch.nn.Conv1d: "Conv1d",
torch.nn.Conv2d: "Conv2d",
torch.nn.Conv3d: "Conv3d",
Expand Down Expand Up @@ -303,7 +186,7 @@
custom.StridedSlice: "StridedSlice",
torch.nn.GroupNorm: "GroupNorm",
custom.GroupNorm: "GroupNorm",
})
}


op_to_weight_index_map = {'Conv1d' : 1,
Expand Down
4 changes: 2 additions & 2 deletions TrainingExtensions/torch/src/python/aimet_torch/v2/nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ def wrapper(quantized_cls):
# of the module due to the limitation of v1 implementation.
# Should redefine `aimet_to_to_backend_op_name_map` as `Dict[Type[Module], str]`
from aimet_torch.translation_mapping import aimet_op_to_backend_op_name_map
backend_op_name = aimet_op_to_backend_op_name_map.get(module_cls.__name__, None)
backend_op_name = aimet_op_to_backend_op_name_map.get(module_cls, None)
if backend_op_name:
aimet_op_to_backend_op_name_map[quantized_cls.__name__] = backend_op_name
aimet_op_to_backend_op_name_map[quantized_cls] = backend_op_name

return quantized_cls

Expand Down

0 comments on commit fea766d

Please sign in to comment.