diff --git a/TrainingExtensions/common/src/python/aimet_common/defs.py b/TrainingExtensions/common/src/python/aimet_common/defs.py index 3c1b56a9c68..ef4be316d49 100644 --- a/TrainingExtensions/common/src/python/aimet_common/defs.py +++ b/TrainingExtensions/common/src/python/aimet_common/defs.py @@ -346,6 +346,7 @@ class AdaroundConstants: class QuantizationDataType(Enum): """ Enumeration of tensor quantizer data types supported """ + undefined = 0 int = 1 float = 2 @@ -361,7 +362,9 @@ class QuantDtypeBwInfo: QuantDtypeBwInfo holds activation dtype/bw and param dtype/bw """ - def __init__(self, act_dtype: QuantizationDataType, act_bw: int, param_dtype: QuantizationDataType = None, param_bw: int = None): + + def __init__(self, act_dtype: QuantizationDataType, act_bw: int, + param_dtype: QuantizationDataType = QuantizationDataType.undefined, param_bw: int = 0): """ Data class to hold dtype and bw info :param act_dtype: Activation datatype of type QuantizationDataType @@ -375,9 +378,11 @@ def __init__(self, act_dtype: QuantizationDataType, act_bw: int, param_dtype: Qu self.param_bw = param_bw self._validate_inputs() + def __repr__(self): + return f'(activation:({self.act_dtype}, {self.act_bw}) param:({self.param_dtype}, {self.param_bw})' + def __str__(self): - return (f'(activation_data_type = {self.act_dtype}, act_bw = {self.act_bw} ' - f'param_data_type = {self.param_dtype} param_bw = {self.param_bw})') + return f'activation:({self.act_dtype}, {self.act_bw}) param:({self.param_dtype}, {self.param_bw})' def __eq__(self, other): return self.act_dtype == other.act_dtype and self.act_bw == other.act_bw and \ @@ -396,7 +401,7 @@ def _validate_inputs(self): raise ValueError( 'float act_dtype can only be used when act_bw is set to 16, not ' + str(self.act_bw)) - def is_same_activation(self, bw: int, dtype: QuantizationDataType): + def is_same_activation(self, dtype: QuantizationDataType, bw: int): """ helper function to check if activation of the object is same as input :param bw: bitwidth to verify against @@ -404,10 +409,18 @@ def is_same_activation(self, bw: int, dtype: QuantizationDataType): """ return bw == self.act_bw and dtype == self.act_dtype - def is_same_param(self, bw: int, dtype: QuantizationDataType): + def is_same_param(self, dtype: QuantizationDataType, bw: int): """ helper function to check if param of the object is same as input :param bw: bitwidth to verify against :param dtype: dtype to verify against """ return bw == self.param_bw and dtype == self.param_dtype + + def get_activation(self) -> tuple: + """ getter method for activation candidate""" + return self.act_dtype, self.act_bw + + def get_param(self) -> tuple: + """ getter method for param candidate""" + return self.param_dtype, self.param_bw diff --git a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py index 069878ec67b..c637adfb6d6 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py @@ -1681,7 +1681,7 @@ def apply_act_rules(act: Tuple[int, QuantizationDataType], allowed_supported_ker """ if action != SupportedKernelsAction.allow_error: for k in allowed_supported_kernels: - if k.is_same_activation(act[0], act[1]): + if k.is_same_activation(act[1], act[0]): return if action == SupportedKernelsAction.warn_on_error: