From e8a8ff7452bf9f5292672c4f1e06cb740827b724 Mon Sep 17 00:00:00 2001 From: Chetan Gulecha Date: Tue, 19 Dec 2023 15:45:01 +0530 Subject: [PATCH] Changes for TF quant analyzer (#2610) Signed-off-by: Chetan Gulecha --- Docs/api_docs/tensorflow_quantization.rst | 1 + .../tensorflow/src/python/aimet_tensorflow/quant_analyzer.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/Docs/api_docs/tensorflow_quantization.rst b/Docs/api_docs/tensorflow_quantization.rst index 99950a82758..5a09c514268 100644 --- a/Docs/api_docs/tensorflow_quantization.rst +++ b/Docs/api_docs/tensorflow_quantization.rst @@ -4,6 +4,7 @@ AIMET TensorFlow Quantization APIs AIMET Quantization for TensorFlow provides the following functionality - :ref:`Quantization Simulation`: Allows ability to simulate inference and training on quantized hardware + - :ref:`QuantAnalyzer`: Analyzes the model and points out sensitive ops to quantization - :ref:`Adaptive Rounding`: Post-training quantization technique to optimize rounding of weight tensors - :ref:`Cross-Layer Equalization`: Post-training quantization technique to equalize layer parameters - :ref:`Bias Correction`: Post-training quantization technique to correct shift in layer outputs due to quantization noise diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quant_analyzer.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quant_analyzer.py index 099234e09b7..5a5e7f9e0e9 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quant_analyzer.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/quant_analyzer.py @@ -402,7 +402,7 @@ def _compute_mse_loss(self, sim: QuantizationSimModel, output_op_name) -> float: # Collect output activation data from quant sim op feed_dict = create_input_feed_dict(sim.session.graph, self._start_op_names, model_inputs) - quant_op = sim.session.graph.get_operation_by_name(output_op_name) + quant_op = sim.session.graph.get_operation_by_name(output_op_name + "_quantized") quantized_out_data = sim.session.run(quant_op.outputs[0], feed_dict=feed_dict) # Calculate MSE loss