From dcd4bf1264518255cb7ba69d91436a795eaf4dea Mon Sep 17 00:00:00 2001 From: Geunho Lee Date: Mon, 23 Oct 2023 20:41:22 +0900 Subject: [PATCH] Implement in-memory ONNX initializer restoration method Signed-off-by: Geunho Lee --- .../src/python/aimet_torch/onnx_utils.py | 30 ++++++++++++---- .../torch/test/python/test_onnx_utils.py | 35 +++++++++++++++++-- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py b/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py index a0e65b56dd7..dac0dcf1c57 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py @@ -1082,8 +1082,6 @@ def _create_onnx_model_with_markers(cls, dummy_input, pt_model, working_dir, onn 'temp_onnx_model_with_all_markers.onnx') cls._export_model_to_onnx(model, dummy_input, temp_file, is_conditional, onnx_export_args) - if RESTORE_ONNX_MODEL_INITIALIZERS: - restore_onnx_graph_initializers(temp_file, temp_file) return cls.load_simply_onnx_model(temp_file) @classmethod @@ -1094,6 +1092,9 @@ def load_simply_onnx_model(cls, filepath) -> onnx.ModelProto: :return: Onnx model with optional simply pass """ onnx_model = onnx.load(filepath) + if RESTORE_ONNX_MODEL_INITIALIZERS: + onnx_model = restore_onnx_graph_initializers(onnx_model, inplace=True) + if simplify_onnx_model: onnx_model_simplified, check = onnxsim.simplify(onnx_model) if check: @@ -1478,15 +1479,32 @@ def _export_model_to_onnx(model: Union[torch.nn.Module, torch.jit.ScriptModule, _logger.warning("ONNX Checker has failed but ONNX graph is still generated.") -def restore_onnx_graph_initializers(original_model_path: str, restored_model_path: str): +def save_initializer_restored_onnx_graph(original_model_path: str, + restored_model_path: str): """ - Restore pruned initializers in ONNX graph + Load original ONNX model path and save restored ONNX model to specific path :param original_model_path: Path where the original ONNX artifact was stored :param restored_model_path: Path to store restored ONNX artifact """ - # pylint: disable=protected-access, no-member model = onnx.load(original_model_path) + restored_model = restore_onnx_graph_initializers(model, inplace=True) + onnx.save(restored_model, restored_model_path) + + +def restore_onnx_graph_initializers(model: onnx.ModelProto, + inplace: bool = False) -> onnx.ModelProto: + """ + Copy original model and restore its pruned initializers + + :param model: Original ONNX ModelProto + :param inplace: Whether to modify ModelProto by inplace manner or not + :return: Initializer restored ONNX ModelProto + """ + # pylint: disable=protected-access, no-member + if not inplace: + model = copy.deepcopy(model) + onnx_graph = model.graph initializers = OnnxSaver._get_all_initializers(onnx_graph) @@ -1501,7 +1519,7 @@ def restore_onnx_graph_initializers(original_model_path: str, restored_model_pat onnx_graph, input_tensor, pruned_initializer_map ) - onnx.save(model, restored_model_path) + return model def _get_pruned_initializer_map(onnx_graph: onnx.GraphProto, diff --git a/TrainingExtensions/torch/test/python/test_onnx_utils.py b/TrainingExtensions/torch/test/python/test_onnx_utils.py index 1ef83ea9f25..7a70a2828fd 100644 --- a/TrainingExtensions/torch/test/python/test_onnx_utils.py +++ b/TrainingExtensions/torch/test/python/test_onnx_utils.py @@ -49,7 +49,10 @@ import aimet_torch.elementwise_ops from aimet_common.utils import AimetLogger from aimet_torch import onnx_utils -from aimet_torch.onnx_utils import restore_onnx_graph_initializers +from aimet_torch.onnx_utils import ( + save_initializer_restored_onnx_graph, + restore_onnx_graph_initializers, +) from models.test_models import RoiModel, InputOutputDictModel, MultiplePReluModel @@ -911,7 +914,7 @@ def test_set_node_name_for_multiple_p_relu_model(self, num_parameters): self.check_onnx_node_name_uniqueness(onnx_model) onnx_utils.RESTORE_ONNX_MODEL_INITIALIZERS = False - def test_restore_onnx_graph_initializers(self): + def test_save_initializer_restored_onnx_graph(self): model = MultiplePReluModel() dummy_input = torch.randn(4, 3, 28, 28) @@ -927,7 +930,7 @@ def test_restore_onnx_graph_initializers(self): assert any([slope in identity_node_outputs for slope in p_relu_slopes]) restored_model_path = f"{tmp_dir}/restored_multiple_p_relu_model.onnx" - restore_onnx_graph_initializers(original_model_path, restored_model_path) + save_initializer_restored_onnx_graph(original_model_path, restored_model_path) restored_model = onnx.load(restored_model_path) # All slope should be initializers in restored ONNX model @@ -935,3 +938,29 @@ def test_restore_onnx_graph_initializers(self): restored_p_relu_slopes = [x.input[1] for x in restored_model.graph.node if x.op_type == "PRelu"] assert all([slope in restored_model_initializers for slope in restored_p_relu_slopes]) + + @pytest.mark.parametrize("inplace", [True, False]) + def test_restore_onnx_graph_initializers(self, inplace): + model = MultiplePReluModel() + dummy_input = torch.randn(4, 3, 28, 28) + + with tempfile.TemporaryDirectory() as tmp_dir: + original_model_path = f"{tmp_dir}/multiple_p_relu_model.onnx" + torch.onnx.export(model, dummy_input, original_model_path) + + original_model = onnx.load(original_model_path) + restored_model = restore_onnx_graph_initializers(original_model, inplace=inplace) + + # Both inplace=True and inplace=False, restored model should have separate initializers + restored_model_initializers = {x.name for x in restored_model.graph.initializer} + restored_p_relu_slopes = [x.input[1] for x in restored_model.graph.node if x.op_type == "PRelu"] + assert all([slope in restored_model_initializers for slope in restored_p_relu_slopes]) + + if inplace: + assert id(original_model) == id(restored_model) + else: + # Original model shouldn't be modified if inplace=False + identity_node_outputs = {x.output[0] for x in original_model.graph.node if x.op_type == "Identity"} + p_relu_slopes = [x.input[1] for x in original_model.graph.node if x.op_type == "PRelu"] + assert any([slope in identity_node_outputs for slope in p_relu_slopes]) + assert id(original_model) != id(restored_model)