Skip to content

Commit

Permalink
Fix indentation and resolve pylint warnings (#2638)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored Jan 8, 2024
1 parent 6511f2c commit c18aade
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@
# @@-COPYRIGHT-END-@@
# =============================================================================
# pylint: disable=redefined-builtin
# pylint: disable=arguments-differ
# pylint: disable=missing-docstring
# pylint: disable=no-member

""" Computes statistics and encodings """

Expand Down Expand Up @@ -99,20 +97,20 @@ def collect_stats(self, input_tensor: torch.Tensor) -> _MinMaxRange:
return _MinMaxRange(new_min, new_max)

@torch.no_grad()
def merge_stats(self, new_stats: _MinMaxRange):
def merge_stats(self, stats: _MinMaxRange):
updated_min = self.stats.min
if new_stats.min is not None:
if stats.min is not None:
if updated_min is None:
updated_min = new_stats.min.clone()
updated_min = stats.min.clone()
else:
updated_min = torch.minimum(updated_min, new_stats.min)
updated_min = torch.minimum(updated_min, stats.min)

updated_max = self.stats.max
if new_stats.max is not None:
if stats.max is not None:
if updated_max is None:
updated_max = new_stats.max.clone()
updated_max = stats.max.clone()
else:
updated_max = torch.maximum(updated_max, new_stats.max)
updated_max = torch.maximum(updated_max, stats.max)

self.stats = _MinMaxRange(updated_min, updated_max)

Expand All @@ -136,7 +134,7 @@ def collect_stats(self, input_tensor: torch.Tensor) -> _Histogram:
raise NotImplementedError

@torch.no_grad()
def merge_stats(self, new_stats: _Histogram):
def merge_stats(self, stats: _Histogram):
# TODO
raise NotImplementedError

Expand Down Expand Up @@ -171,6 +169,8 @@ def get_encoding_analyzer_cls(calibration_method: CalibrationMethod, min_max_sha
'minmax, sqnr, mse, percentile')

class _EncodingAnalyzer(Generic[_Statistics], ABC):
def __init__(self, observer: _Observer):
self.observer = observer

@torch.no_grad()
def update_stats(self, input_tensor: torch.Tensor) -> _Statistics:
Expand Down Expand Up @@ -199,7 +199,8 @@ class MinMaxEncodingAnalyzer(_EncodingAnalyzer[_MinMaxRange]):
Encoding Analyzer for Min-Max calibration technique
"""
def __init__(self, shape):
self.observer = _MinMaxObserver(shape)
observer = _MinMaxObserver(shape)
super().__init__(observer)

@torch.no_grad()
def compute_encodings_from_stats(self, stats: _MinMaxRange, bitwidth: int, is_symmetric: bool)\
Expand Down Expand Up @@ -238,7 +239,8 @@ class PercentileEncodingAnalyzer(_EncodingAnalyzer[_Histogram]):
Encoding Analyzer for Percentile calibration technique
"""
def __init__(self, shape):
self.observer = _HistogramObserver(shape)
observer = _HistogramObserver(shape)
super().__init__(observer)

@torch.no_grad()
def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\
Expand All @@ -251,7 +253,8 @@ class SqnrEncodingAnalyzer(_EncodingAnalyzer[_Histogram]):
Encoding Analyzer for SQNR Calibration technique
"""
def __init__(self, shape):
self.observer = _HistogramObserver(shape)
observer = _HistogramObserver(shape)
super().__init__(observer)

@torch.no_grad()
def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\
Expand All @@ -264,7 +267,8 @@ class MseEncodingAnalyzer(_EncodingAnalyzer[_Histogram]):
Encoding Analyzer for Mean Square Error (MSE) Calibration technique
"""
def __init__(self, shape):
self.observer = _HistogramObserver(shape)
observer = _HistogramObserver(shape)
super().__init__(observer)

@torch.no_grad()
def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,55 +37,54 @@
import torch
import pytest
from aimet_torch.experimental.v2.quantization.encoding_analyzer import get_encoding_analyzer_cls, CalibrationMethod
from aimet_torch.experimental.v2.quantization.backends.default import quantize, quantize_dequantize


class TestEncodingAnalyzer():
@pytest.mark.parametrize('symmetric', [True, False])
def test_overflow(self, symmetric):
encoding_shape = (1,)
float_input = (torch.arange(10) * torch.finfo(torch.float).tiny)

encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_shape)
encoding_analyzer.update_stats(float_input)
min, max = encoding_analyzer.compute_encodings(bitwidth=8, is_symmetric=symmetric)

