Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove spconv from package dependency #3488

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -243,53 +243,59 @@
# =============================================================================
""" Utils for handling custom modules """

import spconv.pytorch as spconv
import aimet_torch
try:
import spconv.pytorch as spconv
except ImportError as e:
is_spconv_module = None
QuantizableSparseSequential = None
create_quantizable_sparse_sequential = None
else:
import aimet_torch


def is_spconv_module(module):
"""
Modified version of is_spconv_module from spconv.pytorch.modules.is_spconv_module
If module is QcQuantizeWrapper, it check for _module_to_wrap instead
:param module: Module to check if it is spconv module
:return: True if module or _module_to_wrap is spconv module
"""
# pylint: disable=protected-access
spconv_modules = (spconv.SparseModule, )
if isinstance(module, aimet_torch.v1.qc_quantize_op.QcQuantizeWrapper):
return isinstance(module._module_to_wrap, spconv_modules)
return isinstance(module, spconv_modules)
def is_spconv_module(module):
"""
Modified version of is_spconv_module from spconv.pytorch.modules.is_spconv_module
If module is QcQuantizeWrapper, it check for _module_to_wrap instead
:param module: Module to check if it is spconv module
:return: True if module or _module_to_wrap is spconv module
"""
# pylint: disable=protected-access
spconv_modules = (spconv.SparseModule, )
if isinstance(module, aimet_torch.v1.qc_quantize_op.QcQuantizeWrapper):
return isinstance(module._module_to_wrap, spconv_modules)
return isinstance(module, spconv_modules)


class QuantizableSparseSequential(spconv.SparseSequential):
"""
Quantizable version of SparseSequential
forward function is modified to use custom version of is_spconv_module
"""
# pylint: disable=arguments-differ
def forward(self, x):
for module in self._modules.values():
if is_spconv_module(module): # use SpConvTensor as input
if isinstance(x, list):
x = module(x)
class QuantizableSparseSequential(spconv.SparseSequential):
"""
Quantizable version of SparseSequential
forward function is modified to use custom version of is_spconv_module
"""
# pylint: disable=arguments-differ
def forward(self, x):
for module in self._modules.values():
if is_spconv_module(module): # use SpConvTensor as input
if isinstance(x, list):
x = module(x)
else:
# assert isinstance(input, spconv.SparseConvTensor)
# self._sparity_dict[k] = input.sparity
x = module(x)
else:
# assert isinstance(input, spconv.SparseConvTensor)
# self._sparity_dict[k] = input.sparity
x = module(x)
else:
if isinstance(x, spconv.SparseConvTensor):
if x.indices.shape[0] != 0:
x = x.replace_feature(module(x.features))
else:
x = module(x)
return x
if isinstance(x, spconv.SparseConvTensor):
if x.indices.shape[0] != 0:
x = x.replace_feature(module(x.features))
else:
x = module(x)
return x


def create_quantizable_sparse_sequential(module: spconv.SparseSequential) -> QuantizableSparseSequential:
"""
Create QuantizableSparseSequential using existing SparseSequential module
:param module: Existing SparseSequential module
:return: Newly created QuantizableSparseSequential
"""
# pylint: disable=protected-access
return QuantizableSparseSequential(module._modules)
def create_quantizable_sparse_sequential(module: spconv.SparseSequential) -> QuantizableSparseSequential:
"""
Create QuantizableSparseSequential using existing SparseSequential module
:param module: Existing SparseSequential module
:return: Newly created QuantizableSparseSequential
"""
# pylint: disable=protected-access
return QuantizableSparseSequential(module._modules)
Original file line number Diff line number Diff line change
Expand Up @@ -36,43 +36,48 @@
# =============================================================================
""" Utils for handling custom tensor types """

from typing import List, Union, Tuple
import spconv.pytorch as spconv
import torch
try:
import spconv.pytorch as spconv
except ImportError as e:
to_torch_tensor = None
to_custom_tensor = None
else:
from typing import List, Union, Tuple
import torch


def to_torch_tensor(original: Union[List, Tuple]) -> List[torch.Tensor]:
"""
Convert custom tensors to torch tensors
:param original: List of original tensors
:return: List of tensors in torch tensor type
"""
def to_torch_tensor(original: Union[List, Tuple]) -> List[torch.Tensor]:
"""
Convert custom tensors to torch tensors
:param original: List of original tensors
:return: List of tensors in torch tensor type
"""

outputs = []
outputs = []

for tensor in original:
if isinstance(tensor, spconv.SparseConvTensor):
tensor = tensor.features
outputs.append(tensor)
for tensor in original:
if isinstance(tensor, spconv.SparseConvTensor):
tensor = tensor.features
outputs.append(tensor)

return outputs
return outputs


def to_custom_tensor(original: Union[List, Tuple], torch_tensors: List[torch.Tensor]) -> List:
"""
Convert torch tensors to original custom tensors
:param original: List of original tensors
:param torch_tensors: List of torch tensors
:return: List of tensors in original type
"""
def to_custom_tensor(original: Union[List, Tuple], torch_tensors: List[torch.Tensor]) -> List:
"""
Convert torch tensors to original custom tensors
:param original: List of original tensors
:param torch_tensors: List of torch tensors
:return: List of tensors in original type
"""

outputs = []
outputs = []

for orig, torch_tensor in zip(original, torch_tensors):
tensor = torch_tensor
if isinstance(orig, spconv.SparseConvTensor):
tensor = orig.replace_feature(torch_tensor)
for orig, torch_tensor in zip(original, torch_tensors):
tensor = torch_tensor
if isinstance(orig, spconv.SparseConvTensor):
tensor = orig.replace_feature(torch_tensor)

outputs.append(tensor)
outputs.append(tensor)

return outputs
return outputs
Loading
Loading