Skip to content

Commit

Permalink
Implement a method to support multiple PReLU model export
Browse files Browse the repository at this point in the history
Signed-off-by: Geunho Lee <quic_geunlee@quicinc.com>
  • Loading branch information
quic-geunlee committed Oct 20, 2023
1 parent 3e2494c commit 0d29fef
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 9 deletions.
93 changes: 92 additions & 1 deletion TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# /usr/bin/env python3.5
#!/usr/bin/env python3
# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
Expand Down Expand Up @@ -68,6 +68,10 @@
# this flag in such circumstances.
EXPORT_TO_ONNX_DIRECT = False

# Flag to enable restoring pruned initializers in ONNX graph
# By default, the flag is disabled because it is rare case we should restore initializers in most cases
RESTORE_ONNX_MODEL_INITIALIZERS = False

# runs the second pass of markers for non-leaf torch module and updates names of onnx ops belonging to
# non-leaf pytorch module
update_all_onnx_nodes_name = True
Expand Down Expand Up @@ -217,6 +221,16 @@ def kwargs(self):
'input_names': self.input_names,
'output_names': self.output_names}


@dataclass(frozen=True)
class PrunedInitializerInfo:
"""
Data carrier containing initializer to be added and identity node to be removed in ONNX graph
"""
initializer: onnx.TensorProto
identity_node: onnx.NodeProto


class MarkerAttr(IntEnum):
""" Enumeration for the custom marker attribute to index into the onnx node """
NAME = 0
Expand Down Expand Up @@ -1064,6 +1078,8 @@ def _create_onnx_model_with_markers(cls, dummy_input, pt_model, working_dir, onn
'temp_onnx_model_with_all_markers.onnx')

cls._export_model_to_onnx(model, dummy_input, temp_file, is_conditional, onnx_export_args)
if RESTORE_ONNX_MODEL_INITIALIZERS:
restore_onnx_graph_initializers(temp_file, temp_file)
return cls.load_simply_onnx_model(temp_file)

@classmethod
Expand Down Expand Up @@ -1456,3 +1472,78 @@ def _export_model_to_onnx(model: Union[torch.nn.Module, torch.jit.ScriptModule,
torch.onnx.export(model, dummy_input, temp_file, **kwargs)
except torch.onnx.CheckerError:
_logger.warning("ONNX Checker has failed but ONNX graph is still generated.")


def restore_onnx_graph_initializers(original_model_path: str, restored_model_path: str):
"""
Restore pruned initializers in ONNX graph
:param original_model_path: Path where the original ONNX artifact was stored
:param restored_model_path: Path to store restored ONNX artifact
"""
# pylint: disable=protected-access, no-member
model = onnx.load(original_model_path)
onnx_graph = model.graph

initializers = OnnxSaver._get_all_initializers(onnx_graph)
initializer_names = [initializer.name for initializer in initializers]
pruned_initializer_map = _get_pruned_initializer_map(
onnx_graph, initializers, initializer_names
)

for node in onnx_graph.node:
for input_tensor in node.input:
_restore_pruned_initializer(
onnx_graph, input_tensor, pruned_initializer_map
)

onnx.save(model, restored_model_path)


def _get_pruned_initializer_map(onnx_graph: onnx.GraphProto,
initializers: List[onnx.TensorProto],
initializer_names: List[str]) -> Dict[str, PrunedInitializerInfo]:
"""
Find pruned ONNX initializers by iterating Identity nodes
:param onnx_graph: ONNX graph
:param initializers: List of ONNX initializers
:param initializer_names: List of model initializer names
:return: Dictionary with output of identity node as key and PrunedInitializerInfo as value
"""
pruned_initializer_map = {}
for node in onnx_graph.node:
if node.op_type == "Identity" and node.input[0] in initializer_names:
index = initializer_names.index(node.input[0])
initializer = copy.deepcopy(initializers[index])
pruned_initializer_map[node.output[0]] = PrunedInitializerInfo(
initializer, node
)

return pruned_initializer_map


def _restore_pruned_initializer(onnx_graph: onnx.GraphProto,
input_tensor: str,
pruned_initializer_map: Dict[str, PrunedInitializerInfo],
new_initializer_name: Optional[str] = None):
"""
Create new Initializer and remove Identity node to restore pruned Initializer
:param onnx_graph: ONNX graph
:param input_tensor: Input tensor name
:param pruned_initializer_map: Dictionary with output of identity node as key and PrunedInitializerInfo as value
:param new_initializer_name: Name for new initializer
"""
if result := pruned_initializer_map.get(input_tensor):
new_initializer = result.initializer
existing_identity_node = result.identity_node

new_initializer.name = new_initializer_name or input_tensor
onnx_graph.initializer.append(new_initializer)
onnx_graph.node.remove(existing_identity_node)
_logger.info(
"Added new Initializer `%s` and removed existing Identity node `%s`",
new_initializer.name,
existing_identity_node.name,
)
20 changes: 20 additions & 0 deletions TrainingExtensions/torch/test/python/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,3 +1130,23 @@ def forward(self, input_ids):
x = self.embedding(input_ids)
x = self.linear(x)
return self.softmax(x)


class MultiplePReluModel(nn.Module):
def __init__(self, num_parameters: int = 1):
super().__init__()
self.conv1 = nn.Conv2d(3, 8, kernel_size=1)
self.act1 = nn.PReLU(num_parameters=num_parameters)
self.conv2 = nn.Conv2d(8, 8, kernel_size=3)
self.act2 = nn.PReLU(num_parameters=num_parameters)
self.conv3 = nn.Conv2d(8, 8, kernel_size=3)
self.act3 = nn.PReLU(num_parameters=num_parameters)

def forward(self, *inputs):
x = self.conv1(inputs[0])
x = self.act1(x)
x = self.conv2(x)
x = self.act2(x)
x = self.conv3(x)
x = self.act3(x)
return x
65 changes: 57 additions & 8 deletions TrainingExtensions/torch/test/python/test_onnx_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# /usr/bin/env python3.5
#!/usr/bin/env python3
# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2019-2020, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2019-2023, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -37,21 +37,21 @@
# =============================================================================
import contextlib
import copy
import os
import logging
import os
import tempfile
from collections import defaultdict
import pytest
from packaging.version import Version

import onnx
import pytest
import torch
from torchvision import models

import aimet_torch.elementwise_ops
from aimet_common.utils import AimetLogger
from aimet_torch import onnx_utils
import onnx

from models.test_models import RoiModel, InputOutputDictModel
from aimet_torch.onnx_utils import restore_onnx_graph_initializers
from models.test_models import RoiModel, InputOutputDictModel, MultiplePReluModel


class OutOfOrderModel(torch.nn.Module):
Expand Down Expand Up @@ -887,3 +887,52 @@ def test_naming_for_model_with_deep_graph(self):

if os.path.exists(onnx_path):
os.remove(onnx_path)

@pytest.mark.parametrize("num_parameters", [1, 8])
def test_set_node_name_for_multiple_p_relu_model(self, num_parameters):
model = MultiplePReluModel(num_parameters)
dummy_input = torch.randn(4, 3, 28, 28)

with tempfile.TemporaryDirectory() as tmp_dir:
onnx_path = f"{tmp_dir}/multiple_p_relu_model.onnx"

onnx_utils.RESTORE_ONNX_MODEL_INITIALIZERS = True
onnx_utils.OnnxSaver.set_node_names(onnx_path, model, dummy_input)
onnx_model = onnx.load(onnx_path)

expected_initializer_names = ["act1.weight", "act2.weight", "act3.weight"]
actual_initializer_names = {x.name for x in onnx_model.graph.initializer}
for name in expected_initializer_names:
assert name in actual_initializer_names

_, valid_param_set = onnx_utils.OnnxSaver.get_onnx_node_to_io_tensor_names_map(onnx_model)
for name in expected_initializer_names:
assert name in valid_param_set

self.check_onnx_node_name_uniqueness(onnx_model)
onnx_utils.RESTORE_ONNX_MODEL_INITIALIZERS = False

def test_restore_onnx_graph_initializers(self):
model = MultiplePReluModel()
dummy_input = torch.randn(4, 3, 28, 28)

with tempfile.TemporaryDirectory() as tmp_dir:
original_model_path = f"{tmp_dir}/multiple_p_relu_model.onnx"
torch.onnx.export(model, dummy_input, original_model_path)

original_model = onnx.load(original_model_path)
identity_node_outputs = {x.output[0] for x in original_model.graph.node if x.op_type == "Identity"}
p_relu_slopes = [x.input[1] for x in original_model.graph.node if x.op_type == "PRelu"]

# At least one slope is related to the identity node in original ONNX model
assert any([slope in identity_node_outputs for slope in p_relu_slopes])

restored_model_path = f"{tmp_dir}/restored_multiple_p_relu_model.onnx"
restore_onnx_graph_initializers(original_model_path, restored_model_path)
restored_model = onnx.load(restored_model_path)

# All slope should be initializers in restored ONNX model
restored_model_initializers = {x.name for x in restored_model.graph.initializer}
restored_p_relu_slopes = [x.input[1] for x in restored_model.graph.node if x.op_type == "PRelu"]

assert all([slope in restored_model_initializers for slope in restored_p_relu_slopes])
28 changes: 28 additions & 0 deletions TrainingExtensions/torch/test/python/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2753,6 +2753,34 @@ def test_export_to_onnx_direct(self):
onnxsaver_act_names = onnxsaver_encodings['activation_encodings'].keys()
assert direct_onnx_act_names != onnxsaver_act_names

@pytest.mark.parametrize("num_parameters", [1, 8])
@pytest.mark.parametrize("config_file", [None, get_path_for_per_channel_config()])
def test_export_to_onnx_for_multiple_p_relu_model(self, num_parameters, config_file):
model = test_models.MultiplePReluModel(num_parameters)
dummy_input = torch.randn(4, 3, 28, 28)

with tempfile.TemporaryDirectory() as tmp_dir:
sim = QuantizationSimModel(model, dummy_input, QuantScheme.post_training_tf,
config_file=config_file)
sim.compute_encodings(lambda m, _: m(dummy_input), None)
filename_prefix = "multiple_p_relu_model"

onnx_utils.RESTORE_ONNX_MODEL_INITIALIZERS = True
sim.export(tmp_dir, filename_prefix, dummy_input)
with open(f"{tmp_dir}/{filename_prefix}.encodings") as encodings_file:
encodings = json.load(encodings_file)

param_encodings = encodings["param_encodings"]
expected_param_names = ["act1.weight", "act2.weight", "act3.weight"]
for param_name in expected_param_names:
assert param_name in param_encodings

if config_file: # Per-channel
assert len(param_encodings[param_name]) == num_parameters
else: # Per-tensor
assert len(param_encodings[param_name]) == 1
onnx_utils.RESTORE_ONNX_MODEL_INITIALIZERS = False

def test_save_encodings_to_json(self):
model = ModelWithTwoInputsOneToAdd()
dummy_input = (torch.rand(32, 1, 100, 100), torch.rand(32, 10, 22, 22))
Expand Down

0 comments on commit 0d29fef

Please sign in to comment.