scale = (max - min) / 255
# Scale should be at least as large as torch.tiny
assert torch.all(torch.isfinite(scale))
assert torch.allclose(scale, torch.tensor(torch.finfo(scale.dtype).tiny))

@pytest.mark.parametrize('dtype', [torch.float, torch.half])
@pytest.mark.parametrize('symmetric', [True, False])
def test_continuity(self, symmetric, dtype):
encoding_shape = (1,)
normal_range = torch.arange(-128, 128).to(dtype) / 256
encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_shape)
eps = torch.finfo(dtype).eps

min_1, max_1 = encoding_analyzer.compute_dynamic_encodings(normal_range * (1 - eps),
bitwidth=8, is_symmetric=symmetric)
min_2, max_2 = encoding_analyzer.compute_dynamic_encodings(normal_range,
bitwidth=8, is_symmetric=symmetric)
min_3, max_3 = encoding_analyzer.compute_dynamic_encodings(normal_range * (1 + eps),
bitwidth=8, is_symmetric=symmetric)

assert min_3 <= min_2 <= min_1 <= max_1 <= max_2 <= max_3
assert torch.allclose(max_1, max_2, atol=eps)
assert torch.allclose(min_1, min_2, atol=eps)
assert torch.allclose(max_2, max_3, atol=eps)
assert torch.allclose(min_2, min_3, atol=eps)
@pytest.mark.parametrize('symmetric', [True, False])
def test_overflow(self, symmetric):
encoding_shape = (1,)
float_input = (torch.arange(10) * torch.finfo(torch.float).tiny)

encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_shape)
encoding_analyzer.update_stats(float_input)
min, max = encoding_analyzer.compute_encodings(bitwidth=8, is_symmetric=symmetric)

scale = (max - min) / 255
# Scale should be at least as large as torch.tiny
assert torch.all(torch.isfinite(scale))
assert torch.allclose(scale, torch.tensor(torch.finfo(scale.dtype).tiny))

@pytest.mark.parametrize('dtype', [torch.float, torch.half])
@pytest.mark.parametrize('symmetric', [True, False])
def test_continuity(self, symmetric, dtype):
encoding_shape = (1,)
normal_range = torch.arange(-128, 128).to(dtype) / 256
encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_shape)
eps = torch.finfo(dtype).eps

min_1, max_1 = encoding_analyzer.compute_dynamic_encodings(normal_range * (1 - eps),
bitwidth=8, is_symmetric=symmetric)
min_2, max_2 = encoding_analyzer.compute_dynamic_encodings(normal_range,
bitwidth=8, is_symmetric=symmetric)
min_3, max_3 = encoding_analyzer.compute_dynamic_encodings(normal_range * (1 + eps),
bitwidth=8, is_symmetric=symmetric)

assert min_3 <= min_2 <= min_1 <= max_1 <= max_2 <= max_3
assert torch.allclose(max_1, max_2, atol=eps)
assert torch.allclose(min_1, min_2, atol=eps)
assert torch.allclose(max_2, max_3, atol=eps)
assert torch.allclose(min_2, min_3, atol=eps)


class TestMinMaxEncodingAnalyzer():
def test_compute_encodings_with_negative_bitwidth(self):
def test_compute_encodings_with_negative_bitwidth(self):
encoding_min_max = torch.randn(3, 4)
encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_min_max.shape)
encoding_analyzer.update_stats(torch.randn(3, 4))
with pytest.raises(ValueError):
encoding_analyzer.compute_encodings(bitwidth = 0, is_symmetric = False)
def test_compute_encodings_asymmetric(self):
encoding_analyzer.compute_encodings(bitwidth = 0, is_symmetric = False)

def test_compute_encodings_asymmetric(self):
encoding_min_max = torch.randn(1)
encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_min_max.shape)
input_tensor = torch.arange(start=0, end=26, step=0.5, dtype=torch.float)
Expand All @@ -94,8 +93,8 @@ def test_compute_encodings_asymmetric(self):
asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False)
assert torch.all(torch.isclose(asymmetric_min, torch.zeros(tuple(encoding_analyzer.observer.shape))))
assert torch.all(torch.isclose(asymmetric_max, torch.full(tuple(encoding_analyzer.observer.shape), 25.5)))
def test_compute_encodings_signed_symmetric(self):

