From 0e926480ec58f11e8e50695ea15169ceee2cf0ac Mon Sep 17 00:00:00 2001 From: eladc-git Date: Wed, 22 Nov 2023 13:55:03 +0200 Subject: [PATCH] fix bug in substitute reshape of weights node --- .../substitutions/reshape_with_static_shapes.py | 4 ++++ .../model_tests/feature_models/dynamic_size_inputs_test.py | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py index 8901dec4b..93ea4b9d7 100644 --- a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py @@ -50,6 +50,10 @@ def substitute(self, Returns: Graph after applying the substitution. """ + # skip substitution if the reshape is applied to weights (not activations). In this case, the node is the first source and doesn't have an input shape. + if len(node.input_shape) == 0: + return graph + # we want the batch size value to infer from the length of the array and remaining dimensions if len(node.output_shape) == 1: node.output_shape[0][0] = BATCH_DIM_VALUE diff --git a/tests/pytorch_tests/model_tests/feature_models/dynamic_size_inputs_test.py b/tests/pytorch_tests/model_tests/feature_models/dynamic_size_inputs_test.py index c2f28e3fd..56acc07a2 100644 --- a/tests/pytorch_tests/model_tests/feature_models/dynamic_size_inputs_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/dynamic_size_inputs_test.py @@ -34,6 +34,7 @@ class ReshapeNet(torch.nn.Module): def __init__(self): super(ReshapeNet, self).__init__() self.conv1 = torch.nn.Conv2d(3, 4, kernel_size=1, stride=1) + self.scale_weight = torch.nn.Parameter(torch.ones(32)) def forward(self, x): x = self.conv1(x) @@ -47,6 +48,7 @@ def forward(self, x): height = height + batch height = height - batch x = torch.transpose(x, 1, 2) + x = torch.reshape(self.scale_weight, (1, -1, 1, 1)) * x return x.reshape(-1, channels, height, width) @@ -78,7 +80,8 @@ def compare(self, quantized_models, float_model, input_x=None, quantization_info reshape_nodes = [n for n in v.graph.nodes if (n.type == torch.reshape or n.type == torch.Tensor.view)] for r in reshape_nodes: for o in r.op_call_args: - assert isinstance(o, list) + if len(r.input_shape) > 0: + assert isinstance(o, list) ###################################################### # check the all other comparisons: