Skip to content

Commit

Permalink
Fix a bug when restoring pruned initializer (#2771)
Browse files Browse the repository at this point in the history
Signed-off-by: Hitarth Mehta <quic_hitameht@quicinc.com>
  • Loading branch information
quic-hitameht authored Feb 22, 2024
1 parent 0c9d93d commit 9894507
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 10 deletions.
18 changes: 9 additions & 9 deletions TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,14 @@ def restore_onnx_graph_initializers(model: onnx.ModelProto,
onnx_graph, input_tensor, pruned_initializer_map
)

# Remove all the detached "Identity" type nodes
for pruned_initializer_info in pruned_initializer_map.values():
onnx_graph.node.remove(pruned_initializer_info.identity_node)
_logger.debug(
"Added new Initializer `%s` and removing existing Identity node `%s`",
pruned_initializer_info.initializer.name,
pruned_initializer_info.identity_node.name,
)
return model


Expand Down Expand Up @@ -1580,7 +1588,7 @@ def _restore_pruned_initializer(onnx_graph: onnx.GraphProto,
pruned_initializer_map: Dict[str, PrunedInitializerInfo],
new_initializer_name: Optional[str] = None):
"""
Create new Initializer and remove Identity node to restore pruned Initializer
Create new Initializer to restore pruned Initializer
:param onnx_graph: ONNX graph
:param input_tensor: Input tensor name
Expand All @@ -1589,13 +1597,5 @@ def _restore_pruned_initializer(onnx_graph: onnx.GraphProto,
"""
if result := pruned_initializer_map.get(input_tensor):
new_initializer = result.initializer
existing_identity_node = result.identity_node

new_initializer.name = new_initializer_name or input_tensor
onnx_graph.initializer.append(new_initializer)
onnx_graph.node.remove(existing_identity_node)
_logger.info(
"Added new Initializer `%s` and removed existing Identity node `%s`",
new_initializer.name,
existing_identity_node.name,
)
23 changes: 23 additions & 0 deletions TrainingExtensions/torch/test/python/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,3 +1228,26 @@ def __init__(self):
def forward(self, *inputs):
chunks = self.split_module(inputs[0])
return self.relu1(chunks[0]), self.relu2(chunks[1]), self.relu3(chunks[2])


class ModuleList(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ModuleList(torch.nn.Linear(256, 256) for _ in range(3))

def forward(self, x):
for i, layer in enumerate(self.layers):
x = torch.nn.functional.relu(layer(x)) if i < 3 - 1 else layer(x)
return x


class ModelWithReusedInitializers(torch.nn.Module):
def __init__(self, repetition):
super(ModelWithReusedInitializers, self).__init__()
self.modulelist = ModuleList()
self.repetition = repetition

def forward(self, x):
for i in range(self.repetition):
x = self.modulelist(x)
return x
34 changes: 33 additions & 1 deletion TrainingExtensions/torch/test/python/test_onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,13 @@
import aimet_torch.elementwise_ops
from aimet_common.utils import AimetLogger
from aimet_torch import onnx_utils
from aimet_torch.model_preparer import prepare_model
from aimet_torch.onnx_utils import (
save_initializer_restored_onnx_graph,
restore_onnx_graph_initializers, get_pytorch_name_from_onnx_name
)
from models.test_models import RoiModel, InputOutputDictModel, MultiplePReluModel, NestedSeqModel,\
NestedModelWithOverlappingNames, ModelWithModuleList
NestedModelWithOverlappingNames, ModelWithModuleList, ModelWithReusedInitializers


class OutOfOrderModel(torch.nn.Module):
Expand Down Expand Up @@ -1005,3 +1006,34 @@ def test_get_pytorch_name_from_onnx_name(self):
pytorch_name = get_pytorch_name_from_onnx_name(node.name)
print(pytorch_name, "-->", node.name)
assert isinstance(model.get_submodule(pytorch_name), torch.nn.Module)

@pytest.mark.parametrize("inplace", [True])
def test_restore_onnx_graph_reused_initializers(self, inplace):
""" test to verify that Initializers are added and correponding Identity nodes are removed correctly """
repetition = 2
model = ModelWithReusedInitializers(repetition).eval()
dummy_input = torch.randn(1, 256)
prepared_model = prepare_model(model)
assert torch.equal(model(dummy_input), prepared_model(dummy_input))

with tempfile.TemporaryDirectory() as tmp_dir:
original_model_path = f"{tmp_dir}/reused_initializers.onnx"
torch.onnx.export(prepared_model, dummy_input, original_model_path)

original_model = onnx.load(original_model_path)
identity_nodes = [node for node in original_model.graph.node if node.op_type == "Identity"]
assert len(identity_nodes) == 6
initializers = [ini for ini in original_model.graph.initializer]
assert len(initializers) == 6

restored_model = restore_onnx_graph_initializers(original_model, inplace=inplace)
identity_nodes = [node for node in restored_model.graph.node if node.op_type == "Identity"]
# There shouldn't be any "Identity" type nodes in the restored model.
assert len(identity_nodes) == 0
initializers = [ini for ini in restored_model.graph.initializer]
# There will be 6 more initializers added for newly added modules.
assert len(initializers) == 6 * repetition

# Ensure that the graph is correct
self.check_onnx_node_name_uniqueness(restored_model)
onnx.checker.check_model(restored_model)

0 comments on commit 9894507

Please sign in to comment.