Skip to content

Commit

Permalink
Sanitize intermediate layer activation names in ONNX LOG utility (#2661)
Browse files Browse the repository at this point in the history
Signed-off-by: Raj Gite <quic_rgite@quicinc.com>
  • Loading branch information
quic-rgite authored and quic-bharathr committed Sep 13, 2024
1 parent ce0d29d commit c2d61a9
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import copy
from typing import List, Dict, Tuple, Union
import re
import numpy as np
import onnxruntime as ort
import onnx
Expand Down Expand Up @@ -122,7 +123,9 @@ def __init__(self, model: ModelProto, providers: List, dir_path: str):
LayerOutput.register_activations(self.model, self.activation_names)

self.session = QuantizationSimModel.build_session(self.model, providers)
self.sanitized_activation_names = [name[:-len('_updated')] if name.endswith('_updated') else name for name in self.activation_names]

# Replace special characters with underscore. This gives valid file names to store activation tensors.
self.sanitized_activation_names = [re.sub(r'\W+', "_", name.replace('_updated', '')) for name in self.activation_names]

# Save activation names which are in topological order of model graph. This order can be used while comparing layer-outputs.
save_layer_output_names(self.sanitized_activation_names, dir_path)
Expand Down
16 changes: 8 additions & 8 deletions TrainingExtensions/onnx/test/python/models/models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,28 +1631,28 @@ def build_dummy_model_with_dynamic_input():
shape=['batch_size', 10])
conv_node = helper.make_node('Conv',
['input', 'conv_w', 'conv_b'],
['3'],
['conv/output.3'],
'conv',
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],)
relu_node = helper.make_node('Relu',
['3'],
['4'],
['conv/output.3'],
['relu/output.4'],
'relu')
maxpool_node = helper.make_node('MaxPool',
['4'],
['5'],
['relu/output.4'],
['maxpool/output.5'],
'maxpool',
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[2, 2],)

flatten_node = helper.make_node('Flatten',
['5'],
['6'],
['maxpool/output.5'],
['flatten/output.6'],
'flatten')
fc_node = helper.make_node('Gemm',
['6', 'fc_w', 'fc_b'],
['flatten/output.6', 'fc_w', 'fc_b'],
['output'],
'fc')

Expand Down
10 changes: 6 additions & 4 deletions TrainingExtensions/onnx/test/python/test_layer_output_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def get_original_model_artifacts():
for node in model.graph.node:
output_names.extend(node.output)
input_dict = make_dummy_input(model)
output_names = [name.replace('/', '_').replace('.', '_') for name in output_names] # sanitization to get valid file names
return model, output_names, input_dict


Expand All @@ -77,6 +78,7 @@ def callback(session, input_dict):
for node in quantsim.model.model.graph.node:
output_names.extend(node.output)
output_names = [name[:-len('_updated')] for name in output_names if name.endswith('_updated')]
output_names = [name.replace('/', '_').replace('.', '_') for name in output_names] # sanitization to get valid file names

return quantsim, output_names, input_dict

Expand All @@ -95,10 +97,10 @@ def test_get_original_model_outputs(self):
layer_output = LayerOutput(model, providers, temp_dir_path)
output_name_to_output_val_dict = layer_output.get_outputs(input_dict)

# Verify whether outputs are generated for all the layers
# Verify whether all outputs are generated and have sanitized names
for name in output_names:
assert name in output_name_to_output_val_dict, \
"Output not generated for " + name
"Output not generated: " + name

# Verify whether captured outputs are correct. This can only be checked for final output of the model
session = QuantizationSimModel.build_session(model, providers)
Expand All @@ -120,10 +122,10 @@ def test_get_quantsim_model_outputs(self):
layer_output = LayerOutput(quantsim.model.model, providers, temp_dir_path)
output_name_to_output_val_dict = layer_output.get_outputs(input_dict)

# Verify whether outputs are generated for all the layers
# Verify whether all outputs are generated and have sanitized names
for name in output_names:
assert name in output_name_to_output_val_dict, \
"Output not generated for " + name
"Output not generated: " + name

# Verify whether captured outputs are correct. This can only be checked for final output of the model
session = QuantizationSimModel.build_session(quantsim.model.model, providers)
Expand Down

0 comments on commit c2d61a9

Please sign in to comment.