Skip to content

Commit

Permalink
Refactor quantizer base class
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored Jan 30, 2024
1 parent 0ae2b4a commit 59a4d89
Show file tree
Hide file tree
Showing 13 changed files with 376 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from torch.utils._pytree import tree_map

from aimet_torch.experimental.v2.nn.quant_base import BaseQuantizationMixin
from aimet_torch.experimental.v2.quantization.modules.quantize import _QuantizerBase
from aimet_torch.experimental.v2.quantization.quantizers import QuantizerBase
import aimet_torch.elementwise_ops as aimet_ops


Expand Down Expand Up @@ -77,7 +77,7 @@ def export_input_encodings(self) -> List[List[Dict]]:
Returns a list of input encodings, each represented as a List of Dicts
"""
return [
quantizer.get_encodings() if isinstance(quantizer, _QuantizerBase) else None
quantizer.get_encodings() if isinstance(quantizer, QuantizerBase) else None
for quantizer in _flatten_nn_module_list(self.input_quantizers)
]

Expand All @@ -86,7 +86,7 @@ def export_output_encodings(self) -> List[List[Dict]]:
Returns a list of output encodings, each represented as a List of Dicts
"""
return [
quantizer.get_encodings() if isinstance(quantizer, _QuantizerBase) else None
quantizer.get_encodings() if isinstance(quantizer, QuantizerBase) else None
for quantizer in _flatten_nn_module_list(self.output_quantizers)
]

