Skip to content

Commit

Permalink
Add getter methods in QuantDtypeBwInfo (#2595)
Browse files Browse the repository at this point in the history
Signed-off-by: yathindra kota <quic_ykota@quicinc.com>
  • Loading branch information
quic-ykota authored Dec 7, 2023
1 parent adfa587 commit 80436ed
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
23 changes: 18 additions & 5 deletions TrainingExtensions/common/src/python/aimet_common/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ class AdaroundConstants:

class QuantizationDataType(Enum):
""" Enumeration of tensor quantizer data types supported """
undefined = 0
int = 1
float = 2

Expand All @@ -361,7 +362,9 @@ class QuantDtypeBwInfo:
QuantDtypeBwInfo holds activation dtype/bw and param dtype/bw
"""

def __init__(self, act_dtype: QuantizationDataType, act_bw: int, param_dtype: QuantizationDataType = None, param_bw: int = None):

def __init__(self, act_dtype: QuantizationDataType, act_bw: int,
param_dtype: QuantizationDataType = QuantizationDataType.undefined, param_bw: int = 0):
"""
Data class to hold dtype and bw info
:param act_dtype: Activation datatype of type QuantizationDataType
Expand All @@ -375,9 +378,11 @@ def __init__(self, act_dtype: QuantizationDataType, act_bw: int, param_dtype: Qu
self.param_bw = param_bw
self._validate_inputs()

def __repr__(self):
return f'(activation:({self.act_dtype}, {self.act_bw}) param:({self.param_dtype}, {self.param_bw})'

def __str__(self):
return (f'(activation_data_type = {self.act_dtype}, act_bw = {self.act_bw} '
f'param_data_type = {self.param_dtype} param_bw = {self.param_bw})')
return f'activation:({self.act_dtype}, {self.act_bw}) param:({self.param_dtype}, {self.param_bw})'

def __eq__(self, other):
return self.act_dtype == other.act_dtype and self.act_bw == other.act_bw and \
Expand All @@ -396,18 +401,26 @@ def _validate_inputs(self):
raise ValueError(
'float act_dtype can only be used when act_bw is set to 16, not ' + str(self.act_bw))

def is_same_activation(self, bw: int, dtype: QuantizationDataType):
def is_same_activation(self, dtype: QuantizationDataType, bw: int):
"""
helper function to check if activation of the object is same as input
:param bw: bitwidth to verify against
:param dtype: dtype to verify against
"""
return bw == self.act_bw and dtype == self.act_dtype

def is_same_param(self, bw: int, dtype: QuantizationDataType):
def is_same_param(self, dtype: QuantizationDataType, bw: int):
"""
helper function to check if param of the object is same as input
:param bw: bitwidth to verify against
:param dtype: dtype to verify against
"""
return bw == self.param_bw and dtype == self.param_dtype

def get_activation(self) -> tuple:
""" getter method for activation candidate"""
return self.act_dtype, self.act_bw

def get_param(self) -> tuple:
""" getter method for param candidate"""
return self.param_dtype, self.param_bw
Original file line number Diff line number Diff line change
Expand Up @@ -1681,7 +1681,7 @@ def apply_act_rules(act: Tuple[int, QuantizationDataType], allowed_supported_ker
"""
if action != SupportedKernelsAction.allow_error:
for k in allowed_supported_kernels:
if k.is_same_activation(act[0], act[1]):
if k.is_same_activation(act[1], act[0]):
return

if action == SupportedKernelsAction.warn_on_error:
Expand Down

0 comments on commit 80436ed

Please sign in to comment.