def test_compute_encodings_signed_symmetric(self):
encoding_min_max = torch.randn(1)
encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_min_max.shape)
input_tensor = torch.arange(start=0, end=26, step=0.5, dtype=torch.float)
Expand All @@ -105,7 +104,7 @@ def test_compute_encodings_signed_symmetric(self):
assert torch.all(torch.isclose(symmetric_min, torch.full(tuple(encoding_analyzer.observer.shape), -25.5)))
assert torch.all(torch.isclose(symmetric_max, torch.full(tuple(encoding_analyzer.observer.shape), 25.5)))

def test_reset_stats(self):
def test_reset_stats(self):
encoding_min_max = torch.randn(3, 4)
encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_min_max.shape)
encoding_analyzer.update_stats(torch.randn(3, 4))
Expand All @@ -114,31 +113,31 @@ def test_reset_stats(self):
encoding_analyzer.reset_stats()
assert not encoding_analyzer.observer.stats.min
assert not encoding_analyzer.observer.stats.max
def test_compute_encodings_with_no_stats(self):

def test_compute_encodings_with_no_stats(self):
encoding_min_max = torch.randn(3, 4)
encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_min_max.shape)
with pytest.raises(RuntimeError):
encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False)
encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False)

def test_compute_encodings_with_only_zero_tensor(self):
def test_compute_encodings_with_only_zero_tensor(self):
encoding_min_max = torch.randn(3, 4)
encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_min_max.shape)
encoding_analyzer.update_stats(torch.zeros(tuple(encoding_analyzer.observer.shape)))

asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False)
updated_min = torch.finfo(asymmetric_min.dtype).tiny * (2 ** (8 - 1))
updated_max = torch.finfo(asymmetric_min.dtype).tiny * ((2 **(8 - 1)) - 1)
assert torch.all(torch.eq(asymmetric_min, torch.full(tuple(encoding_analyzer.observer.shape), -updated_min)))
assert torch.all(torch.eq(asymmetric_max, torch.full(tuple(encoding_analyzer.observer.shape), updated_max)))

symmetric_min , symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True)
updated_symmetric_min = min(-updated_min, -updated_max)
updated_symmetric_max = max(updated_min, updated_max)
assert torch.all(torch.eq(symmetric_min, torch.full(tuple(encoding_analyzer.observer.shape), updated_symmetric_min)))
assert torch.all(torch.eq(symmetric_max, torch.full(tuple(encoding_analyzer.observer.shape), updated_symmetric_max)))
def test_compute_encodings_with_same_nonzero_tensor(self):

def test_compute_encodings_with_same_nonzero_tensor(self):
encoding_min_max = torch.randn(3, 4)
encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_min_max.shape)
encoding_analyzer.update_stats(torch.full((3, 4), 3.0))
Expand All @@ -150,16 +149,16 @@ def test_compute_encodings_with_same_nonzero_tensor(self):
symmetric_min , symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True)
assert torch.allclose(symmetric_min, torch.full(tuple(encoding_analyzer.observer.shape), -3.0), atol = torch.finfo().tiny)
assert torch.allclose(symmetric_max, torch.full(tuple(encoding_analyzer.observer.shape), 3.0), atol = torch.finfo().tiny)
@pytest.mark.parametrize("min_max_size", [[3,4], [2, 3, 1], [4], [1]])
def test_update_stats_with_different_dimensions(self,min_max_size):
for i in range(4):
encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, min_max_size)
encoding_analyzer.update_stats(torch.randn(2, 3, 4))
assert list(encoding_analyzer.observer.stats.min.shape) == min_max_size
assert list(encoding_analyzer.observer.stats.max.shape) == min_max_size

def test_update_stats_incompatible_dimension(self):

@pytest.mark.parametrize("min_max_size", [[3,4], [2, 3, 1], [4], [1]])
def test_update_stats_with_different_dimensions(self,min_max_size):
for _ in range(4):
encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, min_max_size)
encoding_analyzer.update_stats(torch.randn(2, 3, 4))
assert list(encoding_analyzer.observer.stats.min.shape) == min_max_size
assert list(encoding_analyzer.observer.stats.max.shape) == min_max_size

def test_update_stats_incompatible_dimension(self):
encoding_analyzer_1 = get_encoding_analyzer_cls(CalibrationMethod.MinMax, [3, 4])
with pytest.raises(RuntimeError):
encoding_analyzer_1.update_stats(torch.randn(2, 3, 5))
encoding_analyzer_1.update_stats(torch.randn(2, 3, 5))

0 comments on commit c18aade

Please sign in to comment.