Expand All @@ -95,7 +95,7 @@ def export_param_encodings(self) -> Dict[str, List[Dict]]:
Returns a dict of {param name: param encodings}, with each encoding represented as a List of Dicts
"""
return {
param_name: quantizer.get_encodings() if isinstance(quantizer, _QuantizerBase) else None
param_name: quantizer.get_encodings() if isinstance(quantizer, QuantizerBase) else None
for param_name, quantizer in self.param_quantizers.items()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@

# pylint: disable=all

from .modules import *
from .quantizers import *
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@
# =============================================================================
# pylint: disable=all

from .quantize import *
from .base import *
from .affine import *
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2023-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -35,30 +35,29 @@
# @@-COPYRIGHT-END-@@
# =============================================================================
# pylint: disable=redefined-builtin
""" nn.Modules for quantization operators """
""" Affine quantizers """

import copy
import abc
from typing import Optional, Tuple, List, Dict
import contextlib
from collections import OrderedDict
import functools
import weakref

import torch
from torch import nn

from aimet_torch.experimental.v2.utils import patch_attr, patch_param, _is_expandable, StatisticsNotFoundError
from aimet_torch.experimental.v2.quantization.encoding_analyzer import EncodingAnalyzer, MinMaxEncodingAnalyzer
from aimet_torch.experimental.v2.quantization.quantizers.base import QuantizerBase
from aimet_torch.experimental.v2.quantization.backends import get_backend
from aimet_torch.experimental.v2.utils import ste_round


__all__ = ['Quantize', 'QuantizeDequantize', 'Dequantize']
__all__ = ['AffineQuantizerBase', 'MinMaxQuantizer', 'Quantize', 'QuantizeDequantize', 'Dequantize']


class _QuantizerBase(torch.nn.Module): # pylint: disable=abstract-method
class AffineQuantizerBase(QuantizerBase):
"""
Base class for quantization modules.
Base class for linear quantization modules.
:param shape: Shape of the quantization parameters.
:param bitwidth: Quantization bitwidth.
Expand All @@ -67,10 +66,6 @@ class _QuantizerBase(torch.nn.Module): # pylint: disable=abstract-method
:param encoding_analyzer: Encoding analyzer for calibrating quantization encodings.
(default: absolute min-max encoding analyzer)
"""

min: torch.nn.Parameter
max: torch.nn.Parameter

def __init__(self, shape, bitwidth: int, symmetric: bool, encoding_analyzer: EncodingAnalyzer = None):
super().__init__()
if isinstance(shape, int):
Expand All @@ -84,147 +79,47 @@ def __init__(self, shape, bitwidth: int, symmetric: bool, encoding_analyzer: Enc
raise RuntimeError(f'Encoding analyzer of shape {self.encoding_analyzer.observer.shape} '
f'is incompatible with quantizer of shape {self.shape}.')

# param_name -> (weakref of initial parameter, version info of the initial parameter)
# This info will be used for judging whether the current parameter has ever been
# initialized after it was instantiated.
self._initial_parameters = OrderedDict()

# Raw quantization parameters
self.register_quantization_parameter('min', nn.Parameter(-torch.ones(self.shape)))
self.register_quantization_parameter('max', nn.Parameter(torch.ones(self.shape)))

@torch.no_grad()
def __deepcopy__(self, memo):
self_copy = self.__new__(type(self))
self_copy.__dict__ = copy.deepcopy(self.__dict__, memo)

for name, param in self_copy.named_parameters():
# Register parameters to the copied quantizer
self_copy.register_quantization_parameter(name, param)

# If the parameter has been already initialized,
# artificially increment the parameter version to mark as initialized
if self._is_initialized(name):
param.mul_(1.)

return self_copy

def __getstate__(self):
state = self.__dict__.copy()
state.pop('_initial_parameters')
state['initialized_parameters'] = [param_name for param_name, _ in self.named_parameters()
if self._is_initialized(param_name)]
return state

@torch.no_grad()
def __setstate__(self, state):
initialized_parameters = state.pop('initialized_parameters')
self.__dict__.update(state)

self._initial_parameters = OrderedDict()
for param_name, param in self.named_parameters():
# Register parameters to the loaded quantizer
self.register_quantization_parameter(param_name, param)

# If the parameter has been already initialized,
# artificially increment the parameter version to mark as initialized
if param_name in initialized_parameters:
param.mul_(1.)

def register_quantization_parameter(self, name: str, param: nn.Parameter):
@abc.abstractmethod
def get_min(self) -> torch.Tensor:
"""
Register quantization parameter.
"""
# pylint: disable=protected-access

self.register_parameter(name, param)
param = getattr(self, name)
self._initial_parameters[name] = (weakref.ref(param), param._version)

def _is_initialized(self, param_name) -> bool:
# pylint: disable=protected-access

initial_param_weakref, initial_param_version = self._initial_parameters[param_name]
initial_param = initial_param_weakref()

if initial_param is None:
# The initial parameter object doesn't exist in memory space anymore.
return True

current_param = getattr(self, param_name)

if current_param is initial_param and current_param._version == initial_param_version:
# 1. Current parameter is the identical object as the initial parameter
# 2. The version nubmer of the current parameter never changed
return False

return True

def is_initialized(self) -> bool:
"""
Returns true if the quantization parameters are initialized.
"""
for param_name, _ in self.named_parameters():
if not self._is_initialized(param_name):
return False
return True

def get_min(self) -> Optional[torch.Tensor]:
"""
Compute quantization min to be used for forward pass based on raw parameters.
Compute quantization min to be used for forward pass.
Return None f the quantizer is not initialized yet.
:return: Quantization min
"""
if not self.is_initialized():
return None
return self.get_scale() * self.get_offset()

def get_max(self) -> Optional[torch.Tensor]:
@abc.abstractmethod
def get_max(self) -> torch.Tensor:
"""
Compute quantization max to be used for forward pass based on raw parameters.
Compute quantization max to be used for forward pass.
Return None f the quantizer is not initialized yet.
:return: Quantization max
"""
if not self.is_initialized():
return None
return self.get_scale() * (self.get_offset() + 2 ** self.bitwidth - 1)

def get_scale(self) -> Optional[torch.Tensor]:
@abc.abstractmethod
def get_scale(self) -> torch.Tensor:
"""
Compute quantization scale to be used for forward pass based on raw parameters.
Compute quantization scale to be used for forward pass.
Return None f the quantizer is not initialized yet.
:return: Quantization scale
"""
if not self.is_initialized():
return None

num_bins = 2 ** self.bitwidth - 1

if self.symmetric:
positive_bins = num_bins // 2
negative_bins = positive_bins + 1
scale = torch.maximum(-self.min / negative_bins, self.max / positive_bins)
else:
scale = (self.max - self.min) / num_bins

return scale

def get_offset(self) -> Optional[torch.Tensor]:
@abc.abstractmethod
def get_offset(self) -> torch.Tensor:
"""
Compute quantization offset to be used for forward pass based on raw parameters.
Compute quantization offset to be used for forward pass.
Return None f the quantizer is not initialized yet.
:return: Quantization offset
"""
if not self.is_initialized():
return None

if self.symmetric:
with torch.no_grad():
offset = -torch.ones_like(self.min) * 2 ** (self.bitwidth - 1)
else:
offset = ste_round(self.min / self.get_scale())

return offset
@abc.abstractmethod
def set_range(self, min: torch.Tensor, max: torch.Tensor):
"""
Set quantization parameters to the given min-max range
"""

@torch.no_grad()
def get_encodings(self) -> Optional[List[Dict]]:
Expand All @@ -251,6 +146,24 @@ def get_encodings(self) -> Optional[List[Dict]]:
for min_, max_, scale_, offset_ in zip(min, max, scale, offset)
]

def extra_repr(self) -> str:
return f'shape={self.shape}, bitwidth={self.bitwidth}, symmetric={self.symmetric}'


class MinMaxQuantizer(AffineQuantizerBase): # pylint: disable=abstract-method
"""
Affine quantizer with min-max as trainable parameters
"""

min: torch.nn.Parameter
max: torch.nn.Parameter

def __init__(self, shape, bitwidth: int, symmetric: bool, encoding_analyzer: EncodingAnalyzer = None):
super().__init__(shape, bitwidth, symmetric, encoding_analyzer)

self.register_quantization_parameter('min', nn.Parameter(-torch.ones(self.shape)))
self.register_quantization_parameter('max', nn.Parameter(torch.ones(self.shape)))

@contextlib.contextmanager
def compute_encodings(self):
"""
Expand Down Expand Up @@ -285,17 +198,84 @@ def forward_wrapper(input):
if min is None or max is None:
return

