Skip to content

Commit

Permalink
Make load_encodings_to_sim compatible with both v1 and v2
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored Feb 7, 2024
1 parent d763202 commit 7d50ba6
Show file tree
Hide file tree
Showing 8 changed files with 356 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#
# @@-COPYRIGHT-END-@@
# =============================================================================
#pylint: disable=too-many-lines
"""Fake-quantized modules"""

import contextlib
Expand Down Expand Up @@ -82,28 +83,88 @@ def export_input_encodings(self) -> List[List[Dict]]:
Returns a list of input encodings, each represented as a List of Dicts
"""
return [
quantizer.get_encodings() if isinstance(quantizer, QuantizerBase) else None
quantizer.get_legacy_encodings() if isinstance(quantizer, QuantizerBase) else None
for quantizer in _flatten_nn_module_list(self.input_quantizers)
]

def import_input_encodings(self, encodings: Dict[str, Dict]):
"""
Import input encodings represented in below format:
{
'0': dict,
'1': dict,
...
}
"""
for i, quantizer in enumerate(list(self.input_quantizers)):
encoding = encodings.get(str(i), None)
if not encoding:
self.input_quantizers[i] = None
continue
if quantizer is None:
raise RuntimeError
if isinstance(encoding, dict):
encoding = [encoding]
quantizer.set_legacy_encodings(encoding)

def export_output_encodings(self) -> List[List[Dict]]:
"""
Returns a list of output encodings, each represented as a List of Dicts
"""
return [
quantizer.get_encodings() if isinstance(quantizer, QuantizerBase) else None
quantizer.get_legacy_encodings() if isinstance(quantizer, QuantizerBase) else None
for quantizer in _flatten_nn_module_list(self.output_quantizers)
]

def import_output_encodings(self, encodings: Dict[str, Dict]):
"""
Import output encodings represented in below format:
{
'0': dict,
'1': dict,
...
}
"""
for i, quantizer in enumerate(list(self.output_quantizers)):
encoding = encodings.get(str(i), None)
if not encoding:
self.output_quantizers[i] = None
continue
if quantizer is None:
raise RuntimeError
if isinstance(encoding, dict):
encoding = [encoding]
quantizer.set_legacy_encodings(encoding)

def export_param_encodings(self) -> Dict[str, List[Dict]]:
"""
Returns a dict of {param name: param encodings}, with each encoding represented as a List of Dicts
"""
return {
param_name: quantizer.get_encodings() if isinstance(quantizer, QuantizerBase) else None
param_name: quantizer.get_legacy_encodings() if isinstance(quantizer, QuantizerBase) else None
for param_name, quantizer in self.param_quantizers.items()
}

def import_param_encodings(self, encodings: Dict[str, List[Dict]]):
"""
Import parameter encodings represented in below format:
{
'param_name_0': [dict, dict, ...],
'param_name_1': [dict, dict, ...],
...
}
"""
for param_name, quantizer in dict(self.param_quantizers).items():
encoding = encodings.get(param_name, None)
if not encoding:
self.param_quantizers[param_name] = None
continue
if quantizer is None:
raise RuntimeError
if isinstance(encoding, dict):
encoding = [encoding]
quantizer.set_legacy_encodings(encoding)

def get_original_module(self) -> nn.Module:
"""
Returns the floating point version of quantized module
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def set_range(self, min: torch.Tensor, max: torch.Tensor):
"""

@torch.no_grad()
def get_encodings(self) -> Optional[List[Dict]]:
def get_legacy_encodings(self) -> Optional[List[Dict]]:
"""
Returns a list of encodings, each represented as a List of Dicts
"""
Expand All @@ -146,6 +146,33 @@ def get_encodings(self) -> Optional[List[Dict]]:
for min_, max_, scale_, offset_ in zip(min, max, scale, offset)
]

@torch.no_grad()
def set_legacy_encodings(self, encodings: List[Dict]):
"""
Set encodings represented in the same format as the output of get_legacy_encodings as below:
[
{'min': float, 'max': float, 'scale': float, 'offset': float,
'bitwidth': int, 'dtype': str, 'is_symmetric': str},
{'min': float, 'max': float, 'scale': float, 'offset': float,
'bitwidth': int, 'dtype': str, 'is_symmetric': str},
...
]
"""
def str_to_bool(s: str):
s = s.lower()
if s == "false":
return False
if s == "true":
return True
raise ValueError

self.bitwidth = encodings[0]['bitwidth']
self.symmetric = str_to_bool(encodings[0]['is_symmetric'])
min_ = torch.tensor([e['min'] for e in encodings])
max_ = torch.tensor([e['max'] for e in encodings])
self.set_range(min_, max_)

def extra_repr(self) -> str:
return f'shape={self.shape}, bitwidth={self.bitwidth}, symmetric={self.symmetric}'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,17 @@ def compute_encodings(self):
"""

