Skip to content

Commit

Permalink
Implement Quantize, Dequantize, and QuantizeDequantize (#2588)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored Dec 2, 2023
1 parent ee6a080 commit 34188d2
Show file tree
Hide file tree
Showing 8 changed files with 721 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,15 @@
from typing import Union
import torch

def _is_expandable(a: torch.Tensor, b: torch.Tensor) -> bool:
"""
Returns true if tensor a is expandable to shape of tensor b
"""
if len(a.shape) > len(b.shape):
return False
for dim_a, dim_b in zip(a.shape[::-1], b.shape[::-1]):
if dim_a not in (1, dim_b):
return False
return True
from aimet_torch.experimental.v2.utils import _is_expandable


def _validate_arguments(tensor: torch.Tensor, scale: torch.Tensor, offset: torch.Tensor, bitwidth: Union[torch.Tensor, int] = None):
if not tensor.dtype == scale.dtype == offset.dtype:
raise RuntimeError("Data type of tensor, scale, and offset are should be the same")
if bitwidth and torch.finfo(tensor.dtype).bits <= bitwidth:
raise RuntimeError(f"Dtype {tensor.dtype} has insufficient bitwidth to perform {bitwidth} quantization")
if not _is_expandable(scale, tensor):
if not _is_expandable(scale.shape, tensor.shape):
raise RuntimeError(f"Scale of shape {scale.shape} cannot be expanded like input tensor of shape {tensor.shape}")

def quantize(tensor: torch.Tensor, scale: torch.Tensor, offset: torch.Tensor, bitwidth: Union[torch.Tensor, int]) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,60 @@
# @@-COPYRIGHT-END-@@
# =============================================================================
# pylint: disable=all
from typing import Protocol

def set_backend():
...
import torch

def get_backend():
...
from aimet_torch.experimental.v2.utils import _ContextManager


__all__ = ['set_backend', 'get_backend']
class _QuantizationBackendProtocol(Protocol):
def quantize(self, input: torch.Tensor) -> torch.Tensor:
...

def dequantize(self,
input: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor) -> torch.Tensor:
...

def quantize_dequantize(self, input: torch.Tensor) -> torch.Tensor:
...


_CURRENT_BACKEND = 'default'

_SUPPORTED_BACKENDS = {
'default': None,
}


def set_global_backend(name: str):
global _CURRENT_BACKEND
_CURRENT_BACKEND = name


def set_backend(name: str) -> _ContextManager:
if name not in _SUPPORTED_BACKENDS:
supported_backend_names = ", ".join(_SUPPORTED_BACKENDS.keys())
raise RuntimeError(f"Backend '{name}' is not supported. "
f"Please choose one of: {supported_backend_names}")

old_backend = _CURRENT_BACKEND
action = lambda: set_global_backend(name)
cleanup = lambda: set_global_backend(old_backend)
return _ContextManager(action=action, cleanup=cleanup)



def get_backend() -> _QuantizationBackendProtocol:
if _SUPPORTED_BACKENDS[_CURRENT_BACKEND] is None:
# Lazy import
import importlib
module_name = f'aimet_torch.experimental.v2.quantization.backends.{_CURRENT_BACKEND}'
_SUPPORTED_BACKENDS[_CURRENT_BACKEND] = importlib.import_module(module_name)

return _SUPPORTED_BACKENDS[_CURRENT_BACKEND]


__all__ = ['set_global_backend', 'set_backend', 'get_backend']
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
# pylint: disable=all
from typing import TypeVar, Generic, Tuple, Type, Optional
import abc
from dataclasses import dataclass

import torch

from aimet_torch.experimental.v2.utils import reduce


@dataclass(frozen=True)
class _MinMaxRange:
min: Optional[torch.Tensor] = None
max: Optional[torch.Tensor] = None


class _Histogram:
# TODO
...


_Statistics = TypeVar('_Statistics', _MinMaxRange, _Histogram)


class _Observer(Generic[_Statistics], abc.ABC):
def __init__(self, shape):
self.shape = shape

@abc.abstractmethod
def collect_stats(self, x: torch.Tensor) -> _Statistics:
...

@abc.abstractmethod
def merge_stats(self, stats: _Statistics):
...

@abc.abstractmethod
def reset_stats(self):
...

@abc.abstractmethod
def get_stats(self) -> _Statistics:
...


class _MinMaxObserver(_Observer[_MinMaxRange]):
def __init__(self, shape):
super().__init__(shape)
self.stats = _MinMaxRange()

@torch.no_grad()
def collect_stats(self, x: torch.Tensor) -> _MinMaxRange:
min = reduce(x, shape=self.shape, reduce_op=torch.min).values
max = reduce(x, shape=self.shape, reduce_op=torch.max).values
return _MinMaxRange(min, max)

@torch.no_grad()
def merge_stats(self, new_stats: _MinMaxRange):
min = self.stats.min
if new_stats.min is not None:
if min is None:
min = new_stats.min.clone()
else:
min = torch.minimum(min, new_stats.min)

max = self.stats.max
if new_stats.max is not None:
if max is None:
max = new_stats.max.clone()
else:
max = torch.maximum(max, new_stats.max)

self.stats = _MinMaxRange(min, max)

def reset_stats(self):
self.stats = _MinMaxRange()

def get_stats(self) -> _MinMaxRange:
return self.stats


class _HistogramObserver(_Observer[_Histogram]):
def __init__(self, shape):
# TODO
raise NotImplementedError

@torch.no_grad()
def collect_stats(self, x: torch.Tensor) -> _Histogram:
# TODO
raise NotImplementedError

@torch.no_grad()
def merge_stats(self, new_stats: _Histogram):
# TODO
raise NotImplementedError

def reset_stats(self):
# TODO
raise NotImplementedError

def get_stats(self) -> _Histogram:
# TODO
raise NotImplementedError


class _EncodingAnalyzer(Generic[_Statistics], abc.ABC):
observer_cls: Type[_Observer[_Statistics]]

def __init__(self, shape):
self.observer = self.observer_cls(shape)

@torch.no_grad()
def update_stats(self, x: torch.Tensor) -> _Statistics:
new_stats = self.observer.collect_stats(x)
self.observer.merge_stats(new_stats)
return new_stats

def reset_stats(self) -> None:
self.observer.reset_stats()

def compute_encodings(self, symmetric: bool, bitwidth: int)\
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
return self.compute_encodings_from_stats(self.observer.get_stats(), symmetric, bitwidth)

def compute_dynamic_encodings(self, x: torch.Tensor, symmetric: bool, bitwidth: int)\
-> Tuple[torch.Tensor, torch.Tensor]:
return self.compute_encodings_from_stats(self.observer.collect_stats(x), symmetric, bitwidth)

@abc.abstractmethod
def compute_encodings_from_stats(self, stats: _Statistics, symmetric: bool, bitwidth: int)\
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
...


class MinMaxEncodingAnalyzer(_EncodingAnalyzer[_MinMaxRange]):
observer_cls = _MinMaxObserver

@torch.no_grad()
def compute_encodings_from_stats(self, stats: _MinMaxRange, symmetric: bool, bitwidth: int)\
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if stats.min is None or stats.max is None:
return None, None

if symmetric:
min = torch.minimum(stats.min, -stats.max)
max = torch.maximum(-stats.min, stats.max)
else:
min = stats.min
max = stats.max

return min, max


class PercentileEncodingAnalyzer(_EncodingAnalyzer[_Histogram]):
observer_cls = _HistogramObserver

@torch.no_grad()
def compute_encodings_from_stats(self, stats: _Histogram, symmetric: bool, bitwidth: int)\
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
# TODO
raise NotImplementedError


class SqnrEncodingAnalyzer(_EncodingAnalyzer[_Histogram]):
observer_cls = _HistogramObserver

@torch.no_grad()
def compute_encodings_from_stats(self, stats: _Histogram, symmetric: bool, bitwidth: int)\
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
# TODO
raise NotImplementedError


class MseEncodingAnalyzer(_EncodingAnalyzer[_Histogram]):
observer_cls = _HistogramObserver

@torch.no_grad()
def compute_encodings_from_stats(self, stats: _Histogram, symmetric: bool, bitwidth: int)\
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
# TODO
raise NotImplementedError


def get_encoding_analyzer_cls(qscheme):
if qscheme == 'minmax':
return MinMaxEncodingAnalyzer

raise ValueError
Loading

0 comments on commit 34188d2

Please sign in to comment.