diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py b/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py index 5e4f00e4849..117040f6232 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py @@ -88,7 +88,7 @@ def __init__(self, quant_info: libquant_info.QcQuantizeInfo, self.rounding_mode = rounding_mode self._is_encoding_frozen = False self._tensor_quantizer = None - self._set_tensor_quantizer(self._build_tensor_quantizer()) + self._set_tensor_quantizer([self._build_tensor_quantizer()]) self.op_mode = op_mode self.bitwidth = bitwidth self.use_symmetric_encodings = use_symmetric_encodings @@ -126,9 +126,7 @@ def _create_tensor_quantizers(self, num: int): tensor_quantizer.isEncodingValid = False tensor_quantizers.append(tensor_quantizer) - self._tensor_quantizer = tensor_quantizers - self.quant_info.tensorQuantizerRef = [libpymo.PtrToInt64(tensor_quantizer) - for tensor_quantizer in tensor_quantizers] + self._set_tensor_quantizer(tensor_quantizers) self.reset_encoding_stats() @@ -178,14 +176,14 @@ def _build_tensor_quantizer(self): return libpymo.TensorQuantizer(MAP_QUANT_SCHEME_TO_PYMO[self.quant_scheme], MAP_ROUND_MODE_TO_PYMO[self.rounding_mode]) - def _set_tensor_quantizer(self, tensor_quantizer: libpymo.TensorQuantizer): + def _set_tensor_quantizer(self, tensor_quantizers: List[libpymo.TensorQuantizer]): """ Stores tensor_quantizer in self._tensor_quantizer and passes a pointer to the object to the C++ op's QcQuantInfo object - :param tensor_quantizer: The libpymo.TensorQuantizer object to give to the C++ op + :param tensor_quantizers: The list of libpymo.TensorQuantizer objects to give to the C++ op """ - self._tensor_quantizer = [tensor_quantizer] - self.quant_info.tensorQuantizerRef = [libpymo.PtrToInt64(tensor_quantizer)] + self._tensor_quantizer = tensor_quantizers + self.quant_info.tensorQuantizerRef = [libpymo.PtrToInt64(tensor_quantizer) for tensor_quantizer in tensor_quantizers] @property def enabled(self) -> bool: @@ -400,7 +398,7 @@ def set_quant_scheme(self, quant_scheme: QuantScheme): if self.quant_info.usePerChannelMode: self.enable_per_channel_quantization() else: - self._set_tensor_quantizer(self._build_tensor_quantizer()) + self._set_tensor_quantizer([self._build_tensor_quantizer()]) self.reset_encoding_stats() def compute_encodings(self) -> Optional[List[libpymo.TfEncoding]]: