Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add getter methods in QuantDtypeBwInfo #2595

Merged
merged 1 commit into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions TrainingExtensions/common/src/python/aimet_common/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ class AdaroundConstants:

class QuantizationDataType(Enum):
""" Enumeration of tensor quantizer data types supported """
undefined = 0
int = 1
float = 2

Expand All @@ -361,7 +362,9 @@ class QuantDtypeBwInfo:
QuantDtypeBwInfo holds activation dtype/bw and param dtype/bw
"""

def __init__(self, act_dtype: QuantizationDataType, act_bw: int, param_dtype: QuantizationDataType = None, param_bw: int = None):

def __init__(self, act_dtype: QuantizationDataType, act_bw: int,
param_dtype: QuantizationDataType = QuantizationDataType.undefined, param_bw: int = 0):
"""
Data class to hold dtype and bw info
:param act_dtype: Activation datatype of type QuantizationDataType
Expand All @@ -375,9 +378,11 @@ def __init__(self, act_dtype: QuantizationDataType, act_bw: int, param_dtype: Qu
self.param_bw = param_bw
self._validate_inputs()

def __repr__(self):
return f'(activation:({self.act_dtype}, {self.act_bw}) param:({self.param_dtype}, {self.param_bw})'

def __str__(self):
return (f'(activation_data_type = {self.act_dtype}, act_bw = {self.act_bw} '
f'param_data_type = {self.param_dtype} param_bw = {self.param_bw})')
return f'activation:({self.act_dtype}, {self.act_bw}) param:({self.param_dtype}, {self.param_bw})'

def __eq__(self, other):
return self.act_dtype == other.act_dtype and self.act_bw == other.act_bw and \
Expand All @@ -396,18 +401,26 @@ def _validate_inputs(self):
raise ValueError(
'float act_dtype can only be used when act_bw is set to 16, not ' + str(self.act_bw))

def is_same_activation(self, bw: int, dtype: QuantizationDataType):
def is_same_activation(self, dtype: QuantizationDataType, bw: int):
"""
helper function to check if activation of the object is same as input
:param bw: bitwidth to verify against
:param dtype: dtype to verify against
"""
return bw == self.act_bw and dtype == self.act_dtype

def is_same_param(self, bw: int, dtype: QuantizationDataType):
def is_same_param(self, dtype: QuantizationDataType, bw: int):
"""
helper function to check if param of the object is same as input
:param bw: bitwidth to verify against
:param dtype: dtype to verify against
"""
return bw == self.param_bw and dtype == self.param_dtype

def get_activation(self) -> tuple:
""" getter method for activation candidate"""
return self.act_dtype, self.act_bw

def get_param(self) -> tuple:
""" getter method for param candidate"""
return self.param_dtype, self.param_bw
Original file line number Diff line number Diff line change
Expand Up @@ -1681,7 +1681,7 @@ def apply_act_rules(act: Tuple[int, QuantizationDataType], allowed_supported_ker
"""
if action != SupportedKernelsAction.allow_error:
for k in allowed_supported_kernels:
if k.is_same_activation(act[0], act[1]):
if k.is_same_activation(act[1], act[0]):
return

if action == SupportedKernelsAction.warn_on_error:
Expand Down
Loading