diff --git a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py index c70adce82b5..a9dde58a972 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py @@ -367,7 +367,12 @@ def forward(self, *args) -> torch.Tensor: res_per_class.append([index, class_index, val.detach()]) res_per_class = res_per_class[:self.max_output_boxes_per_class] res.extend(res_per_class) - return torch.Tensor(res).type(torch.int64) + + res = torch.Tensor(res).type(torch.int64) + out = torch.zeros(batch_scores.shape[0] * batch_scores.shape[1] * self.max_output_boxes_per_class, 3, dtype=torch.int64) + indices = torch.arange(0, len(res) * len(res[0]), dtype=torch.int64) + out.put_(indices, res) + return out def perform_nms_per_class(self, boxes: torch.Tensor, classes_score: torch.Tensor) -> torch.Tensor: """