Skip to content

Commit

Permalink
Make quantizer take encoding analyzer directly as input
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored Jan 29, 2024
1 parent c01ea3a commit 5a19b1b
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ class _Observer(Generic[_Statistics], ABC):
"""
Observes and gathers statistics
"""
def __init__(self, min_max_shape: tuple):
self.shape = min_max_shape
def __init__(self, shape: tuple):
if isinstance(shape, int):
shape = (shape,)
self.shape = shape

@abstractmethod
def collect_stats(self, input_tensor: torch.Tensor) -> _Statistics:
Expand All @@ -87,8 +89,8 @@ class _MinMaxObserver(_Observer[_MinMaxRange]):
"""
Observer for Min-Max calibration technique
"""
def __init__(self, min_max_shape: tuple):
super().__init__(min_max_shape)
def __init__(self, shape: tuple):
super().__init__(shape)
self.stats = _MinMaxRange()

@torch.no_grad()
Expand Down Expand Up @@ -125,8 +127,8 @@ class _HistogramObserver(_Observer[_Histogram]):
"""
Observer for Histogram based calibration techniques (percentile, MSE)
"""
def __init__(self, min_max_shape: tuple, num_bins: int):
super().__init__(min_max_shape)
def __init__(self, shape: tuple, num_bins: int):
super().__init__(shape)
self.stats = _Histogram()
self.num_bins = num_bins

Expand Down Expand Up @@ -223,7 +225,7 @@ class PercentileEncodingAnalyzer(EncodingAnalyzer[_Histogram]):
"""
Encoding Analyzer for Percentile calibration technique
"""
def __init__(self, shape: torch.Tensor, num_bins: int = 2048):
def __init__(self, shape: tuple, num_bins: int = 2048):
observer = _HistogramObserver(shape=shape, num_bins=num_bins)
super().__init__(observer)

Expand All @@ -237,7 +239,7 @@ class SqnrEncodingAnalyzer(EncodingAnalyzer[_Histogram]):
"""
Encoding Analyzer for SQNR Calibration technique
"""
def __init__(self, shape: torch.Tensor, num_bins: int = 2048):
def __init__(self, shape: tuple, num_bins: int = 2048):
observer = _HistogramObserver(shape=shape, num_bins=num_bins)
super().__init__(observer)

Expand All @@ -251,7 +253,7 @@ class MseEncodingAnalyzer(EncodingAnalyzer[_Histogram]):
"""
Encoding Analyzer for Mean Square Error (MSE) Calibration technique
"""
def __init__(self, shape: torch.Tensor, num_bins: int = 2048):
def __init__(self, shape: tuple, num_bins: int = 2048):
observer = _HistogramObserver(shape=shape, num_bins=num_bins)
super().__init__(observer)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
import torch
from torch import nn

from aimet_torch.experimental.v2.utils import patch_attr, patch_param, StatisticsNotFoundError
from aimet_torch.experimental.v2.quantization.encoding_analyzer import EncodingAnalyzer
from aimet_torch.experimental.v2.utils import patch_attr, patch_param, _is_expandable, StatisticsNotFoundError
from aimet_torch.experimental.v2.quantization.encoding_analyzer import EncodingAnalyzer, MinMaxEncodingAnalyzer
from aimet_torch.experimental.v2.quantization.backends import get_backend
from aimet_torch.experimental.v2.utils import ste_round

