diff --git a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py index a9dde58a972..d5dc3f55b7b 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py @@ -370,7 +370,7 @@ def forward(self, *args) -> torch.Tensor: res = torch.Tensor(res).type(torch.int64) out = torch.zeros(batch_scores.shape[0] * batch_scores.shape[1] * self.max_output_boxes_per_class, 3, dtype=torch.int64) - indices = torch.arange(0, len(res) * len(res[0]), dtype=torch.int64) + indices = torch.arange(0, len(res) * 3, dtype=torch.int64) out.put_(indices, res) return out