diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/backends/default.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/backends/default.py index cec0a9bc304..76a64be8ba9 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/backends/default.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/backends/default.py @@ -38,23 +38,15 @@ from typing import Union import torch -def _is_expandable(a: torch.Tensor, b: torch.Tensor) -> bool: - """ - Returns true if tensor a is expandable to shape of tensor b - """ - if len(a.shape) > len(b.shape): - return False - for dim_a, dim_b in zip(a.shape[::-1], b.shape[::-1]): - if dim_a not in (1, dim_b): - return False - return True +from aimet_torch.experimental.v2.utils import _is_expandable + def _validate_arguments(tensor: torch.Tensor, scale: torch.Tensor, offset: torch.Tensor, bitwidth: Union[torch.Tensor, int] = None): if not tensor.dtype == scale.dtype == offset.dtype: raise RuntimeError("Data type of tensor, scale, and offset are should be the same") if bitwidth and torch.finfo(tensor.dtype).bits <= bitwidth: raise RuntimeError(f"Dtype {tensor.dtype} has insufficient bitwidth to perform {bitwidth} quantization") - if not _is_expandable(scale, tensor): + if not _is_expandable(scale.shape, tensor.shape): raise RuntimeError(f"Scale of shape {scale.shape} cannot be expanded like input tensor of shape {tensor.shape}") def quantize(tensor: torch.Tensor, scale: torch.Tensor, offset: torch.Tensor, bitwidth: Union[torch.Tensor, int]) -> torch.Tensor: diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/backends/utils.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/backends/utils.py index fe61ff0ee90..ed62c1ecd02 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/backends/utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/backends/utils.py @@ -35,12 +35,60 @@ # @@-COPYRIGHT-END-@@ # ============================================================================= # pylint: disable=all +from typing import Protocol -def set_backend(): - ... +import torch -def get_backend(): - ... +from aimet_torch.experimental.v2.utils import _ContextManager -__all__ = ['set_backend', 'get_backend'] +class _QuantizationBackendProtocol(Protocol): + def quantize(self, input: torch.Tensor) -> torch.Tensor: + ... + + def dequantize(self, + input: torch.Tensor, + scale: torch.Tensor, + offset: torch.Tensor) -> torch.Tensor: + ... + + def quantize_dequantize(self, input: torch.Tensor) -> torch.Tensor: + ... + + +_CURRENT_BACKEND = 'default' + +_SUPPORTED_BACKENDS = { + 'default': None, +} + + +def set_global_backend(name: str): + global _CURRENT_BACKEND + _CURRENT_BACKEND = name + + +def set_backend(name: str) -> _ContextManager: + if name not in _SUPPORTED_BACKENDS: + supported_backend_names = ", ".join(_SUPPORTED_BACKENDS.keys()) + raise RuntimeError(f"Backend '{name}' is not supported. " + f"Please choose one of: {supported_backend_names}") + + old_backend = _CURRENT_BACKEND + action = lambda: set_global_backend(name) + cleanup = lambda: set_global_backend(old_backend) + return _ContextManager(action=action, cleanup=cleanup) + + + +def get_backend() -> _QuantizationBackendProtocol: + if _SUPPORTED_BACKENDS[_CURRENT_BACKEND] is None: + # Lazy import + import importlib + module_name = f'aimet_torch.experimental.v2.quantization.backends.{_CURRENT_BACKEND}' + _SUPPORTED_BACKENDS[_CURRENT_BACKEND] = importlib.import_module(module_name) + + return _SUPPORTED_BACKENDS[_CURRENT_BACKEND] + + +__all__ = ['set_global_backend', 'set_backend', 'get_backend'] diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py new file mode 100644 index 00000000000..279f8b012ae --- /dev/null +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py @@ -0,0 +1,224 @@ +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= +# pylint: disable=all +from typing import TypeVar, Generic, Tuple, Type, Optional +import abc +from dataclasses import dataclass + +import torch + +from aimet_torch.experimental.v2.utils import reduce + + +@dataclass(frozen=True) +class _MinMaxRange: + min: Optional[torch.Tensor] = None + max: Optional[torch.Tensor] = None + + +class _Histogram: + # TODO + ... + + +_Statistics = TypeVar('_Statistics', _MinMaxRange, _Histogram) + + +class _Observer(Generic[_Statistics], abc.ABC): + def __init__(self, shape): + self.shape = shape + + @abc.abstractmethod + def collect_stats(self, x: torch.Tensor) -> _Statistics: + ... + + @abc.abstractmethod + def merge_stats(self, stats: _Statistics): + ... + + @abc.abstractmethod + def reset_stats(self): + ... + + @abc.abstractmethod + def get_stats(self) -> _Statistics: + ... + + +class _MinMaxObserver(_Observer[_MinMaxRange]): + def __init__(self, shape): + super().__init__(shape) + self.stats = _MinMaxRange() + + @torch.no_grad() + def collect_stats(self, x: torch.Tensor) -> _MinMaxRange: + min = reduce(x, shape=self.shape, reduce_op=torch.min).values + max = reduce(x, shape=self.shape, reduce_op=torch.max).values + return _MinMaxRange(min, max) + + @torch.no_grad() + def merge_stats(self, new_stats: _MinMaxRange): + min = self.stats.min + if new_stats.min is not None: + if min is None: + min = new_stats.min.clone() + else: + min = torch.minimum(min, new_stats.min) + + max = self.stats.max + if new_stats.max is not None: + if max is None: + max = new_stats.max.clone() + else: + max = torch.maximum(max, new_stats.max) + + self.stats = _MinMaxRange(min, max) + + def reset_stats(self): + self.stats = _MinMaxRange() + + def get_stats(self) -> _MinMaxRange: + return self.stats + + +class _HistogramObserver(_Observer[_Histogram]): + def __init__(self, shape): + # TODO + raise NotImplementedError + + @torch.no_grad() + def collect_stats(self, x: torch.Tensor) -> _Histogram: + # TODO + raise NotImplementedError + + @torch.no_grad() + def merge_stats(self, new_stats: _Histogram): + # TODO + raise NotImplementedError + + def reset_stats(self): + # TODO + raise NotImplementedError + + def get_stats(self) -> _Histogram: + # TODO + raise NotImplementedError + + +class _EncodingAnalyzer(Generic[_Statistics], abc.ABC): + observer_cls: Type[_Observer[_Statistics]] + + def __init__(self, shape): + self.observer = self.observer_cls(shape) + + @torch.no_grad() + def update_stats(self, x: torch.Tensor) -> _Statistics: + new_stats = self.observer.collect_stats(x) + self.observer.merge_stats(new_stats) + return new_stats + + def reset_stats(self) -> None: + self.observer.reset_stats() + + def compute_encodings(self, symmetric: bool, bitwidth: int)\ + -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + return self.compute_encodings_from_stats(self.observer.get_stats(), symmetric, bitwidth) + + def compute_dynamic_encodings(self, x: torch.Tensor, symmetric: bool, bitwidth: int)\ + -> Tuple[torch.Tensor, torch.Tensor]: + return self.compute_encodings_from_stats(self.observer.collect_stats(x), symmetric, bitwidth) + + @abc.abstractmethod + def compute_encodings_from_stats(self, stats: _Statistics, symmetric: bool, bitwidth: int)\ + -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + ... + + +class MinMaxEncodingAnalyzer(_EncodingAnalyzer[_MinMaxRange]): + observer_cls = _MinMaxObserver + + @torch.no_grad() + def compute_encodings_from_stats(self, stats: _MinMaxRange, symmetric: bool, bitwidth: int)\ + -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + if stats.min is None or stats.max is None: + return None, None + + if symmetric: + min = torch.minimum(stats.min, -stats.max) + max = torch.maximum(-stats.min, stats.max) + else: + min = stats.min + max = stats.max + + return min, max + + +class PercentileEncodingAnalyzer(_EncodingAnalyzer[_Histogram]): + observer_cls = _HistogramObserver + + @torch.no_grad() + def compute_encodings_from_stats(self, stats: _Histogram, symmetric: bool, bitwidth: int)\ + -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + # TODO + raise NotImplementedError + + +class SqnrEncodingAnalyzer(_EncodingAnalyzer[_Histogram]): + observer_cls = _HistogramObserver + + @torch.no_grad() + def compute_encodings_from_stats(self, stats: _Histogram, symmetric: bool, bitwidth: int)\ + -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + # TODO + raise NotImplementedError + + +class MseEncodingAnalyzer(_EncodingAnalyzer[_Histogram]): + observer_cls = _HistogramObserver + + @torch.no_grad() + def compute_encodings_from_stats(self, stats: _Histogram, symmetric: bool, bitwidth: int)\ + -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + # TODO + raise NotImplementedError + + +def get_encoding_analyzer_cls(qscheme): + if qscheme == 'minmax': + return MinMaxEncodingAnalyzer + + raise ValueError diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/modules/quantize.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/modules/quantize.py index c323f2ef8bf..12b8379a804 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/modules/quantize.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/modules/quantize.py @@ -34,43 +34,208 @@ # # @@-COPYRIGHT-END-@@ # ============================================================================= -# pylint: disable=all +# pylint: disable=redefined-builtin +""" nn.Modules for quantization operators """ + +from typing import Optional +import contextlib +import functools import torch +from aimet_torch.experimental.v2.utils import patch_attr, patch_param +from aimet_torch.experimental.v2.quantization.encoding_analyzer import get_encoding_analyzer_cls +from aimet_torch.experimental.v2.quantization.backends import get_backend +from aimet_torch.experimental.v2.utils import ste_round + __all__ = ['Quantize', 'QuantizeDequantize', 'Dequantize'] -class _QuantizerBase(torch.nn.Module): - def __init__(self, shape, bitwidth, symmetric, qscheme): +class _QuantizerBase(torch.nn.Module): # pylint: disable=abstract-method + """ + Base class for quantization modules. + + :param shape: Shape of the quantization parameters. + :param bitwidth: Quantization bitwidth. + :param symmetric: If True, performs symmetric quantization; + otherwise, performs asymmetric quantization. + :param qscheme: Quantization scheme + """ + + min: torch.nn.Parameter + max: torch.nn.Parameter + + def __init__(self, shape, bitwidth: int, symmetric: bool, qscheme): super().__init__() + self.shape = shape + self.bitwidth = bitwidth + self.symmetric = symmetric + self.qscheme = qscheme + self.encoding_analyzer = get_encoding_analyzer_cls(qscheme)(shape) + + # Raw quantization parameters self.register_parameter("min", None) self.register_parameter("max", None) - self.bitwidth = bitwidth - ... - def get_min(self) -> torch.Tensor: - ... + def is_initialized(self) -> bool: + """ + Returns true if the quantization parameters are initialized. + """ + return self.min is not None or self.max is not None + + def get_min(self) -> Optional[torch.Tensor]: + """ + Compute quantization min to be used for forward pass based on raw parameters. + + :return: Quantization min + """ + if not self.is_initialized(): + return None + return self.get_scale() * self.get_offset() + + def get_max(self) -> Optional[torch.Tensor]: + """ + Compute quantization max to be used for forward pass based on raw parameters. + + :return: Quantization max + """ + if not self.is_initialized(): + return None + return self.get_scale() * (self.get_offset() + 2 ** self.bitwidth - 1) + + def get_scale(self) -> Optional[torch.Tensor]: + """ + Compute quantization scale to be used for forward pass based on raw parameters. + + :return: Quantization scale + """ + if not self.is_initialized(): + return None - def get_max(self) -> torch.Tensor: - ... + num_bins = 2 ** self.bitwidth - 1 - def get_scale(self) -> torch.Tensor: - ... + if self.symmetric: + positive_bins = num_bins // 2 + negative_bins = positive_bins + 1 + scale = torch.maximum(-self.min / negative_bins, self.max / positive_bins) + else: + scale = (self.max - self.min) / num_bins - def get_offset(self) -> torch.Tensor: - ... + return scale + def get_offset(self) -> Optional[torch.Tensor]: + """ + Compute quantization offset to be used for forward pass based on raw parameters. + + :return: Quantization offset + """ + if not self.is_initialized(): + return None + + if self.symmetric: + with torch.no_grad(): + offset = -torch.ones_like(self.min) * 2 ** (self.bitwidth - 1) + else: + offset = ste_round(self.min / self.get_scale()) + + return offset + + @contextlib.contextmanager def compute_encodings(self): - ... + """ + Observe inputs and update quantization parameters based on the input statistics. + During ``compute_encodings`` is enabled, the quantizer forward pass performs + dynamic quantization using the batch statistics. + """ + original_forward = self.forward + + @functools.wraps(original_forward) + def forward_wrapper(input): + batch_statistics = self.encoding_analyzer.update_stats(input) + dynamic_min, dynamic_max =\ + self.encoding_analyzer.compute_encodings_from_stats(batch_statistics, + self.symmetric, + self.bitwidth) + with patch_param(self, 'min', dynamic_min),\ + patch_param(self, 'max', dynamic_max): + return original_forward(input) + + try: + with patch_attr(self, 'forward', forward_wrapper): + yield + except: # pylint: disable=try-except-raise + raise + else: + min, max = self.encoding_analyzer.compute_encodings(self.symmetric, self.bitwidth) + + if min is None or max is None: + return + + if not self.is_initialized(): + self.min = torch.nn.Parameter(torch.empty(self.shape)) + self.max = torch.nn.Parameter(torch.empty(self.shape)) + + with torch.no_grad(): + self.min.copy_(min) + self.max.copy_(max) + finally: + self.encoding_analyzer.reset_stats() class Quantize(_QuantizerBase): - ... + """ + Applies quantization to the input + """ + def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + :param input: Input to quantize + :return: Quantized output + """ + if not self.is_initialized(): + raise RuntimeError( + 'Failed to run Quantize since quantization parameters are not initialized.' + ' Please initialize the quantization parameters using `compute_encodings()`.' + ) + + scale = self.get_scale() + offset = self.get_offset() + return get_backend().quantize(input, scale, offset, self.bitwidth) + class QuantizeDequantize(_QuantizerBase): - ... + """ + Applies quantization followed by dequantization to the input + """ + def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + :param input: Input to quantize and dequantize + :return: Quantize-dequantized output + """ + if not self.is_initialized(): + raise RuntimeError( + 'Failed to run QuantizeDequantize since quantization parameters are not initialized.' + ' Please initialize the quantization parameters using `compute_encodings()`.' + ) + + scale = self.get_scale() + offset = self.get_offset() + return get_backend().quantize_dequantize(input, scale, offset, self.bitwidth) + class Dequantize(torch.nn.Module): - ... + """ + Applies dequantization to the input + """ + def forward(self, + input: torch.Tensor, + scale: torch.Tensor, + offset: torch.Tensor) -> torch.Tensor: + # pylint: disable=no-self-use + """ + :param input: Input to dequantize + :param scale: Quantization scale + :param offset: Quantization offset + :return: Dequantized output + """ + return get_backend().dequantize(input, scale, offset) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/utils.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/utils.py new file mode 100644 index 00000000000..721d2f91ed2 --- /dev/null +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/utils.py @@ -0,0 +1,173 @@ +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= +# pylint: disable=redefined-builtin +""" Common utility functions """ +from typing import Callable, Tuple +import functools +import itertools + +import torch + + +def _is_expandable(src_shape: Tuple[int, ...], + target_shape: Tuple[int, ...]) -> bool: + """ + Returns true if source shape can be expanded as target shape + """ + if len(src_shape) > len(target_shape): + return False + + for src_dim, dst_dim in zip(src_shape[::-1], target_shape[::-1]): + if src_dim not in (1, dst_dim): + return False + + return True + + +def _is_reducible(src_shape: Tuple[int, ...], + target_shape: Tuple[int, ...]) -> bool: + """ + Returns true if source shape can be reduced as target shape + """ + return _is_expandable(target_shape, src_shape) + + +def reduce(input: torch.Tensor, shape: Tuple[int, ...], reduce_op: Callable): + """ + Reduce input into given shape. + + :param input: Input to reduce + :param shape: Shape of the reduced output + :param reduce_op: Reduce operation + """ + if not _is_reducible(input.shape, shape): + raise RuntimeError( + f"Input of shape {list(input.shape)} can't be reduced to shape {list(shape)}" + ) + + padded_shape = ( + *itertools.repeat(1, len(input.shape) - len(shape)), + *shape + ) + reduce_dims = tuple(axis for axis, dim in enumerate(padded_shape) if dim == 1) + other_dims = tuple(axis for axis, dim in enumerate(padded_shape) if dim > 1) + permute_dims = reduce_dims + other_dims + + return reduce_op(input.permute(permute_dims).reshape(-1, *shape), dim=0, keepdim=False) + + +class _ContextManager: + def __init__(self, action: Callable[[], None], cleanup: Callable[[], None]): + self._action = action + self._cleanup = cleanup + + def __enter__(self): + self._action() + return self + + def __exit__(self, *_): + self._cleanup() + + def __call__(self, fn: Callable): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + with self: + return fn(*args, **kwargs) + return wrapper + + +def patch_attr(obj, attr_name, new_attr)-> _ContextManager: + """ + Temporarily overwrite object attribute + """ + old_attr = getattr(obj, attr_name) + action = lambda: setattr(obj, attr_name, new_attr) + cleanup = lambda: setattr(obj, attr_name, old_attr) + return _ContextManager(action, cleanup) + + +def patch_param(module: torch.nn.Module, param_name: str, new_param: torch.Tensor) -> _ContextManager: + """ + Temporarily substitute the reference to the a parameter with the quantized parameter. + Under the scope of this function, ``getattr(module, param_name)`` will return + ``new_param`` instead of the original parameter. + + :param module: Module that owns the parameter + :param param_name: Name of the parameter + :param new_param: New parameter to replace the original parameter + """ + original_param = getattr(module, param_name) + if original_param is not None: + assert original_param.shape == new_param.shape + + # Modify module.__dict__. + # module.__dict__ is the primary lookup table which has higher priority than __getattr__ method. + # Once we overwrite module.__dict__[param_name] with quantized_params, + # getattr(module, param_name) will return module.__dict__[param_name] directly + # without falling back to torch.nn.Module's __getattr__ method which returns + # the original parameter stored in module._parameters. + action = lambda: module.__dict__.update({param_name: new_param}) + + if param_name in module.__dict__: + # Some non-standard modules (e.g. replicas of torch.nn.DataParallel) store their parameters + # directly to module.__dict__. In that case, the cleanup function should restore the dict + # so that module.__dict__[param_name] points back to the original parameter again. + assert module.__dict__[param_name] is original_param + cleanup = lambda: module.__dict__.update({param_name: original_param}) + else: + assert module._parameters[param_name] is original_param # pylint: disable=protected-access + cleanup = lambda: module.__dict__.pop(param_name) + + + return _ContextManager(action, cleanup) + + +class _StraightThroughEstimator(torch.autograd.Function): # pylint: disable=abstract-method + @staticmethod + def forward(ctx, op, *args, **kwargs): # pylint: disable=arguments-differ + return op(*args, **kwargs) + + @staticmethod + def backward(ctx, *grad): + return (None, *grad) + + +def ste_round(*args, **kwargs): + """ + Applies straight-through rounding + """ + return _StraightThroughEstimator.apply(torch.round, *args, **kwargs) diff --git a/TrainingExtensions/torch/test/python/experimental/v2/test_backend.py b/TrainingExtensions/torch/test/python/experimental/v2/test_backend.py index 5230217fb71..9d508441cfb 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/test_backend.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/test_backend.py @@ -39,6 +39,7 @@ import pytest from collections import namedtuple from aimet_torch.experimental.v2.quantization.backends import default as default_backend +from aimet_torch.experimental.v2.utils import ste_round VectorSetForTest = namedtuple("VectorSetForTest", ["tensor", "tensor_q", "tensor_qdq", "mask", "delta", "offset", "bitwidth"]) @@ -174,15 +175,6 @@ bitwidth=8 ) -class STE(torch.autograd.Function): - @staticmethod - def forward(ctx, *x): - return torch.round(*x) - - @staticmethod - def backward(ctx, *output_grad): - return output_grad - class AutogradQuantizationModule(torch.nn.Module): def __init__(self, scale, offset, bitwidth): super().__init__() @@ -192,7 +184,7 @@ def __init__(self, scale, offset, bitwidth): def forward(self, x): return torch.clamp( - STE.apply(x / self.scale) - STE.apply(self.offset), + ste_round(x / self.scale) - ste_round(self.offset), 0, 2 ** self.bitwidth - 1 ) @@ -204,7 +196,7 @@ def __init__(self, scale, offset, bitwidth): self.offset = torch.nn.Parameter(offset.clone()) def forward(self, x): - return (x + STE.apply(self.offset)) * self.scale + return (x + ste_round(self.offset)) * self.scale class AutogradQuantDequantModule(torch.nn.Module): def __init__(self, scale, offset, bitwidth): @@ -215,10 +207,10 @@ def __init__(self, scale, offset, bitwidth): def forward(self, x): x_q = torch.clamp( - STE.apply(x / self.scale) - STE.apply(self.offset), + ste_round(x / self.scale) - ste_round(self.offset), 0, 2 ** self.bitwidth - 1 ) - x_dq = (x_q + STE.apply(self.offset)) * self.scale + x_dq = (x_q + ste_round(self.offset)) * self.scale return x_dq def copy_test_set(test_set: namedtuple, device: torch.device = torch.device("cpu"), diff --git a/TrainingExtensions/torch/test/python/experimental/v2/test_quantizer_.py b/TrainingExtensions/torch/test/python/experimental/v2/test_quantizer_.py index 6e18ded9d68..4ac10b553b1 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/test_quantizer_.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/test_quantizer_.py @@ -44,9 +44,6 @@ from aimet_torch.experimental.v2.quantization.backends import get_backend -pytestmark = pytest.mark.skip("not implemented") - - _PARAMETER_SHAPE = (100,) def _initialize(q, symmetric): @@ -127,7 +124,7 @@ def test_compute_encodings(q: Union[Quantize, QuantizeDequantize], 1. forward() returns dynamic quantization output 2. self.get_min(), self.get_max() == self.encoding_analyzer.compute_encodings() """ - dynamic_min, dynamic_max = q.encoding_analyzer.compute_dynamic_encodings(x) + dynamic_min, dynamic_max = q.encoding_analyzer.compute_dynamic_encodings(x, q.symmetric, q.bitwidth) if q.symmetric: dynamic_scale = torch.maximum(dynamic_max/127, -dynamic_min/128) @@ -548,18 +545,3 @@ def test_asymmetric_learning(q, x, optim_cls): assert not torch.equal(q.get_max(), original_max) assert not torch.equal(q.get_scale(), original_scale) assert not torch.equal(q.get_offset(), original_offset) - - -@pytest.mark.parametrize('q', [ - quantize(symmetric=False, initialized=False), - quantize(symmetric=True, initialized=False), - quantize(symmetric=False, initialized=False), - quantize(symmetric=True, initialized=True), - quantize_dequantize(symmetric=False, initialized=False), - quantize_dequantize(symmetric=True, initialized=False), - quantize_dequantize(symmetric=False, initialized=False), - quantize_dequantize(symmetric=True, initialized=True), -]) -def test_change_symmetry_flag_in_runtime(q): - with pytest.raises(RuntimeError): - q.symmetric = not q.symmetric diff --git a/TrainingExtensions/torch/test/python/experimental/v2/test_utils_.py b/TrainingExtensions/torch/test/python/experimental/v2/test_utils_.py new file mode 100644 index 00000000000..97fee2655a5 --- /dev/null +++ b/TrainingExtensions/torch/test/python/experimental/v2/test_utils_.py @@ -0,0 +1,80 @@ +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= +import pytest +import torch + +from aimet_torch.experimental.v2.utils import reduce + +@pytest.mark.parametrize('reduce_dim, target_shape', [ + # | reduce dim | target shape | + # | -------------|--------------| + ( [0,1,2,3], [] ), + + ( [0,1,2], [6] ), + ( [0,1,2], [1,6] ), + ( [0,1,2], [1,1,6] ), + ( [0,1,2], [1,1,1,6] ), + ( [0,1,3], [5,1] ), + ( [0,1,3], [1,5,1] ), + ( [0,1,3], [1,1,5,1] ), + ( [0,2,3], [4,1,1] ), + ( [0,2,3], [1,4,1,1] ), + ( [1,2,3], [3,1,1,1] ), + + ( [0,1], [5,6] ), + ( [0,1], [1,5,6] ), + ( [0,1], [1,1,5,6] ), + ( [0,2], [4,1,6] ), + ( [0,2], [1,4,1,6] ), + ( [1,2], [3,1,1,6] ), + ( [0,3], [4,5,1] ), + ( [0,3], [1,4,5,1] ), + ( [1,3], [3,1,5,1] ), + ( [2,3], [3,4,1,1] ), + + ( [0], [4,5,6] ), + ( [0], [1,4,5,6] ), + ( [1], [3,1,5,6] ), + ( [2], [3,4,1,6] ), + ( [3], [3,4,5,1] ), +]) +def test_reduce(reduce_dim, target_shape): + x = torch.arange(start=0, end=3*4*5*6).view(3,4,5,6) + out = reduce(x, target_shape, torch.sum) + expected = torch.sum(x, dim=reduce_dim, keepdim=True) + assert list(out.shape) == list(target_shape) + assert torch.allclose(out, expected)