@abc.abstractmethod
def get_encodings(self) -> Optional[List[Dict]]:
def get_legacy_encodings(self) -> Optional[List[Dict]]:
"""
Returns a list of encodings, each represented as a List of Dicts
"""

@abc.abstractmethod
def set_legacy_encodings(self, encodings: List[Dict]):
"""
Set encodings represented in the same format as the output of get_legacy_encodings.
"""

def register_quantization_parameter(self, name: str, param: nn.Parameter):
"""
Register quantization parameter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,23 @@ def is_bfloat16(self):
return self.exponent_bits == _BFLOAT16_EXPONENT_BITS and \
self.mantissa_bits == _BFLOAT16_MANTISSA_BITS

def get_encodings(self) -> Optional[List[Dict]]:
def get_legacy_encodings(self) -> Optional[List[Dict]]:
return [{'bitwidth': self.bitwidth, 'dtype': 'float'}]

def set_legacy_encodings(self, encodings: List[Dict]):
"""
Set encodings represented in the same format as the output of get_legacy_encodings as below:
[
{'bitwidth': int, 'dtype': str},
...
]
"""
if encodings[0]['bitwidth'] != 16:
raise RuntimeError(f"{self.__class__} can only import 16-bit legay encodings.")
self.exponent_bits = 5
self.mantissa_bits = 10

@contextlib.contextmanager
def compute_encodings(self):
"""
Expand Down
167 changes: 103 additions & 64 deletions TrainingExtensions/torch/src/python/aimet_torch/qc_quantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _quantize_activation(self, tensor_quantizers, tensors_to_quantize):
return outputs


class QcQuantizeWrapper(nn.Module):
class QcQuantizeWrapper(nn.Module): # pylint: disable=too-many-public-methods
"""
Base class for the quantization custom ops
"""
Expand Down Expand Up @@ -366,78 +366,34 @@ def set_activation_encoding(self, module_name: str, activation_encodings: Dict):
:param module_name: name of module
:param activation_encodings: activation encodings dictionary
"""
_logger.info("Setting quantization encodings for activation quantizers of: %s", module_name)

def _set_quantizer_encodings(type_of_quantizer: str, quantizers: List[TensorQuantizer]):
"""
Sets bitwidth, symmetric mode and encodings for quantizer of type input or output
:param type_of_quantizer: input or output
:param quantizers: input or output quantizers
"""
if module_name in activation_encodings and type_of_quantizer in activation_encodings[module_name]:
encodings = activation_encodings[module_name][type_of_quantizer]
# The number of quantizers and encodings might not be same. For example, suppose the 1st output
# quantizer is disabled out of 4. The number of encodings will be 3, but number of output quantizers
# will still be 4.
# This can occur if a certain quantizer corresponded to a tensor with unquantizable datatype.
for index, quantizer in enumerate(quantizers):
ind = str(index)
if ind not in encodings:
quantizer.enabled = False
_logger.debug("No encoding loaded for %s quantizer %s of layer %s", type_of_quantizer, ind,
module_name)
continue
if not quantizer.enabled:
raise RuntimeError("The quantsim passed for loading encodings does not have the same "
"configuration as the quantsim which was used to export the encodings")

if quantizer._is_encoding_frozen: # pylint: disable=protected-access
_logger.debug("Encodings are frozen for module %s and quantizer type %s", module_name,
type_of_quantizer)
continue

if encodings[ind]['dtype'] == 'int':
encoding, is_symmetric = utils.create_encoding_from_dict(encodings[ind])
quantizer.bitwidth = encoding.bw
quantizer.use_symmetric_encodings = is_symmetric
quantizer.encoding = encoding
elif encodings[ind]['dtype'] == 'float':
quantizer.bitwidth = encodings[ind]['bitwidth']
quantizer.data_type = QuantizationDataType.float
else:
raise RuntimeError("Unrecognized encodings datatype")
try:
input_encoding = activation_encodings[module_name]['input']
except KeyError:
input_encoding = {}

_logger.info("Setting quantization encodings for activation quantizers of: %s", module_name)
self.import_input_encodings(input_encoding)

try:
output_encoding = activation_encodings[module_name]['output']
except KeyError:
output_encoding = {}

_set_quantizer_encodings(QUANTIZER_TYPE_INPUT, self.input_quantizers)
_set_quantizer_encodings(QUANTIZER_TYPE_OUTPUT, self.output_quantizers)
self.import_output_encodings(output_encoding)

def set_param_encoding(self, module_name: str, param_encodings: Dict):
"""
Set encoding for parameter from encodings dictionary
:param module_name: name of module
:param param_encodings: parameter encodings dictionary
"""
for orig_param_name, param_quantizer in self.param_quantizers.items():
param_name = module_name + '.' + orig_param_name
# pylint: disable=protected-access
if param_name in param_encodings and param_quantizer.enabled and not param_quantizer._is_encoding_frozen:
encodings = []
if param_encodings[param_name][0]['dtype'] == 'int':
is_symmetric = False
for encoding_dict in param_encodings[param_name]:
if encoding_dict['dtype'] == 'int':
encoding, is_symmetric = utils.create_encoding_from_dict(encoding_dict)
encodings.append(encoding)
param_quantizer.bitwidth = encodings[0].bw
param_quantizer.use_symmetric_encodings = is_symmetric
param_quantizer.encoding = encodings
elif param_encodings[param_name][0]['dtype'] == 'float':
param_quantizer.bitwidth = param_encodings[param_name][0]['bitwidth']
param_quantizer.data_type = QuantizationDataType.float
else:
raise RuntimeError("Data type does not match int or float in encodings file")

_logger.info("Setting quantization encodings for parameter: %s", param_name)
param_encoding = {
param_name: param_encodings[f'{module_name}.{param_name}']
for param_name, _ in self.param_quantizers.items()
if f'{module_name}.{param_name}' in param_encodings
}
self.import_param_encodings(param_encoding)

def freeze_param_encoding(self, module_name: str, param_encodings: Dict):
"""
Expand Down Expand Up @@ -496,7 +452,7 @@ def get_original_module(self) -> torch.nn.Module:
"""
return self._module_to_wrap

def export_param_encodings(self) -> Dict[str, List]:
def export_param_encodings(self) -> Dict[str, List[Dict]]:
"""
Returns the layer's parameter encodings in an exportable format
"""
Expand All @@ -514,6 +470,89 @@ def export_input_encodings(self) -> List[List[Dict]]:
"""
return [export_quantizer_encoding(quantizer) for quantizer in self.input_quantizers]