Expand All @@ -64,18 +64,25 @@ class _QuantizerBase(torch.nn.Module): # pylint: disable=abstract-method
:param bitwidth: Quantization bitwidth.
:param symmetric: If True, performs symmetric quantization;
otherwise, performs asymmetric quantization.
:param encoding_analyzer: Encoding Analyzer
:param encoding_analyzer: Encoding analyzer for calibrating quantization encodings.
(default: absolute min-max encoding analyzer)
"""

min: torch.nn.Parameter
max: torch.nn.Parameter

def __init__(self, shape, bitwidth: int, symmetric: bool, encoding_analyzer: EncodingAnalyzer):
def __init__(self, shape, bitwidth: int, symmetric: bool, encoding_analyzer: EncodingAnalyzer = None):
super().__init__()
if isinstance(shape, int):
shape = (shape,)
self.shape = shape
self.bitwidth = bitwidth
self.symmetric = symmetric
self.encoding_analyzer = encoding_analyzer
self.encoding_analyzer = encoding_analyzer or MinMaxEncodingAnalyzer(shape)

if not _is_expandable(self.encoding_analyzer.observer.shape, self.shape):
raise RuntimeError(f'Encoding analyzer of shape {self.encoding_analyzer.observer.shape} '
f'is incompatible with quantizer of shape {self.shape}.')

# param_name -> (weakref of initial parameter, version info of the initial parameter)
# This info will be used for judging whether the current parameter has ever been
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from aimet_torch.tensor_quantizer import TensorQuantizer, StaticGridPerChannelQuantizer
from aimet_torch.experimental.v2.nn.fake_quant import FakeQuantizationMixin
from aimet_torch.experimental.v2.quantization.modules.quantize import QuantizeDequantize
from aimet_torch.experimental.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer, PercentileEncodingAnalyzer


logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)
Expand Down Expand Up @@ -267,6 +268,17 @@ def _validate_quantizer_properties(self):

assert self.data_type == QuantizationDataType.int, "Only int quantization is supported in quantsim v1.5"

def _get_v2_encoding_analyzer(self, shape):
"""
Converts v1 quant scheme into v2 quant scheme.
:return: corresponding v2 quant scheme
"""
if self.quant_scheme in (QuantScheme.post_training_tf, QuantScheme.training_range_learning_with_tf_init):
return MinMaxEncodingAnalyzer(shape)
if self.quant_scheme == QuantScheme.post_training_percentile:
return PercentileEncodingAnalyzer(shape)
raise NotImplementedError(f"Quant scheme {self.quant_scheme} in old quantsim is not supported yet in quantsim v1.5")

@staticmethod
def _get_param_shape() -> List[int]:
Expand All @@ -284,8 +296,17 @@ def realize(self) -> Optional[QuantizeDequantize]:
:return: spec for v2 quantizer initialization
"""
raise NotImplementedError
if not self.enabled:
return None

self._validate_quantizer_properties()

quantizer_param_shape = self._get_param_shape()

encoding_analyzer = self._get_v2_encoding_analyzer(quantizer_param_shape)

return QuantizeDequantize(quantizer_param_shape, self.bitwidth,
self.use_symmetric_encodings, encoding_analyzer)

def _set_internal_quantizer_properties(self, quantizer: TensorQuantizer):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ def patch_param(module: torch.nn.Module, param_name: str, new_param: torch.Tenso
"""
original_param = getattr(module, param_name)
if original_param is not None:
assert original_param.shape == new_param.shape
assert _is_expandable(new_param.shape, original_param.shape)
new_param = new_param.expand_as(original_param)

# Modify module.__dict__.
# module.__dict__ is the primary lookup table which has higher priority than __getattr__ method.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from aimet_torch.experimental.v2.quantization.backends import get_backend
from aimet_torch.experimental.v2.quantization.modules.quantize import QuantizeDequantize
from aimet_torch.experimental.v2.nn.fake_quant import FakeQuantizedSoftmax

from aimet_torch.experimental.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer


@pytest.fixture
Expand Down Expand Up @@ -70,7 +70,6 @@ def test_no_spec(self, input):
expected_output = F.softmax(input, quant_softmax.dim)
assert torch.equal(quant_softmax(input), expected_output)

@pytest.mark.skip('Skipping due to changes in EncodingAnalyzer instantiation')
def test_input_qtzn(self, input):
"""
Given: Instantiate a fake-quantized module with input quantizer spec specified
Expand All @@ -79,7 +78,7 @@ def test_input_qtzn(self, input):
quant_softmax.input_quantizers[0] = QuantizeDequantize((1,),
bitwidth=8,
symmetric=False,
qscheme='MinMax')
encoding_analyzer=MinMaxEncodingAnalyzer((1,)))