with torch.no_grad():
self.min.copy_(min)
self.max.copy_(max)
self.set_range(min, max)

finally:
self.encoding_analyzer.reset_stats()

def extra_repr(self) -> str:
return f'shape={self.shape}, bitwidth={self.bitwidth}, symmetric={self.symmetric}'
def get_min(self) -> Optional[torch.Tensor]:
"""
Compute quantization min to be used for forward pass.
NOTE: self.min may not be equal to self.get_min().
self.get_min() returns slightly recalibrated version of self.min.
:return: Quantization min
"""
if not self.is_initialized():
return None
return self.get_scale() * self.get_offset()

def get_max(self) -> Optional[torch.Tensor]:
"""
Compute quantization max to be used for forward pass.
NOTE: self.max may not be equal to self.get_max()
self.get_max() returns slightly recalibrated version of self.max.
:return: Quantization max
"""
if not self.is_initialized():
return None
return self.get_scale() * (self.get_offset() + 2 ** self.bitwidth - 1)

def get_scale(self) -> Optional[torch.Tensor]:
"""
Compute quantization scale to be used for forward pass.
:return: Quantization scale
"""
if not self.is_initialized():
return None

num_bins = 2 ** self.bitwidth - 1

if self.symmetric:
positive_bins = num_bins // 2
negative_bins = positive_bins + 1
scale = torch.maximum(-self.min / negative_bins, self.max / positive_bins)
else:
scale = (self.max - self.min) / num_bins

return scale

def get_offset(self) -> Optional[torch.Tensor]:
"""
Compute quantization offset to be used for forward pass.
:return: Quantization offset
"""
if not self.is_initialized():
return None

if self.symmetric:
with torch.no_grad():
offset = -torch.ones_like(self.min) * 2 ** (self.bitwidth - 1)
else:
offset = ste_round(self.min / self.get_scale())

return offset

def set_range(self, min: torch.Tensor, max: torch.Tensor):
"""
Set quantization parameters to the given min-max range
"""
with torch.no_grad():
self.min.copy_(min)
self.max.copy_(max)


class Quantize(_QuantizerBase):
class Quantize(MinMaxQuantizer):
"""
Applies quantization to the input
"""
Expand All @@ -316,7 +296,7 @@ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torc
return input_q, scale, offset


class QuantizeDequantize(_QuantizerBase):
class QuantizeDequantize(MinMaxQuantizer):
"""
Applies quantization followed by dequantization to the input
"""
Expand Down
Loading

0 comments on commit 59a4d89

Please sign in to comment.