diff --git a/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py b/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py index 4e16baa10b9..6bf1eb3fb5f 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py @@ -1,4 +1,4 @@ -# /usr/bin/env python3.5 +#!/usr/bin/env python3 # -*- mode: python -*- # ============================================================================= # @@-COPYRIGHT-START-@@ @@ -68,6 +68,10 @@ # this flag in such circumstances. EXPORT_TO_ONNX_DIRECT = False +# Flag to enable restoring pruned initializers in ONNX graph +# By default, the flag is disabled because it is rare case we should restore initializers in most cases +RESTORE_ONNX_MODEL_INITIALIZERS = False + # runs the second pass of markers for non-leaf torch module and updates names of onnx ops belonging to # non-leaf pytorch module update_all_onnx_nodes_name = True @@ -217,6 +221,16 @@ def kwargs(self): 'input_names': self.input_names, 'output_names': self.output_names} + +@dataclass(frozen=True) +class PrunedInitializerInfo: + """ + Data carrier containing initializer to be added and identity node to be removed in ONNX graph + """ + initializer: onnx.TensorProto + identity_node: onnx.NodeProto + + class MarkerAttr(IntEnum): """ Enumeration for the custom marker attribute to index into the onnx node """ NAME = 0 @@ -1064,6 +1078,8 @@ def _create_onnx_model_with_markers(cls, dummy_input, pt_model, working_dir, onn 'temp_onnx_model_with_all_markers.onnx') cls._export_model_to_onnx(model, dummy_input, temp_file, is_conditional, onnx_export_args) + if RESTORE_ONNX_MODEL_INITIALIZERS: + restore_onnx_graph_initializers(temp_file, temp_file) return cls.load_simply_onnx_model(temp_file) @classmethod @@ -1456,3 +1472,78 @@ def _export_model_to_onnx(model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.onnx.export(model, dummy_input, temp_file, **kwargs) except torch.onnx.CheckerError: _logger.warning("ONNX Checker has failed but ONNX graph is still generated.") + + +def restore_onnx_graph_initializers(original_model_path: str, restored_model_path: str): + """ + Restore pruned initializers in ONNX graph + + :param original_model_path: Path where the original ONNX artifact was stored + :param restored_model_path: Path to store restored ONNX artifact + """ + # pylint: disable=protected-access, no-member + model = onnx.load(original_model_path) + onnx_graph = model.graph + + initializers = OnnxSaver._get_all_initializers(onnx_graph) + initializer_names = [initializer.name for initializer in initializers] + pruned_initializer_map = _get_pruned_initializer_map( + onnx_graph, initializers, initializer_names + ) + + for node in onnx_graph.node: + for input_tensor in node.input: + _restore_pruned_initializer( + onnx_graph, input_tensor, pruned_initializer_map + ) + + onnx.save(model, restored_model_path) + + +def _get_pruned_initializer_map(onnx_graph: onnx.GraphProto, + initializers: List[onnx.TensorProto], + initializer_names: List[str]) -> Dict[str, PrunedInitializerInfo]: + """ + Find pruned ONNX initializers by iterating Identity nodes + + :param onnx_graph: ONNX graph + :param initializers: List of ONNX initializers + :param initializer_names: List of model initializer names + :return: Dictionary with output of identity node as key and PrunedInitializerInfo as value + """ + pruned_initializer_map = {} + for node in onnx_graph.node: + if node.op_type == "Identity" and node.input[0] in initializer_names: + index = initializer_names.index(node.input[0]) + initializer = copy.deepcopy(initializers[index]) + pruned_initializer_map[node.output[0]] = PrunedInitializerInfo( + initializer, node + ) + + return pruned_initializer_map + + +def _restore_pruned_initializer(onnx_graph: onnx.GraphProto, + input_tensor: str, + pruned_initializer_map: Dict[str, PrunedInitializerInfo], + new_initializer_name: Optional[str] = None): + """ + Create new Initializer and remove Identity node to restore pruned Initializer + + :param onnx_graph: ONNX graph + :param input_tensor: Input tensor name + :param pruned_initializer_map: Dictionary with output of identity node as key and PrunedInitializerInfo as value + :param new_initializer_name: Name for new initializer + """ + if result := pruned_initializer_map.get(input_tensor): + new_initializer = result.initializer + existing_identity_node = result.identity_node + + new_initializer.name = new_initializer_name or input_tensor + onnx_graph.initializer.append(new_initializer) + onnx_graph.node.remove(existing_identity_node) + _logger.info( + "Added new Initializer `%s` and removed existing Identity node `%s`", + new_initializer.name, + existing_identity_node.name, + ) diff --git a/TrainingExtensions/torch/test/python/models/test_models.py b/TrainingExtensions/torch/test/python/models/test_models.py index 33ee85de474..c0602908b68 100644 --- a/TrainingExtensions/torch/test/python/models/test_models.py +++ b/TrainingExtensions/torch/test/python/models/test_models.py @@ -1130,3 +1130,23 @@ def forward(self, input_ids): x = self.embedding(input_ids) x = self.linear(x) return self.softmax(x) + + +class MultiplePReluModel(nn.Module): + def __init__(self, num_parameters: int = 1): + super().__init__() + self.conv1 = nn.Conv2d(3, 8, kernel_size=1) + self.act1 = nn.PReLU(num_parameters=num_parameters) + self.conv2 = nn.Conv2d(8, 8, kernel_size=3) + self.act2 = nn.PReLU(num_parameters=num_parameters) + self.conv3 = nn.Conv2d(8, 8, kernel_size=3) + self.act3 = nn.PReLU(num_parameters=num_parameters) + + def forward(self, *inputs): + x = self.conv1(inputs[0]) + x = self.act1(x) + x = self.conv2(x) + x = self.act2(x) + x = self.conv3(x) + x = self.act3(x) + return x diff --git a/TrainingExtensions/torch/test/python/test_onnx_utils.py b/TrainingExtensions/torch/test/python/test_onnx_utils.py index 0f3f9a51dbe..d4bc4d6a4ba 100644 --- a/TrainingExtensions/torch/test/python/test_onnx_utils.py +++ b/TrainingExtensions/torch/test/python/test_onnx_utils.py @@ -1,9 +1,9 @@ -# /usr/bin/env python3.5 +#!/usr/bin/env python3 # -*- mode: python -*- # ============================================================================= # @@-COPYRIGHT-START-@@ # -# Copyright (c) 2019-2020, Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2019-2023, Qualcomm Innovation Center, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: @@ -37,21 +37,21 @@ # ============================================================================= import contextlib import copy -import os import logging +import os +import tempfile from collections import defaultdict -import pytest -from packaging.version import Version +import onnx +import pytest import torch from torchvision import models import aimet_torch.elementwise_ops from aimet_common.utils import AimetLogger from aimet_torch import onnx_utils -import onnx - -from models.test_models import RoiModel, InputOutputDictModel +from aimet_torch.onnx_utils import restore_onnx_graph_initializers +from models.test_models import RoiModel, InputOutputDictModel, MultiplePReluModel class OutOfOrderModel(torch.nn.Module): @@ -887,3 +887,52 @@ def test_naming_for_model_with_deep_graph(self): if os.path.exists(onnx_path): os.remove(onnx_path) + + @pytest.mark.parametrize("num_parameters", [1, 8]) + def test_set_node_name_for_multiple_p_relu_model(self, num_parameters): + model = MultiplePReluModel(num_parameters) + dummy_input = torch.randn(4, 3, 28, 28) + + with tempfile.TemporaryDirectory() as tmp_dir: + onnx_path = f"{tmp_dir}/multiple_p_relu_model.onnx" + + onnx_utils.RESTORE_ONNX_MODEL_INITIALIZERS = True + onnx_utils.OnnxSaver.set_node_names(onnx_path, model, dummy_input) + onnx_model = onnx.load(onnx_path) + + expected_initializer_names = ["act1.weight", "act2.weight", "act3.weight"] + actual_initializer_names = {x.name for x in onnx_model.graph.initializer} + for name in expected_initializer_names: + assert name in actual_initializer_names + + _, valid_param_set = onnx_utils.OnnxSaver.get_onnx_node_to_io_tensor_names_map(onnx_model) + for name in expected_initializer_names: + assert name in valid_param_set + + self.check_onnx_node_name_uniqueness(onnx_model) + onnx_utils.RESTORE_ONNX_MODEL_INITIALIZERS = False + + def test_restore_onnx_graph_initializers(self): + model = MultiplePReluModel() + dummy_input = torch.randn(4, 3, 28, 28) + + with tempfile.TemporaryDirectory() as tmp_dir: + original_model_path = f"{tmp_dir}/multiple_p_relu_model.onnx" + torch.onnx.export(model, dummy_input, original_model_path) + + original_model = onnx.load(original_model_path) + identity_node_outputs = {x.output[0] for x in original_model.graph.node if x.op_type == "Identity"} + p_relu_slopes = [x.input[1] for x in original_model.graph.node if x.op_type == "PRelu"] + + # At least one slope is related to the identity node in original ONNX model + assert any([slope in identity_node_outputs for slope in p_relu_slopes]) + + restored_model_path = f"{tmp_dir}/restored_multiple_p_relu_model.onnx" + restore_onnx_graph_initializers(original_model_path, restored_model_path) + restored_model = onnx.load(restored_model_path) + + # All slope should be initializers in restored ONNX model + restored_model_initializers = {x.name for x in restored_model.graph.initializer} + restored_p_relu_slopes = [x.input[1] for x in restored_model.graph.node if x.op_type == "PRelu"] + + assert all([slope in restored_model_initializers for slope in restored_p_relu_slopes]) diff --git a/TrainingExtensions/torch/test/python/test_quantizer.py b/TrainingExtensions/torch/test/python/test_quantizer.py index 0633bfe2e40..9472a9014d7 100644 --- a/TrainingExtensions/torch/test/python/test_quantizer.py +++ b/TrainingExtensions/torch/test/python/test_quantizer.py @@ -2753,6 +2753,34 @@ def test_export_to_onnx_direct(self): onnxsaver_act_names = onnxsaver_encodings['activation_encodings'].keys() assert direct_onnx_act_names != onnxsaver_act_names + @pytest.mark.parametrize("num_parameters", [1, 8]) + @pytest.mark.parametrize("config_file", [None, get_path_for_per_channel_config()]) + def test_export_to_onnx_for_multiple_p_relu_model(self, num_parameters, config_file): + model = test_models.MultiplePReluModel(num_parameters) + dummy_input = torch.randn(4, 3, 28, 28) + + with tempfile.TemporaryDirectory() as tmp_dir: + sim = QuantizationSimModel(model, dummy_input, QuantScheme.post_training_tf, + config_file=config_file) + sim.compute_encodings(lambda m, _: m(dummy_input), None) + filename_prefix = "multiple_p_relu_model" + + onnx_utils.RESTORE_ONNX_MODEL_INITIALIZERS = True + sim.export(tmp_dir, filename_prefix, dummy_input) + with open(f"{tmp_dir}/{filename_prefix}.encodings") as encodings_file: + encodings = json.load(encodings_file) + + param_encodings = encodings["param_encodings"] + expected_param_names = ["act1.weight", "act2.weight", "act3.weight"] + for param_name in expected_param_names: + assert param_name in param_encodings + + if config_file: # Per-channel + assert len(param_encodings[param_name]) == num_parameters + else: # Per-tensor + assert len(param_encodings[param_name]) == 1 + onnx_utils.RESTORE_ONNX_MODEL_INITIALIZERS = False + def test_save_encodings_to_json(self): model = ModelWithTwoInputsOneToAdd() dummy_input = (torch.rand(32, 1, 100, 100), torch.rand(32, 10, 22, 22))