def import_param_encodings(self, encodings: Dict[str, List[Dict]]):
"""
Import parameter encodings represented in below format:
{
'param_name_0': [dict, dict, ...],
'param_name_1': [dict, dict, ...],
...
}
"""
for param_name, quantizer in self.param_quantizers.items():
encoding = encodings.get(param_name, None)
if not encoding:
quantizer.enabled = False
continue

if encoding[0]['dtype'] == 'int':
_, is_symmetric = utils.create_encoding_from_dict(encoding[0])
quantizer.use_symmetric_encodings = is_symmetric
quantizer.bitwidth = encoding[0]['bitwidth']
quantizer.encoding = [utils.create_encoding_from_dict(enc_dict)[0] for enc_dict in encoding]
quantizer.data_type = QuantizationDataType.int
elif encoding[0]['dtype'] == 'float':
quantizer.bitwidth = encoding[0]['bitwidth']
quantizer.data_type = QuantizationDataType.float
else:
raise RuntimeError("Data type does not match int or float in encodings file")

_logger.info("Setting quantization encodings for parameter: %s", param_name)

self.set_mode(QcQuantizeOpMode.ACTIVE)

def import_output_encodings(self, encodings: Dict[str, Dict]):
"""
Import output encodings represented in below format:
{
'0': dict,
'1': dict,
...
}
"""
self._import_encoding(encodings, self.output_quantizers)

def import_input_encodings(self, encodings: Dict[str, Dict]):
"""
Import input encodings represented in below format:
{
'0': dict,
'1': dict,
...
}
"""
self._import_encoding(encodings, self.input_quantizers)

def _import_encoding(self, encodings, quantizers):
assert quantizers is self.input_quantizers or quantizers is self.output_quantizers

for i, quantizer in enumerate(quantizers):
encoding = encodings.get(str(i), None)
if not encoding:
quantizer.enabled = False
continue
if not quantizer.enabled:
raise RuntimeError("The quantsim passed for loading encodings does not have the same "
"configuration as the quantsim which was used to export the encodings")
if quantizer._is_encoding_frozen: # pylint: disable=protected-access
type_of_quantizer = 'input' if quantizers is self.input_quantizers else 'output'
_logger.debug("Encodings are frozen for module %s quantizer of %s",
type_of_quantizer, self._module_to_wrap.__class__)
continue

if encoding['dtype'] == 'int':
encoding, is_symmetric = utils.create_encoding_from_dict(encoding)
quantizer.bitwidth = encoding.bw
quantizer.use_symmetric_encodings = is_symmetric
quantizer.encoding = encoding
elif encoding['dtype'] == 'float':
quantizer.bitwidth = encoding['bitwidth']
quantizer.data_type = QuantizationDataType.float
else:
raise RuntimeError("Unrecognized encodings datatype")

self.set_mode(QcQuantizeOpMode.ACTIVE)


class StaticGridQuantWrapper(QcQuantizeWrapper):
""" A custom PyTorch module that derives from QcQuantizeWrapper and quantizes modules """
Expand Down
Loading

0 comments on commit 7d50ba6

Please sign in to comment.