Skip to content

Commit

Permalink
Deprecate encodings property
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Tuttle <quic_mtuttle@quicinc.com>
  • Loading branch information
quic-mtuttle committed Oct 9, 2024
1 parent 6ba99b1 commit bb03529
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 12 deletions.
21 changes: 16 additions & 5 deletions TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,16 +256,27 @@ def use_unsigned_symmetric(self, use_unsigned_symmetric: bool):
tensor_quantizer.setUnsignedSymmetric(use_unsigned_symmetric)
self._reset_encodings()

@property
def encodings(self) -> Optional[List[libpymo.TfEncoding]]:
def get_encodings(self) -> Optional[List[libpymo.TfEncoding]]:
"""
Reads the encodings object from the node's QcQuantizeInfo
:return: The libpymo.TfEncoding object used to store the node's quantization encoding
"""
if not self.is_initialized() or self.data_type == QuantizationDataType.float:
return None
return self.quant_info.encoding

@property
@deprecated(f"Use {get_encodings.__qualname__} instead")
def encodings(self) -> Optional[List[libpymo.TfEncoding]]:
"""Deprecated. Use :meth:`get_encodings` to set the quantizer encodings.
Reads the encodings object from the node's QcQuantizeInfo
:return: The libpymo.TfEncoding object used to store the node's quantization encoding
"""
return self.get_encodings()

def update_quantizer_and_load_encodings(self, encoding: List[libpymo.TfEncoding], is_symmetric: Optional[bool],
is_strict_symmetric: Optional[bool], is_unsigned_symmetric: Optional[bool],
data_type: QuantizationDataType):
Expand Down Expand Up @@ -308,7 +319,7 @@ def load_encodings(self, encoding: List[libpymo.TfEncoding]):
assert isinstance(encoding, (list, tuple))
assert len(encoding) == len(self._tensor_quantizer)
if self.data_type == QuantizationDataType.float:
raise RuntimeError(f"`load_encodings` is not supported for floating-point quantizers.")
raise RuntimeError(f"{type(self).load_encodings.__qualname__} is not supported for floating-point quantizers.")
for tensor_quantizer in self._tensor_quantizer:
tensor_quantizer.isEncodingValid = True
self.op_mode = OpMode.quantizeDequantize
Expand Down Expand Up @@ -422,7 +433,7 @@ def get_stats_histogram(self) -> List[List]:
if self.quant_scheme != QuantScheme.post_training_tf_enhanced:
raise RuntimeError("get_stats_histogram() can be invoked only when quantization scheme is TF-Enhanced.")

if not self.encodings:
if not self.get_encodings():
raise RuntimeError("get_stats_histogram() can be invoked only when encoding is computed.")

histogram = []
Expand Down Expand Up @@ -469,7 +480,7 @@ def _export_legacy_encodings(self) -> Union[List, None]:

if self.data_type == QuantizationDataType.int:
encodings = []
for encoding in self.encodings:
for encoding in self.get_encodings():
enc_dict = dict(min=encoding.min,
max=encoding.max,
scale=encoding.delta,
Expand Down
15 changes: 9 additions & 6 deletions TrainingExtensions/onnx/src/python/aimet_onnx/quant_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,23 +474,26 @@ def export_per_layer_encoding_min_max_range(self, sim: QuantizationSimModel, res
# Get input activations' encodings if starting op
for index, quantizer in enumerate(input_quantizers):
name = f"{op_name}_input_{index}"
min_max_range_for_activations_dict[name] = (quantizer.encodings[0].min, quantizer.encodings[0].max)
encodings = quantizer.get_encodings()
min_max_range_for_activations_dict[name] = (encodings[0].min, encodings[0].max)

# Get output activations' encodings
for index, quantizer in enumerate(output_quantizers):
name = f"{op_name}_output_{index}"
min_max_range_for_activations_dict[name] = (quantizer.encodings[0].min, quantizer.encodings[0].max)
encodings = quantizer.get_encodings()
min_max_range_for_activations_dict[name] = (encodings[0].min, encodings[0].max)

# Get parameters' encodings
for param_name, quantizer in param_quantizers.items():
name = re.sub(r'\W+', '_', f"{op_name}_{param_name}")
if len(quantizer.encodings) > 1: # per-channel
encodings = quantizer.get_encodings()
if len(encodings) > 1: # per-channel
per_channel_encodings = {}
for index, encoding in enumerate(quantizer.encodings):
for index, encoding in enumerate(encodings):
per_channel_encodings[f"{name}_{index}"] = (encoding.min, encoding.max)
min_max_range_for_weights_dict[name] = per_channel_encodings
else: # per-tensor
min_max_range_for_weights_dict[name] = (quantizer.encodings[0].min, quantizer.encodings[0].max)
min_max_range_for_weights_dict[name] = (encodings[0].min, encodings[0].max)

create_and_export_min_max_ranges_plot(min_max_range_for_weights_dict, min_max_ranges_dir, title="weights")
create_and_export_min_max_ranges_plot(min_max_range_for_activations_dict, min_max_ranges_dir, title="activations")
Expand Down Expand Up @@ -564,7 +567,7 @@ def _create_and_export_stats_histogram_plot(quantizer: QcQuantizeOp, results_dir
os.makedirs(results_dir, exist_ok=True)

histograms = quantizer.get_stats_histogram()
encodings = quantizer.encodings
encodings = quantizer.get_encodings()

if not isinstance(encodings, List):
encodings = [encodings]
Expand Down
9 changes: 8 additions & 1 deletion TrainingExtensions/onnx/test/python/test_qc_quantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,4 +851,11 @@ def test_export_float_encodings(self):
encodings = qc_quantize_op.export_encodings("0.6.1")
assert len(encodings) == 1
assert encodings[0]["dtype"] == "float"
assert encodings[0]["bitwidth"] == 16
assert encodings[0]["bitwidth"] == 16

def test_load_float_encodings(self):
quant_info = libquant_info.QcQuantizeInfo()
qc_quantize_op = QcQuantizeOp(quant_info, bitwidth=16, op_mode=OpMode.quantizeDequantize)
qc_quantize_op.data_type = QuantizationDataType.float
with pytest.raises(RuntimeError):
qc_quantize_op.load_encodings([libpymo.TfEncoding()])

0 comments on commit bb03529

Please sign in to comment.