Skip to content

Commit

Permalink
fix bug in substitute reshape of weights node
Browse files Browse the repository at this point in the history
  • Loading branch information
eladc-git committed Nov 22, 2023
1 parent e921060 commit 0e92648
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def substitute(self,
Returns:
Graph after applying the substitution.
"""
# skip substitution if the reshape is applied to weights (not activations). In this case, the node is the first source and doesn't have an input shape.
if len(node.input_shape) == 0:
return graph

# we want the batch size value to infer from the length of the array and remaining dimensions
if len(node.output_shape) == 1:
node.output_shape[0][0] = BATCH_DIM_VALUE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ReshapeNet(torch.nn.Module):
def __init__(self):
super(ReshapeNet, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 4, kernel_size=1, stride=1)
self.scale_weight = torch.nn.Parameter(torch.ones(32))

def forward(self, x):
x = self.conv1(x)
Expand All @@ -47,6 +48,7 @@ def forward(self, x):
height = height + batch
height = height - batch
x = torch.transpose(x, 1, 2)
x = torch.reshape(self.scale_weight, (1, -1, 1, 1)) * x
return x.reshape(-1, channels, height, width)


Expand Down Expand Up @@ -78,7 +80,8 @@ def compare(self, quantized_models, float_model, input_x=None, quantization_info
reshape_nodes = [n for n in v.graph.nodes if (n.type == torch.reshape or n.type == torch.Tensor.view)]
for r in reshape_nodes:
for o in r.op_call_args:
assert isinstance(o, list)
if len(r.input_shape) > 0:
assert isinstance(o, list)

######################################################
# check the all other comparisons:
Expand Down

0 comments on commit 0e92648

Please sign in to comment.