diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py b/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py index 1e615841460..5e4f00e4849 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py @@ -256,16 +256,27 @@ def use_unsigned_symmetric(self, use_unsigned_symmetric: bool): tensor_quantizer.setUnsignedSymmetric(use_unsigned_symmetric) self._reset_encodings() - @property - def encodings(self) -> Optional[List[libpymo.TfEncoding]]: + def get_encodings(self) -> Optional[List[libpymo.TfEncoding]]: """ Reads the encodings object from the node's QcQuantizeInfo + :return: The libpymo.TfEncoding object used to store the node's quantization encoding """ if not self.is_initialized() or self.data_type == QuantizationDataType.float: return None return self.quant_info.encoding + @property + @deprecated(f"Use {get_encodings.__qualname__} instead") + def encodings(self) -> Optional[List[libpymo.TfEncoding]]: + """Deprecated. Use :meth:`get_encodings` to set the quantizer encodings. + + Reads the encodings object from the node's QcQuantizeInfo + + :return: The libpymo.TfEncoding object used to store the node's quantization encoding + """ + return self.get_encodings() + def update_quantizer_and_load_encodings(self, encoding: List[libpymo.TfEncoding], is_symmetric: Optional[bool], is_strict_symmetric: Optional[bool], is_unsigned_symmetric: Optional[bool], data_type: QuantizationDataType): @@ -308,7 +319,7 @@ def load_encodings(self, encoding: List[libpymo.TfEncoding]): assert isinstance(encoding, (list, tuple)) assert len(encoding) == len(self._tensor_quantizer) if self.data_type == QuantizationDataType.float: - raise RuntimeError(f"`load_encodings` is not supported for floating-point quantizers.") + raise RuntimeError(f"{type(self).load_encodings.__qualname__} is not supported for floating-point quantizers.") for tensor_quantizer in self._tensor_quantizer: tensor_quantizer.isEncodingValid = True self.op_mode = OpMode.quantizeDequantize @@ -422,7 +433,7 @@ def get_stats_histogram(self) -> List[List]: if self.quant_scheme != QuantScheme.post_training_tf_enhanced: raise RuntimeError("get_stats_histogram() can be invoked only when quantization scheme is TF-Enhanced.") - if not self.encodings: + if not self.get_encodings(): raise RuntimeError("get_stats_histogram() can be invoked only when encoding is computed.") histogram = [] @@ -469,7 +480,7 @@ def _export_legacy_encodings(self) -> Union[List, None]: if self.data_type == QuantizationDataType.int: encodings = [] - for encoding in self.encodings: + for encoding in self.get_encodings(): enc_dict = dict(min=encoding.min, max=encoding.max, scale=encoding.delta, diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/quant_analyzer.py b/TrainingExtensions/onnx/src/python/aimet_onnx/quant_analyzer.py index 3666a6304b8..867b88d0b42 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/quant_analyzer.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/quant_analyzer.py @@ -474,23 +474,26 @@ def export_per_layer_encoding_min_max_range(self, sim: QuantizationSimModel, res # Get input activations' encodings if starting op for index, quantizer in enumerate(input_quantizers): name = f"{op_name}_input_{index}" - min_max_range_for_activations_dict[name] = (quantizer.encodings[0].min, quantizer.encodings[0].max) + encodings = quantizer.get_encodings() + min_max_range_for_activations_dict[name] = (encodings[0].min, encodings[0].max) # Get output activations' encodings for index, quantizer in enumerate(output_quantizers): name = f"{op_name}_output_{index}" - min_max_range_for_activations_dict[name] = (quantizer.encodings[0].min, quantizer.encodings[0].max) + encodings = quantizer.get_encodings() + min_max_range_for_activations_dict[name] = (encodings[0].min, encodings[0].max) # Get parameters' encodings for param_name, quantizer in param_quantizers.items(): name = re.sub(r'\W+', '_', f"{op_name}_{param_name}") - if len(quantizer.encodings) > 1: # per-channel + encodings = quantizer.get_encodings() + if len(encodings) > 1: # per-channel per_channel_encodings = {} - for index, encoding in enumerate(quantizer.encodings): + for index, encoding in enumerate(encodings): per_channel_encodings[f"{name}_{index}"] = (encoding.min, encoding.max) min_max_range_for_weights_dict[name] = per_channel_encodings else: # per-tensor - min_max_range_for_weights_dict[name] = (quantizer.encodings[0].min, quantizer.encodings[0].max) + min_max_range_for_weights_dict[name] = (encodings[0].min, encodings[0].max) create_and_export_min_max_ranges_plot(min_max_range_for_weights_dict, min_max_ranges_dir, title="weights") create_and_export_min_max_ranges_plot(min_max_range_for_activations_dict, min_max_ranges_dir, title="activations") @@ -564,7 +567,7 @@ def _create_and_export_stats_histogram_plot(quantizer: QcQuantizeOp, results_dir os.makedirs(results_dir, exist_ok=True) histograms = quantizer.get_stats_histogram() - encodings = quantizer.encodings + encodings = quantizer.get_encodings() if not isinstance(encodings, List): encodings = [encodings] diff --git a/TrainingExtensions/onnx/test/python/test_qc_quantize_op.py b/TrainingExtensions/onnx/test/python/test_qc_quantize_op.py index 340e076e81e..f455408b072 100644 --- a/TrainingExtensions/onnx/test/python/test_qc_quantize_op.py +++ b/TrainingExtensions/onnx/test/python/test_qc_quantize_op.py @@ -851,4 +851,11 @@ def test_export_float_encodings(self): encodings = qc_quantize_op.export_encodings("0.6.1") assert len(encodings) == 1 assert encodings[0]["dtype"] == "float" - assert encodings[0]["bitwidth"] == 16 \ No newline at end of file + assert encodings[0]["bitwidth"] == 16 + + def test_load_float_encodings(self): + quant_info = libquant_info.QcQuantizeInfo() + qc_quantize_op = QcQuantizeOp(quant_info, bitwidth=16, op_mode=OpMode.quantizeDequantize) + qc_quantize_op.data_type = QuantizationDataType.float + with pytest.raises(RuntimeError): + qc_quantize_op.load_encodings([libpymo.TfEncoding()]) \ No newline at end of file