Skip to content

Commit

Permalink
fix output tensors size calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Jan 6, 2025
1 parent 9817e53 commit e009eac
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,11 @@ def get_total_output_params(self) -> float:
Returns: Output size.
"""
# multiple output shapes are not necessarily lists, e.g. tf nms uses custom named tuple.
# shape can be tuple or list, and multiple shapes can be packed in list or tuple
if self.output_shape and isinstance(self.output_shape[0], (tuple, list)):
output_shapes = list(self.output_shape)
output_shapes = self.output_shape
else:
output_shapes = self.output_shape if isinstance(self.output_shape, list) else [self.output_shape]
output_shapes = [self.output_shape]

# remove batch size (first element) from output shape
output_shapes = [s[1:] for s in output_shapes]
Expand Down

0 comments on commit e009eac

Please sign in to comment.