"""
When: Inspect `input_quantizer` attribute.
Expand Down Expand Up @@ -113,7 +112,6 @@ def test_input_qtzn(self, input):
expected_output = F.softmax(input_qdq, quant_softmax.dim)
assert torch.equal(quant_output, expected_output)

@pytest.mark.skip('Skipping due to changes in EncodingAnalyzer instantiation')
def test_output_qtzn(self, input):
"""
Given: Instantiate a fake-quantized module with output quantizer spec specified
Expand All @@ -122,7 +120,7 @@ def test_output_qtzn(self, input):
quant_softmax.output_quantizers[0] = QuantizeDequantize((1,),
bitwidth=8,
symmetric=False,
qscheme='MinMax')
encoding_analyzer=MinMaxEncodingAnalyzer((1,)))

"""
When: Inspect `output_quantizer` attribute.
Expand Down Expand Up @@ -157,4 +155,4 @@ def test_output_qtzn(self, input):
scale,
offset,
bitwidth)
assert torch.equal(quant_output, expected_output)
assert torch.equal(quant_output, expected_output)
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,13 @@
from aimet_torch.experimental.v2.quantization.backends import get_backend
from aimet_torch.experimental.v2.quantization.modules.quantize import QuantizeDequantize
from aimet_torch.experimental.v2.nn.fake_quant import FakeQuantizedLinear, FakeQuantizationMixin

from aimet_torch.experimental.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer


@pytest.fixture
def input():
return torch.arange(-5, 5) / 10

@pytest.mark.skip('Skipping due to changes in EncodingAnalyzer instantiation')
class TestFakeQuantizedLinear:
def test_no_spec(self, input):
quant_linear = FakeQuantizedLinear(10, 10)
Expand All @@ -69,7 +68,7 @@ def test_input_qtzn(self, input):
quant_linear.input_quantizers[0] = QuantizeDequantize((1,),
bitwidth=8,
symmetric=False,
qscheme='MinMax')
encoding_analyzer=MinMaxEncodingAnalyzer((1,)))
"""
When: Inspect `input_quantizer` attribute.
Then: `input_quantizer` is set to `QuantizeDequantize` as a submodule
Expand Down Expand Up @@ -111,7 +110,7 @@ def test_output_qtzn(self, input):
quant_linear.output_quantizers[0] = QuantizeDequantize((1,),
bitwidth=8,
symmetric=False,
qscheme='MinMax')
encoding_analyzer=MinMaxEncodingAnalyzer((1,)))

"""
When: Inspect `output_quantizer` attribute.
Expand Down Expand Up @@ -158,7 +157,7 @@ def test_param_qtzn(self, input, bias):
quant_linear.param_quantizers['weight'] = QuantizeDequantize((10,),
bitwidth=4,
symmetric=True,
qscheme='MinMax')
encoding_analyzer=MinMaxEncodingAnalyzer((10,)))

"""
When: Inspect `weight_quantizer` attribute.
Expand Down Expand Up @@ -196,15 +195,15 @@ def test_from_module(self, input):
quant_linear.input_quantizers[0] = QuantizeDequantize((1,),
bitwidth=8,
symmetric=False,
qscheme='MinMax')
encoding_analyzer=MinMaxEncodingAnalyzer((1,)))
quant_linear.output_quantizers[0] = QuantizeDequantize((1,),
bitwidth=8,
symmetric=False,
qscheme='MinMax')
encoding_analyzer=MinMaxEncodingAnalyzer((1,)))
quant_linear.param_quantizers['weight'] = QuantizeDequantize((10,),
bitwidth=4,
symmetric=True,
qscheme='MinMax')
encoding_analyzer=MinMaxEncodingAnalyzer((10,)))
with quant_linear.compute_encodings():
_ = quant_linear(input)

Expand Down Expand Up @@ -263,4 +262,4 @@ def test_from_module(self, input):
NOTE: Analogous to shallow copies, reassigning a new attribute to one of them shouldn't affect the other.
"""
fp_linear.weight = nn.Parameter(torch.zeros(10, 10))
assert not torch.any(fp_linear.weight == quant_linear.weight)
assert not torch.any(fp_linear.weight == quant_linear.weight)
Loading

0 comments on commit 5a19b1b

Please sign in to comment.