Skip to content

Commit

Permalink
Implement in-memory ONNX initializer restoration method
Browse files Browse the repository at this point in the history
Signed-off-by: Geunho Lee <quic_geunlee@quicinc.com>
  • Loading branch information
quic-geunlee committed Nov 14, 2023
1 parent 6d91070 commit dcd4bf1
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 9 deletions.
30 changes: 24 additions & 6 deletions TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,8 +1082,6 @@ def _create_onnx_model_with_markers(cls, dummy_input, pt_model, working_dir, onn
'temp_onnx_model_with_all_markers.onnx')

cls._export_model_to_onnx(model, dummy_input, temp_file, is_conditional, onnx_export_args)
if RESTORE_ONNX_MODEL_INITIALIZERS:
restore_onnx_graph_initializers(temp_file, temp_file)
return cls.load_simply_onnx_model(temp_file)

@classmethod
Expand All @@ -1094,6 +1092,9 @@ def load_simply_onnx_model(cls, filepath) -> onnx.ModelProto:
:return: Onnx model with optional simply pass
"""
onnx_model = onnx.load(filepath)
if RESTORE_ONNX_MODEL_INITIALIZERS:
onnx_model = restore_onnx_graph_initializers(onnx_model, inplace=True)

if simplify_onnx_model:
onnx_model_simplified, check = onnxsim.simplify(onnx_model)
if check:
Expand Down Expand Up @@ -1478,15 +1479,32 @@ def _export_model_to_onnx(model: Union[torch.nn.Module, torch.jit.ScriptModule,
_logger.warning("ONNX Checker has failed but ONNX graph is still generated.")


def restore_onnx_graph_initializers(original_model_path: str, restored_model_path: str):
def save_initializer_restored_onnx_graph(original_model_path: str,
restored_model_path: str):
"""
Restore pruned initializers in ONNX graph
Load original ONNX model path and save restored ONNX model to specific path
:param original_model_path: Path where the original ONNX artifact was stored
:param restored_model_path: Path to store restored ONNX artifact
"""
# pylint: disable=protected-access, no-member
model = onnx.load(original_model_path)
restored_model = restore_onnx_graph_initializers(model, inplace=True)
onnx.save(restored_model, restored_model_path)


def restore_onnx_graph_initializers(model: onnx.ModelProto,
inplace: bool = False) -> onnx.ModelProto:
"""
Copy original model and restore its pruned initializers
:param model: Original ONNX ModelProto
:param inplace: Whether to modify ModelProto by inplace manner or not
:return: Initializer restored ONNX ModelProto
"""
# pylint: disable=protected-access, no-member
if not inplace:
model = copy.deepcopy(model)

onnx_graph = model.graph

initializers = OnnxSaver._get_all_initializers(onnx_graph)
Expand All @@ -1501,7 +1519,7 @@ def restore_onnx_graph_initializers(original_model_path: str, restored_model_pat
onnx_graph, input_tensor, pruned_initializer_map
)

onnx.save(model, restored_model_path)
return model


def _get_pruned_initializer_map(onnx_graph: onnx.GraphProto,
Expand Down
35 changes: 32 additions & 3 deletions TrainingExtensions/torch/test/python/test_onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@
import aimet_torch.elementwise_ops
from aimet_common.utils import AimetLogger
from aimet_torch import onnx_utils
from aimet_torch.onnx_utils import restore_onnx_graph_initializers
from aimet_torch.onnx_utils import (
save_initializer_restored_onnx_graph,
restore_onnx_graph_initializers,
)
from models.test_models import RoiModel, InputOutputDictModel, MultiplePReluModel


Expand Down Expand Up @@ -911,7 +914,7 @@ def test_set_node_name_for_multiple_p_relu_model(self, num_parameters):
self.check_onnx_node_name_uniqueness(onnx_model)
onnx_utils.RESTORE_ONNX_MODEL_INITIALIZERS = False

def test_restore_onnx_graph_initializers(self):
def test_save_initializer_restored_onnx_graph(self):
model = MultiplePReluModel()
dummy_input = torch.randn(4, 3, 28, 28)

Expand All @@ -927,11 +930,37 @@ def test_restore_onnx_graph_initializers(self):
assert any([slope in identity_node_outputs for slope in p_relu_slopes])

restored_model_path = f"{tmp_dir}/restored_multiple_p_relu_model.onnx"
restore_onnx_graph_initializers(original_model_path, restored_model_path)
save_initializer_restored_onnx_graph(original_model_path, restored_model_path)
restored_model = onnx.load(restored_model_path)

# All slope should be initializers in restored ONNX model
restored_model_initializers = {x.name for x in restored_model.graph.initializer}
restored_p_relu_slopes = [x.input[1] for x in restored_model.graph.node if x.op_type == "PRelu"]

assert all([slope in restored_model_initializers for slope in restored_p_relu_slopes])

@pytest.mark.parametrize("inplace", [True, False])
def test_restore_onnx_graph_initializers(self, inplace):
model = MultiplePReluModel()
dummy_input = torch.randn(4, 3, 28, 28)

with tempfile.TemporaryDirectory() as tmp_dir:
original_model_path = f"{tmp_dir}/multiple_p_relu_model.onnx"
torch.onnx.export(model, dummy_input, original_model_path)

original_model = onnx.load(original_model_path)
restored_model = restore_onnx_graph_initializers(original_model, inplace=inplace)

# Both inplace=True and inplace=False, restored model should have separate initializers
restored_model_initializers = {x.name for x in restored_model.graph.initializer}
restored_p_relu_slopes = [x.input[1] for x in restored_model.graph.node if x.op_type == "PRelu"]
assert all([slope in restored_model_initializers for slope in restored_p_relu_slopes])

if inplace:
assert id(original_model) == id(restored_model)
else:
# Original model shouldn't be modified if inplace=False
identity_node_outputs = {x.output[0] for x in original_model.graph.node if x.op_type == "Identity"}
p_relu_slopes = [x.input[1] for x in original_model.graph.node if x.op_type == "PRelu"]
assert any([slope in identity_node_outputs for slope in p_relu_slopes])
assert id(original_model) != id(restored_model)

0 comments on commit dcd4bf1

Please sign in to comment.