diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/fake_quant.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/fake_quant.py index a66e60efa77..4937db60e30 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/fake_quant.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/nn/fake_quant.py @@ -45,7 +45,7 @@ from torch.utils._pytree import tree_map from aimet_torch.experimental.v2.nn.quant_base import BaseQuantizationMixin -from aimet_torch.experimental.v2.quantization.modules.quantize import _QuantizerBase +from aimet_torch.experimental.v2.quantization.quantizers import QuantizerBase import aimet_torch.elementwise_ops as aimet_ops @@ -77,7 +77,7 @@ def export_input_encodings(self) -> List[List[Dict]]: Returns a list of input encodings, each represented as a List of Dicts """ return [ - quantizer.get_encodings() if isinstance(quantizer, _QuantizerBase) else None + quantizer.get_encodings() if isinstance(quantizer, QuantizerBase) else None for quantizer in _flatten_nn_module_list(self.input_quantizers) ] @@ -86,7 +86,7 @@ def export_output_encodings(self) -> List[List[Dict]]: Returns a list of output encodings, each represented as a List of Dicts """ return [ - quantizer.get_encodings() if isinstance(quantizer, _QuantizerBase) else None + quantizer.get_encodings() if isinstance(quantizer, QuantizerBase) else None for quantizer in _flatten_nn_module_list(self.output_quantizers) ] @@ -95,7 +95,7 @@ def export_param_encodings(self) -> Dict[str, List[Dict]]: Returns a dict of {param name: param encodings}, with each encoding represented as a List of Dicts """ return { - param_name: quantizer.get_encodings() if isinstance(quantizer, _QuantizerBase) else None + param_name: quantizer.get_encodings() if isinstance(quantizer, QuantizerBase) else None for param_name, quantizer in self.param_quantizers.items() } diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/__init__.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/__init__.py index 2d6a46a2230..745788400d7 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/__init__.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/__init__.py @@ -37,4 +37,4 @@ # pylint: disable=all -from .modules import * +from .quantizers import * diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/modules/__init__.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantizers/__init__.py similarity index 97% rename from TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/modules/__init__.py rename to TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantizers/__init__.py index 91697e18f69..1cb1f305685 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/modules/__init__.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantizers/__init__.py @@ -36,4 +36,5 @@ # ============================================================================= # pylint: disable=all -from .quantize import * +from .base import * +from .affine import * diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/modules/quantize.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantizers/affine.py similarity index 70% rename from TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/modules/quantize.py rename to TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantizers/affine.py index 02484f9424c..b23485cb85b 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/modules/quantize.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantizers/affine.py @@ -2,7 +2,7 @@ # ============================================================================= # @@-COPYRIGHT-START-@@ # -# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2023-2024, Qualcomm Innovation Center, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: @@ -35,30 +35,29 @@ # @@-COPYRIGHT-END-@@ # ============================================================================= # pylint: disable=redefined-builtin -""" nn.Modules for quantization operators """ +""" Affine quantizers """ -import copy +import abc from typing import Optional, Tuple, List, Dict import contextlib -from collections import OrderedDict import functools -import weakref import torch from torch import nn from aimet_torch.experimental.v2.utils import patch_attr, patch_param, _is_expandable, StatisticsNotFoundError from aimet_torch.experimental.v2.quantization.encoding_analyzer import EncodingAnalyzer, MinMaxEncodingAnalyzer +from aimet_torch.experimental.v2.quantization.quantizers.base import QuantizerBase from aimet_torch.experimental.v2.quantization.backends import get_backend from aimet_torch.experimental.v2.utils import ste_round -__all__ = ['Quantize', 'QuantizeDequantize', 'Dequantize'] +__all__ = ['AffineQuantizerBase', 'MinMaxQuantizer', 'Quantize', 'QuantizeDequantize', 'Dequantize'] -class _QuantizerBase(torch.nn.Module): # pylint: disable=abstract-method +class AffineQuantizerBase(QuantizerBase): """ - Base class for quantization modules. + Base class for linear quantization modules. :param shape: Shape of the quantization parameters. :param bitwidth: Quantization bitwidth. @@ -67,10 +66,6 @@ class _QuantizerBase(torch.nn.Module): # pylint: disable=abstract-method :param encoding_analyzer: Encoding analyzer for calibrating quantization encodings. (default: absolute min-max encoding analyzer) """ - - min: torch.nn.Parameter - max: torch.nn.Parameter - def __init__(self, shape, bitwidth: int, symmetric: bool, encoding_analyzer: EncodingAnalyzer = None): super().__init__() if isinstance(shape, int): @@ -84,147 +79,47 @@ def __init__(self, shape, bitwidth: int, symmetric: bool, encoding_analyzer: Enc raise RuntimeError(f'Encoding analyzer of shape {self.encoding_analyzer.observer.shape} ' f'is incompatible with quantizer of shape {self.shape}.') - # param_name -> (weakref of initial parameter, version info of the initial parameter) - # This info will be used for judging whether the current parameter has ever been - # initialized after it was instantiated. - self._initial_parameters = OrderedDict() - - # Raw quantization parameters - self.register_quantization_parameter('min', nn.Parameter(-torch.ones(self.shape))) - self.register_quantization_parameter('max', nn.Parameter(torch.ones(self.shape))) - - @torch.no_grad() - def __deepcopy__(self, memo): - self_copy = self.__new__(type(self)) - self_copy.__dict__ = copy.deepcopy(self.__dict__, memo) - - for name, param in self_copy.named_parameters(): - # Register parameters to the copied quantizer - self_copy.register_quantization_parameter(name, param) - - # If the parameter has been already initialized, - # artificially increment the parameter version to mark as initialized - if self._is_initialized(name): - param.mul_(1.) - - return self_copy - - def __getstate__(self): - state = self.__dict__.copy() - state.pop('_initial_parameters') - state['initialized_parameters'] = [param_name for param_name, _ in self.named_parameters() - if self._is_initialized(param_name)] - return state - - @torch.no_grad() - def __setstate__(self, state): - initialized_parameters = state.pop('initialized_parameters') - self.__dict__.update(state) - - self._initial_parameters = OrderedDict() - for param_name, param in self.named_parameters(): - # Register parameters to the loaded quantizer - self.register_quantization_parameter(param_name, param) - - # If the parameter has been already initialized, - # artificially increment the parameter version to mark as initialized - if param_name in initialized_parameters: - param.mul_(1.) - - def register_quantization_parameter(self, name: str, param: nn.Parameter): + @abc.abstractmethod + def get_min(self) -> torch.Tensor: """ - Register quantization parameter. - """ - # pylint: disable=protected-access - - self.register_parameter(name, param) - param = getattr(self, name) - self._initial_parameters[name] = (weakref.ref(param), param._version) - - def _is_initialized(self, param_name) -> bool: - # pylint: disable=protected-access - - initial_param_weakref, initial_param_version = self._initial_parameters[param_name] - initial_param = initial_param_weakref() - - if initial_param is None: - # The initial parameter object doesn't exist in memory space anymore. - return True - - current_param = getattr(self, param_name) - - if current_param is initial_param and current_param._version == initial_param_version: - # 1. Current parameter is the identical object as the initial parameter - # 2. The version nubmer of the current parameter never changed - return False - - return True - - def is_initialized(self) -> bool: - """ - Returns true if the quantization parameters are initialized. - """ - for param_name, _ in self.named_parameters(): - if not self._is_initialized(param_name): - return False - return True - - def get_min(self) -> Optional[torch.Tensor]: - """ - Compute quantization min to be used for forward pass based on raw parameters. + Compute quantization min to be used for forward pass. + Return None f the quantizer is not initialized yet. :return: Quantization min """ - if not self.is_initialized(): - return None - return self.get_scale() * self.get_offset() - def get_max(self) -> Optional[torch.Tensor]: + @abc.abstractmethod + def get_max(self) -> torch.Tensor: """ - Compute quantization max to be used for forward pass based on raw parameters. + Compute quantization max to be used for forward pass. + Return None f the quantizer is not initialized yet. :return: Quantization max """ - if not self.is_initialized(): - return None - return self.get_scale() * (self.get_offset() + 2 ** self.bitwidth - 1) - def get_scale(self) -> Optional[torch.Tensor]: + @abc.abstractmethod + def get_scale(self) -> torch.Tensor: """ - Compute quantization scale to be used for forward pass based on raw parameters. + Compute quantization scale to be used for forward pass. + Return None f the quantizer is not initialized yet. :return: Quantization scale """ - if not self.is_initialized(): - return None - - num_bins = 2 ** self.bitwidth - 1 - - if self.symmetric: - positive_bins = num_bins // 2 - negative_bins = positive_bins + 1 - scale = torch.maximum(-self.min / negative_bins, self.max / positive_bins) - else: - scale = (self.max - self.min) / num_bins - - return scale - def get_offset(self) -> Optional[torch.Tensor]: + @abc.abstractmethod + def get_offset(self) -> torch.Tensor: """ - Compute quantization offset to be used for forward pass based on raw parameters. + Compute quantization offset to be used for forward pass. + Return None f the quantizer is not initialized yet. :return: Quantization offset """ - if not self.is_initialized(): - return None - if self.symmetric: - with torch.no_grad(): - offset = -torch.ones_like(self.min) * 2 ** (self.bitwidth - 1) - else: - offset = ste_round(self.min / self.get_scale()) - - return offset + @abc.abstractmethod + def set_range(self, min: torch.Tensor, max: torch.Tensor): + """ + Set quantization parameters to the given min-max range + """ @torch.no_grad() def get_encodings(self) -> Optional[List[Dict]]: @@ -251,6 +146,24 @@ def get_encodings(self) -> Optional[List[Dict]]: for min_, max_, scale_, offset_ in zip(min, max, scale, offset) ] + def extra_repr(self) -> str: + return f'shape={self.shape}, bitwidth={self.bitwidth}, symmetric={self.symmetric}' + + +class MinMaxQuantizer(AffineQuantizerBase): # pylint: disable=abstract-method + """ + Affine quantizer with min-max as trainable parameters + """ + + min: torch.nn.Parameter + max: torch.nn.Parameter + + def __init__(self, shape, bitwidth: int, symmetric: bool, encoding_analyzer: EncodingAnalyzer = None): + super().__init__(shape, bitwidth, symmetric, encoding_analyzer) + + self.register_quantization_parameter('min', nn.Parameter(-torch.ones(self.shape))) + self.register_quantization_parameter('max', nn.Parameter(torch.ones(self.shape))) + @contextlib.contextmanager def compute_encodings(self): """ @@ -285,17 +198,84 @@ def forward_wrapper(input): if min is None or max is None: return - with torch.no_grad(): - self.min.copy_(min) - self.max.copy_(max) + self.set_range(min, max) + finally: self.encoding_analyzer.reset_stats() - def extra_repr(self) -> str: - return f'shape={self.shape}, bitwidth={self.bitwidth}, symmetric={self.symmetric}' + def get_min(self) -> Optional[torch.Tensor]: + """ + Compute quantization min to be used for forward pass. + + NOTE: self.min may not be equal to self.get_min(). + self.get_min() returns slightly recalibrated version of self.min. + + :return: Quantization min + """ + if not self.is_initialized(): + return None + return self.get_scale() * self.get_offset() + + def get_max(self) -> Optional[torch.Tensor]: + """ + Compute quantization max to be used for forward pass. + + NOTE: self.max may not be equal to self.get_max() + self.get_max() returns slightly recalibrated version of self.max. + + :return: Quantization max + """ + if not self.is_initialized(): + return None + return self.get_scale() * (self.get_offset() + 2 ** self.bitwidth - 1) + + def get_scale(self) -> Optional[torch.Tensor]: + """ + Compute quantization scale to be used for forward pass. + + :return: Quantization scale + """ + if not self.is_initialized(): + return None + + num_bins = 2 ** self.bitwidth - 1 + + if self.symmetric: + positive_bins = num_bins // 2 + negative_bins = positive_bins + 1 + scale = torch.maximum(-self.min / negative_bins, self.max / positive_bins) + else: + scale = (self.max - self.min) / num_bins + + return scale + + def get_offset(self) -> Optional[torch.Tensor]: + """ + Compute quantization offset to be used for forward pass. + + :return: Quantization offset + """ + if not self.is_initialized(): + return None + + if self.symmetric: + with torch.no_grad(): + offset = -torch.ones_like(self.min) * 2 ** (self.bitwidth - 1) + else: + offset = ste_round(self.min / self.get_scale()) + + return offset + + def set_range(self, min: torch.Tensor, max: torch.Tensor): + """ + Set quantization parameters to the given min-max range + """ + with torch.no_grad(): + self.min.copy_(min) + self.max.copy_(max) -class Quantize(_QuantizerBase): +class Quantize(MinMaxQuantizer): """ Applies quantization to the input """ @@ -316,7 +296,7 @@ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torc return input_q, scale, offset -class QuantizeDequantize(_QuantizerBase): +class QuantizeDequantize(MinMaxQuantizer): """ Applies quantization followed by dequantization to the input """ diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantizers/base.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantizers/base.py new file mode 100644 index 00000000000..69c6e1b8b6d --- /dev/null +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantizers/base.py @@ -0,0 +1,152 @@ +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= +""" Quantizer base class """ + +import abc +import copy +from collections import OrderedDict +import contextlib +import weakref +from typing import Optional, List, Dict + +import torch +from torch import nn + + +__all__ = ['QuantizerBase'] + + +class QuantizerBase(abc.ABC, torch.nn.Module): + """ + Quantizer base class + """ + def __init__(self): + super().__init__() + + # param_name -> (weakref of initial parameter, version info of the initial parameter) + # This info will be used for judging whether the current parameter has ever been + # initialized after it was instantiated. + self._initial_parameters = OrderedDict() + + @abc.abstractmethod + @contextlib.contextmanager + def compute_encodings(self): + """ + Observe inputs and update quantization parameters based on the input statistics. + """ + + @abc.abstractmethod + def get_encodings(self) -> Optional[List[Dict]]: + """ + Returns a list of encodings, each represented as a List of Dicts + """ + + def register_quantization_parameter(self, name: str, param: nn.Parameter): + """ + Register quantization parameter. + """ + # pylint: disable=protected-access + + self.register_parameter(name, param) + param = getattr(self, name) + self._initial_parameters[name] = (weakref.ref(param), param._version) + + def is_initialized(self) -> bool: + """ + Returns true if the quantization parameters are initialized. + """ + for param_name, _ in self.named_parameters(): + if not self._is_initialized(param_name): + return False + return True + + def _is_initialized(self, param_name) -> bool: + # pylint: disable=protected-access + + initial_param_weakref, initial_param_version = self._initial_parameters[param_name] + initial_param = initial_param_weakref() + + if initial_param is None: + # The initial parameter object doesn't exist in memory space anymore. + return True + + current_param = getattr(self, param_name) + + if current_param is initial_param and current_param._version == initial_param_version: + # 1. Current parameter is the identical object as the initial parameter + # 2. The version nubmer of the current parameter never changed + return False + + return True + + @torch.no_grad() + def __deepcopy__(self, memo): + self_copy = self.__new__(type(self)) + self_copy.__dict__ = copy.deepcopy(self.__dict__, memo) + + for name, param in self_copy.named_parameters(): + # Register parameters to the copied quantizer + self_copy.register_quantization_parameter(name, param) + + # If the parameter has been already initialized, + # artificially increment the parameter version to mark as initialized + if self._is_initialized(name): + param.mul_(1.) + + return self_copy + + def __getstate__(self): + state = self.__dict__.copy() + state.pop('_initial_parameters') + state['initialized_parameters'] = [param_name for param_name, _ in self.named_parameters() + if self._is_initialized(param_name)] + return state + + @torch.no_grad() + def __setstate__(self, state): + initialized_parameters = state.pop('initialized_parameters') + self.__dict__.update(state) + + self._initial_parameters = OrderedDict() + for param_name, param in self.named_parameters(): + # Register parameters to the loaded quantizer + self.register_quantization_parameter(param_name, param) + + # If the parameter has been already initialized, + # artificially increment the parameter version to mark as initialized + if param_name in initialized_parameters: + param.mul_(1.) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantizers/float.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantizers/float.py new file mode 100644 index 00000000000..031599c81d9 --- /dev/null +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/quantizers/float.py @@ -0,0 +1,48 @@ +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= +""" Float quantizers """ + +from aimet_torch.experimental.v2.quantization.quantizers.base import QuantizerBase + + +class FloatQuantizerBase(QuantizerBase): + """ + Base class for float quantization modules. + """ + def __init__(self, *args, **kwargs): + # pylint: disable=super-init-not-called + raise NotImplementedError diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/wrappers/builder.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/wrappers/builder.py index 83fd8ce73eb..13705bdc5b6 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/wrappers/builder.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/wrappers/builder.py @@ -47,7 +47,7 @@ from aimet_torch.qc_quantize_op import QcQuantizeOpMode, QcQuantizeWrapper, StaticGridQuantWrapper, tensor_quantizer_factory from aimet_torch.tensor_quantizer import TensorQuantizer, StaticGridPerChannelQuantizer from aimet_torch.experimental.v2.nn.fake_quant import FakeQuantizationMixin -from aimet_torch.experimental.v2.quantization.modules.quantize import QuantizeDequantize +from aimet_torch.experimental.v2.quantization.quantizers.affine import QuantizeDequantize from aimet_torch.experimental.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer, PercentileEncodingAnalyzer diff --git a/TrainingExtensions/torch/test/python/experimental/v2/nn/test_activation.py b/TrainingExtensions/torch/test/python/experimental/v2/nn/test_activation.py index 9af73df9120..572c5b0111f 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/nn/test_activation.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/nn/test_activation.py @@ -39,7 +39,7 @@ import torch import torch.nn.functional as F from aimet_torch.experimental.v2.quantization.backends import get_backend -from aimet_torch.experimental.v2.quantization.modules.quantize import QuantizeDequantize +from aimet_torch.experimental.v2.quantization.quantizers.affine import QuantizeDequantize from aimet_torch.experimental.v2.nn.fake_quant import FakeQuantizedSoftmax from aimet_torch.experimental.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer diff --git a/TrainingExtensions/torch/test/python/experimental/v2/nn/test_linear.py b/TrainingExtensions/torch/test/python/experimental/v2/nn/test_linear.py index 7b1b815e32e..82b97a287f1 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/nn/test_linear.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/nn/test_linear.py @@ -40,7 +40,7 @@ from torch import nn import torch.nn.functional as F from aimet_torch.experimental.v2.quantization.backends import get_backend -from aimet_torch.experimental.v2.quantization.modules.quantize import QuantizeDequantize +from aimet_torch.experimental.v2.quantization.quantizers.affine import QuantizeDequantize from aimet_torch.experimental.v2.nn.fake_quant import FakeQuantizedLinear, FakeQuantizationMixin from aimet_torch.experimental.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer diff --git a/TrainingExtensions/torch/test/python/experimental/v2/test_quantizer_.py b/TrainingExtensions/torch/test/python/experimental/v2/quantizers/test_affine_quantizer.py similarity index 97% rename from TrainingExtensions/torch/test/python/experimental/v2/test_quantizer_.py rename to TrainingExtensions/torch/test/python/experimental/v2/quantizers/test_affine_quantizer.py index 0215e0eb908..158b9e9d7a0 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/test_quantizer_.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/quantizers/test_affine_quantizer.py @@ -2,7 +2,7 @@ # ============================================================================= # @@-COPYRIGHT-START-@@ # -# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2023-2024, Qualcomm Innovation Center, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: @@ -42,8 +42,7 @@ from torch import nn from torch.optim import SGD, RMSprop, Adagrad, Adam, AdamW from aimet_torch.experimental.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer -from aimet_torch.experimental.v2.quantization import Quantize, QuantizeDequantize -from aimet_torch.experimental.v2.quantization.modules.quantize import _QuantizerBase +from aimet_torch.experimental.v2.quantization.quantizers.affine import AffineQuantizerBase, Quantize, QuantizeDequantize from aimet_torch.experimental.v2.quantization.backends import get_backend @@ -208,7 +207,7 @@ def test_qdq_compute_encodings(quantize_dequantize: QuantizeDequantize, x: torch quantize_dequantize(symmetric=True, initialized=False), quantize_dequantize(symmetric=True, initialized=True), ]) -def test_compute_encodings_with_no_input(q: _QuantizerBase): +def test_compute_encodings_with_no_input(q: AffineQuantizerBase): """ :param q: Quantize or QuantizeDequantize module @@ -247,7 +246,7 @@ def test_compute_encodings_with_no_input(q: _QuantizerBase): quantize(symmetric=False, initialized=True), quantize_dequantize(symmetric=False, initialized=True), ]) -def test_backward_during_compute_encodings(q: _QuantizerBase, x: torch.Tensor): +def test_backward_during_compute_encodings(q: AffineQuantizerBase, x: torch.Tensor): """ :param q: Quantize or QuantizeDequantize module :param x: Input tensor @@ -276,7 +275,7 @@ def test_backward_during_compute_encodings(q: _QuantizerBase, x: torch.Tensor): quantize(symmetric=True, initialized=False), quantize_dequantize(symmetric=True, initialized=False), ]) -def test_compute_encodings_updates_parameters_upon_exit(q: _QuantizerBase, x: torch.Tensor): +def test_compute_encodings_updates_parameters_upon_exit(q: AffineQuantizerBase, x: torch.Tensor): """ :param q: Quantize or QuantizeDequantize module :param x: Input tensor @@ -364,7 +363,7 @@ def test_qdq_forward(quantize_dequantize: QuantizeDequantize, x: torch.Tensor): quantize_dequantize(symmetric=True, initialized=True), quantize_dequantize(symmetric=False, initialized=True), ]) -def test_backward(q: _QuantizerBase, x: torch.Tensor): +def test_backward(q: AffineQuantizerBase, x: torch.Tensor): """ :param q: Quantize or QuantizeDequantize module :param x: Input tensor @@ -421,7 +420,7 @@ def test_backward_with_no_grad(q, x: torch.Tensor): quantize(symmetric=True, initialized=False), quantize_dequantize(symmetric=True, initialized=False), ]) -def test_uninitialized_quantize(q: _QuantizerBase, x: torch.Tensor): +def test_uninitialized_quantize(q: AffineQuantizerBase, x: torch.Tensor): """ :param q: Quantize or QuantizeDequantize module :param x: Input tensor @@ -575,7 +574,7 @@ def test_symmetric_learning(q, x, optim_cls): quantize(symmetric=False, initialized=False), quantize_dequantize(symmetric=False, initialized=False), ]) -def test_asymmetric_invariants(q: _QuantizerBase, x: torch.Tensor): +def test_asymmetric_invariants(q: AffineQuantizerBase, x: torch.Tensor): """ Given: Asymmetric quantizer When: Quantization parameters initialized with compute_encodings diff --git a/TrainingExtensions/torch/test/python/experimental/v2/quantizers/test_float_quantizer.py b/TrainingExtensions/torch/test/python/experimental/v2/quantizers/test_float_quantizer.py new file mode 100644 index 00000000000..e2bd9499cbc --- /dev/null +++ b/TrainingExtensions/torch/test/python/experimental/v2/quantizers/test_float_quantizer.py @@ -0,0 +1,36 @@ +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= diff --git a/TrainingExtensions/torch/test/python/experimental/v2/test_quantsim_config_.py b/TrainingExtensions/torch/test/python/experimental/v2/test_quantsim_config_.py index cfad081f2a6..e2b8a74fa56 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/test_quantsim_config_.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/test_quantsim_config_.py @@ -54,7 +54,7 @@ from aimet_torch.experimental.v2.nn.fake_quant import FakeQuantizationMixin from aimet_torch.experimental.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer from aimet_torch.experimental.v2.quantization.quantsim import QuantizationSimModel -from aimet_torch.experimental.v2.quantization.modules.quantize import QuantizeDequantize +from aimet_torch.experimental.v2.quantization.quantizers.affine import QuantizeDequantize from models_.models_to_test import SingleResidual, QuantSimTinyModel, MultiInput, SingleResidualWithModuleAdd, \ SingleResidualWithAvgPool, ModelWithBertCustomLayerNormGelu diff --git a/TrainingExtensions/torch/test/python/experimental/v2/test_quantsim_v1_export.py b/TrainingExtensions/torch/test/python/experimental/v2/test_quantsim_v1_export.py index 10149b75c40..c41b1b8666f 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/test_quantsim_v1_export.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/test_quantsim_v1_export.py @@ -44,7 +44,7 @@ # from aimet_torch.experimental.v2.quantization.wrappers.quantization_mixin import _QuantizationMixin import aimet_torch.experimental.v2.nn as aimet_nn from aimet_torch.experimental.v2.nn.fake_quant import FakeQuantizationMixin -from aimet_torch.experimental.v2.quantization.modules.quantize import QuantizeDequantize +from aimet_torch.experimental.v2.quantization.quantizers.affine import QuantizeDequantize from aimet_torch.experimental.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer from aimet_torch.elementwise_ops import Add from aimet_torch import onnx_utils