From c2d61a988a87439a5f51a82879fba3f3be90dc2b Mon Sep 17 00:00:00 2001 From: Raj Gite Date: Wed, 24 Jan 2024 15:30:22 +0530 Subject: [PATCH] Sanitize intermediate layer activation names in ONNX LOG utility (#2661) Signed-off-by: Raj Gite --- .../src/python/aimet_onnx/layer_output_utils.py | 5 ++++- .../onnx/test/python/models/models_for_tests.py | 16 ++++++++-------- .../onnx/test/python/test_layer_output_utils.py | 10 ++++++---- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/layer_output_utils.py b/TrainingExtensions/onnx/src/python/aimet_onnx/layer_output_utils.py index ef4e8741be2..94dd2708fd8 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/layer_output_utils.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/layer_output_utils.py @@ -39,6 +39,7 @@ import copy from typing import List, Dict, Tuple, Union +import re import numpy as np import onnxruntime as ort import onnx @@ -122,7 +123,9 @@ def __init__(self, model: ModelProto, providers: List, dir_path: str): LayerOutput.register_activations(self.model, self.activation_names) self.session = QuantizationSimModel.build_session(self.model, providers) - self.sanitized_activation_names = [name[:-len('_updated')] if name.endswith('_updated') else name for name in self.activation_names] + + # Replace special characters with underscore. This gives valid file names to store activation tensors. + self.sanitized_activation_names = [re.sub(r'\W+', "_", name.replace('_updated', '')) for name in self.activation_names] # Save activation names which are in topological order of model graph. This order can be used while comparing layer-outputs. save_layer_output_names(self.sanitized_activation_names, dir_path) diff --git a/TrainingExtensions/onnx/test/python/models/models_for_tests.py b/TrainingExtensions/onnx/test/python/models/models_for_tests.py index 7944cf36fae..dbd201e254b 100644 --- a/TrainingExtensions/onnx/test/python/models/models_for_tests.py +++ b/TrainingExtensions/onnx/test/python/models/models_for_tests.py @@ -1631,28 +1631,28 @@ def build_dummy_model_with_dynamic_input(): shape=['batch_size', 10]) conv_node = helper.make_node('Conv', ['input', 'conv_w', 'conv_b'], - ['3'], + ['conv/output.3'], 'conv', kernel_shape=[3, 3], pads=[1, 1, 1, 1],) relu_node = helper.make_node('Relu', - ['3'], - ['4'], + ['conv/output.3'], + ['relu/output.4'], 'relu') maxpool_node = helper.make_node('MaxPool', - ['4'], - ['5'], + ['relu/output.4'], + ['maxpool/output.5'], 'maxpool', kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[2, 2],) flatten_node = helper.make_node('Flatten', - ['5'], - ['6'], + ['maxpool/output.5'], + ['flatten/output.6'], 'flatten') fc_node = helper.make_node('Gemm', - ['6', 'fc_w', 'fc_b'], + ['flatten/output.6', 'fc_w', 'fc_b'], ['output'], 'fc') diff --git a/TrainingExtensions/onnx/test/python/test_layer_output_utils.py b/TrainingExtensions/onnx/test/python/test_layer_output_utils.py index a7eaa804847..9d73a969c4c 100644 --- a/TrainingExtensions/onnx/test/python/test_layer_output_utils.py +++ b/TrainingExtensions/onnx/test/python/test_layer_output_utils.py @@ -61,6 +61,7 @@ def get_original_model_artifacts(): for node in model.graph.node: output_names.extend(node.output) input_dict = make_dummy_input(model) + output_names = [name.replace('/', '_').replace('.', '_') for name in output_names] # sanitization to get valid file names return model, output_names, input_dict @@ -77,6 +78,7 @@ def callback(session, input_dict): for node in quantsim.model.model.graph.node: output_names.extend(node.output) output_names = [name[:-len('_updated')] for name in output_names if name.endswith('_updated')] + output_names = [name.replace('/', '_').replace('.', '_') for name in output_names] # sanitization to get valid file names return quantsim, output_names, input_dict @@ -95,10 +97,10 @@ def test_get_original_model_outputs(self): layer_output = LayerOutput(model, providers, temp_dir_path) output_name_to_output_val_dict = layer_output.get_outputs(input_dict) - # Verify whether outputs are generated for all the layers + # Verify whether all outputs are generated and have sanitized names for name in output_names: assert name in output_name_to_output_val_dict, \ - "Output not generated for " + name + "Output not generated: " + name # Verify whether captured outputs are correct. This can only be checked for final output of the model session = QuantizationSimModel.build_session(model, providers) @@ -120,10 +122,10 @@ def test_get_quantsim_model_outputs(self): layer_output = LayerOutput(quantsim.model.model, providers, temp_dir_path) output_name_to_output_val_dict = layer_output.get_outputs(input_dict) - # Verify whether outputs are generated for all the layers + # Verify whether all outputs are generated and have sanitized names for name in output_names: assert name in output_name_to_output_val_dict, \ - "Output not generated for " + name + "Output not generated: " + name # Verify whether captured outputs are correct. This can only be checked for final output of the model session = QuantizationSimModel.build_session(quantsim.model.model, providers)