Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug in GatherNd causing device mismatch when tracing #2765

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2021-2023, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2021-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -37,12 +37,12 @@

""" Custom modules for functional operations defined under torch and torch.nn.functional packages """

from typing import Callable, Any, Tuple, Union, List
import itertools
import torchvision
from typing import Callable, Any, Tuple, Union, List

import torch
import torch.nn

import torchvision


def forward_function_wrapper(functional: Callable) -> Any:
Expand Down Expand Up @@ -436,7 +436,7 @@ def forward(self, data: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
else batch_dims_shape + list(indices.shape)[self.batch_dims:-1] + list(data.shape)[self.batch_dims + indices.shape[-1]:])

if torch.jit.is_tracing():
return torch.zeros(*output_shape)
return torch.zeros(*output_shape, device=data.device)

output_data_buffer = []

Expand All @@ -450,7 +450,7 @@ def forward(self, data: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
output_data_buffer.append(reshaped_data[(batch_dim, *gather_index)])

if output_data_buffer[0].dim() == 0:
return torch.tensor(output_data_buffer).reshape(output_shape)
return torch.tensor(output_data_buffer, device=data.device).reshape(output_shape)
return torch.cat(output_data_buffer).reshape(output_shape)


Expand Down
22 changes: 19 additions & 3 deletions TrainingExtensions/torch/test/python/test_elementwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2021-2023, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2021-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -37,10 +37,13 @@

import json
import unittest.mock
from unittest import mock

import aimet_common.libpymo as libpymo
import numpy as np
import torch
import torch.nn as nn
import numpy as np
import aimet_common.libpymo as libpymo

from aimet_common.defs import QuantScheme
from aimet_torch import elementwise_ops
from aimet_torch.quantsim import QuantizationSimModel
Expand Down Expand Up @@ -583,3 +586,16 @@ def test_custom_scatternd_op(self):

custom_module_out = model(inputs, indices, updates)
self.assertTrue(np.allclose(custom_module_out, original_module_out))

def test_gather_nd_jit_trace(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

gather_nd = elementwise_ops.GatherNd(batch_dim=0)
data = torch.tensor([[0, 1], [2, 3]], device=device)
indices = torch.tensor([[0, 0], [1, 1]], device=device)

# Patch torch.jit.is_tracing() as True
with mock.patch("torch.jit.is_tracing", lambda: True), torch.inference_mode():
outputs = gather_nd(data, indices)

assert data.device == outputs.device
Loading