From c18aade7716782d5beaab3bb205eb8ed3eb4e840 Mon Sep 17 00:00:00 2001 From: Kyunggeun Lee Date: Mon, 8 Jan 2024 15:47:50 -0800 Subject: [PATCH] Fix indentation and resolve pylint warnings (#2638) Signed-off-by: Kyunggeun Lee --- .../v2/quantization/encoding_analyzer.py | 32 +++-- .../experimental/v2/test_encoding_analyzer.py | 121 +++++++++--------- 2 files changed, 78 insertions(+), 75 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py index b7a5a8af6fb..a71d818d24f 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py @@ -35,9 +35,7 @@ # @@-COPYRIGHT-END-@@ # ============================================================================= # pylint: disable=redefined-builtin -# pylint: disable=arguments-differ # pylint: disable=missing-docstring -# pylint: disable=no-member """ Computes statistics and encodings """ @@ -99,20 +97,20 @@ def collect_stats(self, input_tensor: torch.Tensor) -> _MinMaxRange: return _MinMaxRange(new_min, new_max) @torch.no_grad() - def merge_stats(self, new_stats: _MinMaxRange): + def merge_stats(self, stats: _MinMaxRange): updated_min = self.stats.min - if new_stats.min is not None: + if stats.min is not None: if updated_min is None: - updated_min = new_stats.min.clone() + updated_min = stats.min.clone() else: - updated_min = torch.minimum(updated_min, new_stats.min) + updated_min = torch.minimum(updated_min, stats.min) updated_max = self.stats.max - if new_stats.max is not None: + if stats.max is not None: if updated_max is None: - updated_max = new_stats.max.clone() + updated_max = stats.max.clone() else: - updated_max = torch.maximum(updated_max, new_stats.max) + updated_max = torch.maximum(updated_max, stats.max) self.stats = _MinMaxRange(updated_min, updated_max) @@ -136,7 +134,7 @@ def collect_stats(self, input_tensor: torch.Tensor) -> _Histogram: raise NotImplementedError @torch.no_grad() - def merge_stats(self, new_stats: _Histogram): + def merge_stats(self, stats: _Histogram): # TODO raise NotImplementedError @@ -171,6 +169,8 @@ def get_encoding_analyzer_cls(calibration_method: CalibrationMethod, min_max_sha 'minmax, sqnr, mse, percentile') class _EncodingAnalyzer(Generic[_Statistics], ABC): + def __init__(self, observer: _Observer): + self.observer = observer @torch.no_grad() def update_stats(self, input_tensor: torch.Tensor) -> _Statistics: @@ -199,7 +199,8 @@ class MinMaxEncodingAnalyzer(_EncodingAnalyzer[_MinMaxRange]): Encoding Analyzer for Min-Max calibration technique """ def __init__(self, shape): - self.observer = _MinMaxObserver(shape) + observer = _MinMaxObserver(shape) + super().__init__(observer) @torch.no_grad() def compute_encodings_from_stats(self, stats: _MinMaxRange, bitwidth: int, is_symmetric: bool)\ @@ -238,7 +239,8 @@ class PercentileEncodingAnalyzer(_EncodingAnalyzer[_Histogram]): Encoding Analyzer for Percentile calibration technique """ def __init__(self, shape): - self.observer = _HistogramObserver(shape) + observer = _HistogramObserver(shape) + super().__init__(observer) @torch.no_grad() def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\ @@ -251,7 +253,8 @@ class SqnrEncodingAnalyzer(_EncodingAnalyzer[_Histogram]): Encoding Analyzer for SQNR Calibration technique """ def __init__(self, shape): - self.observer = _HistogramObserver(shape) + observer = _HistogramObserver(shape) + super().__init__(observer) @torch.no_grad() def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\ @@ -264,7 +267,8 @@ class MseEncodingAnalyzer(_EncodingAnalyzer[_Histogram]): Encoding Analyzer for Mean Square Error (MSE) Calibration technique """ def __init__(self, shape): - self.observer = _HistogramObserver(shape) + observer = _HistogramObserver(shape) + super().__init__(observer) @torch.no_grad() def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\ diff --git a/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py b/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py index 7159c707028..7d99ebbad73 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py @@ -37,55 +37,54 @@ import torch import pytest from aimet_torch.experimental.v2.quantization.encoding_analyzer import get_encoding_analyzer_cls, CalibrationMethod -from aimet_torch.experimental.v2.quantization.backends.default import quantize, quantize_dequantize class TestEncodingAnalyzer(): - @pytest.mark.parametrize('symmetric', [True, False]) - def test_overflow(self, symmetric): - encoding_shape = (1,) - float_input = (torch.arange(10) * torch.finfo(torch.float).tiny) - - encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_shape) - encoding_analyzer.update_stats(float_input) - min, max = encoding_analyzer.compute_encodings(bitwidth=8, is_symmetric=symmetric) - - scale = (max - min) / 255 - # Scale should be at least as large as torch.tiny - assert torch.all(torch.isfinite(scale)) - assert torch.allclose(scale, torch.tensor(torch.finfo(scale.dtype).tiny)) - - @pytest.mark.parametrize('dtype', [torch.float, torch.half]) - @pytest.mark.parametrize('symmetric', [True, False]) - def test_continuity(self, symmetric, dtype): - encoding_shape = (1,) - normal_range = torch.arange(-128, 128).to(dtype) / 256 - encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_shape) - eps = torch.finfo(dtype).eps - - min_1, max_1 = encoding_analyzer.compute_dynamic_encodings(normal_range * (1 - eps), - bitwidth=8, is_symmetric=symmetric) - min_2, max_2 = encoding_analyzer.compute_dynamic_encodings(normal_range, - bitwidth=8, is_symmetric=symmetric) - min_3, max_3 = encoding_analyzer.compute_dynamic_encodings(normal_range * (1 + eps), - bitwidth=8, is_symmetric=symmetric) - - assert min_3 <= min_2 <= min_1 <= max_1 <= max_2 <= max_3 - assert torch.allclose(max_1, max_2, atol=eps) - assert torch.allclose(min_1, min_2, atol=eps) - assert torch.allclose(max_2, max_3, atol=eps) - assert torch.allclose(min_2, min_3, atol=eps) + @pytest.mark.parametrize('symmetric', [True, False]) + def test_overflow(self, symmetric): + encoding_shape = (1,) + float_input = (torch.arange(10) * torch.finfo(torch.float).tiny) + + encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_shape) + encoding_analyzer.update_stats(float_input) + min, max = encoding_analyzer.compute_encodings(bitwidth=8, is_symmetric=symmetric) + + scale = (max - min) / 255 + # Scale should be at least as large as torch.tiny + assert torch.all(torch.isfinite(scale)) + assert torch.allclose(scale, torch.tensor(torch.finfo(scale.dtype).tiny)) + + @pytest.mark.parametrize('dtype', [torch.float, torch.half]) + @pytest.mark.parametrize('symmetric', [True, False]) + def test_continuity(self, symmetric, dtype): + encoding_shape = (1,) + normal_range = torch.arange(-128, 128).to(dtype) / 256 + encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_shape) + eps = torch.finfo(dtype).eps + + min_1, max_1 = encoding_analyzer.compute_dynamic_encodings(normal_range * (1 - eps), + bitwidth=8, is_symmetric=symmetric) + min_2, max_2 = encoding_analyzer.compute_dynamic_encodings(normal_range, + bitwidth=8, is_symmetric=symmetric) + min_3, max_3 = encoding_analyzer.compute_dynamic_encodings(normal_range * (1 + eps), + bitwidth=8, is_symmetric=symmetric) + + assert min_3 <= min_2 <= min_1 <= max_1 <= max_2 <= max_3 + assert torch.allclose(max_1, max_2, atol=eps) + assert torch.allclose(min_1, min_2, atol=eps) + assert torch.allclose(max_2, max_3, atol=eps) + assert torch.allclose(min_2, min_3, atol=eps) class TestMinMaxEncodingAnalyzer(): - def test_compute_encodings_with_negative_bitwidth(self): + def test_compute_encodings_with_negative_bitwidth(self): encoding_min_max = torch.randn(3, 4) encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_min_max.shape) encoding_analyzer.update_stats(torch.randn(3, 4)) with pytest.raises(ValueError): - encoding_analyzer.compute_encodings(bitwidth = 0, is_symmetric = False) - - def test_compute_encodings_asymmetric(self): + encoding_analyzer.compute_encodings(bitwidth = 0, is_symmetric = False) + + def test_compute_encodings_asymmetric(self): encoding_min_max = torch.randn(1) encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_min_max.shape) input_tensor = torch.arange(start=0, end=26, step=0.5, dtype=torch.float) @@ -94,8 +93,8 @@ def test_compute_encodings_asymmetric(self): asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False) assert torch.all(torch.isclose(asymmetric_min, torch.zeros(tuple(encoding_analyzer.observer.shape)))) assert torch.all(torch.isclose(asymmetric_max, torch.full(tuple(encoding_analyzer.observer.shape), 25.5))) - - def test_compute_encodings_signed_symmetric(self): + + def test_compute_encodings_signed_symmetric(self): encoding_min_max = torch.randn(1) encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_min_max.shape) input_tensor = torch.arange(start=0, end=26, step=0.5, dtype=torch.float) @@ -105,7 +104,7 @@ def test_compute_encodings_signed_symmetric(self): assert torch.all(torch.isclose(symmetric_min, torch.full(tuple(encoding_analyzer.observer.shape), -25.5))) assert torch.all(torch.isclose(symmetric_max, torch.full(tuple(encoding_analyzer.observer.shape), 25.5))) - def test_reset_stats(self): + def test_reset_stats(self): encoding_min_max = torch.randn(3, 4) encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_min_max.shape) encoding_analyzer.update_stats(torch.randn(3, 4)) @@ -114,31 +113,31 @@ def test_reset_stats(self): encoding_analyzer.reset_stats() assert not encoding_analyzer.observer.stats.min assert not encoding_analyzer.observer.stats.max - - def test_compute_encodings_with_no_stats(self): + + def test_compute_encodings_with_no_stats(self): encoding_min_max = torch.randn(3, 4) encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_min_max.shape) with pytest.raises(RuntimeError): - encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False) + encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False) - def test_compute_encodings_with_only_zero_tensor(self): + def test_compute_encodings_with_only_zero_tensor(self): encoding_min_max = torch.randn(3, 4) encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_min_max.shape) encoding_analyzer.update_stats(torch.zeros(tuple(encoding_analyzer.observer.shape))) - + asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False) updated_min = torch.finfo(asymmetric_min.dtype).tiny * (2 ** (8 - 1)) updated_max = torch.finfo(asymmetric_min.dtype).tiny * ((2 **(8 - 1)) - 1) assert torch.all(torch.eq(asymmetric_min, torch.full(tuple(encoding_analyzer.observer.shape), -updated_min))) assert torch.all(torch.eq(asymmetric_max, torch.full(tuple(encoding_analyzer.observer.shape), updated_max))) - + symmetric_min , symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True) updated_symmetric_min = min(-updated_min, -updated_max) updated_symmetric_max = max(updated_min, updated_max) assert torch.all(torch.eq(symmetric_min, torch.full(tuple(encoding_analyzer.observer.shape), updated_symmetric_min))) assert torch.all(torch.eq(symmetric_max, torch.full(tuple(encoding_analyzer.observer.shape), updated_symmetric_max))) - - def test_compute_encodings_with_same_nonzero_tensor(self): + + def test_compute_encodings_with_same_nonzero_tensor(self): encoding_min_max = torch.randn(3, 4) encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, encoding_min_max.shape) encoding_analyzer.update_stats(torch.full((3, 4), 3.0)) @@ -150,16 +149,16 @@ def test_compute_encodings_with_same_nonzero_tensor(self): symmetric_min , symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True) assert torch.allclose(symmetric_min, torch.full(tuple(encoding_analyzer.observer.shape), -3.0), atol = torch.finfo().tiny) assert torch.allclose(symmetric_max, torch.full(tuple(encoding_analyzer.observer.shape), 3.0), atol = torch.finfo().tiny) - - @pytest.mark.parametrize("min_max_size", [[3,4], [2, 3, 1], [4], [1]]) - def test_update_stats_with_different_dimensions(self,min_max_size): - for i in range(4): - encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, min_max_size) - encoding_analyzer.update_stats(torch.randn(2, 3, 4)) - assert list(encoding_analyzer.observer.stats.min.shape) == min_max_size - assert list(encoding_analyzer.observer.stats.max.shape) == min_max_size - - def test_update_stats_incompatible_dimension(self): + + @pytest.mark.parametrize("min_max_size", [[3,4], [2, 3, 1], [4], [1]]) + def test_update_stats_with_different_dimensions(self,min_max_size): + for _ in range(4): + encoding_analyzer = get_encoding_analyzer_cls(CalibrationMethod.MinMax, min_max_size) + encoding_analyzer.update_stats(torch.randn(2, 3, 4)) + assert list(encoding_analyzer.observer.stats.min.shape) == min_max_size + assert list(encoding_analyzer.observer.stats.max.shape) == min_max_size + + def test_update_stats_incompatible_dimension(self): encoding_analyzer_1 = get_encoding_analyzer_cls(CalibrationMethod.MinMax, [3, 4]) with pytest.raises(RuntimeError): - encoding_analyzer_1.update_stats(torch.randn(2, 3, 5)) + encoding_analyzer_1.update_stats(torch.randn(2, 3, 5))