From 10f1f0d860c8fa7ff1d208239c745c38d126342d Mon Sep 17 00:00:00 2001 From: Priyanka Dangi Date: Tue, 20 Feb 2024 15:24:12 -0800 Subject: [PATCH] Removing percentile argument from compute encodings (#2768) Signed-off-by: Priyanka Dangi --- .../v2/quantization/encoding_analyzer.py | 26 ++-- .../experimental/v2/test_encoding_analyzer.py | 113 ++++++------------ 2 files changed, 46 insertions(+), 93 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py index 0f4649c5998..e02f1fcdd9c 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py @@ -341,11 +341,16 @@ class PercentileEncodingAnalyzer(EncodingAnalyzer[_Histogram]): """ Encoding Analyzer for Percentile calibration technique """ - def __init__(self, shape: tuple, num_bins: int = 2048): + def __init__(self, shape: tuple, num_bins: int = 2048, percentile: float = 100): if num_bins <= 0: raise ValueError('Number of bins cannot be less than or equal to 0.') + + if percentile < 50 or percentile > 100: + raise ValueError('Percentile value must be within 50-100 range') + observer = _HistogramObserver(shape=shape, num_bins=num_bins) super().__init__(observer) + self.percentile = percentile @torch.no_grad() def update_stats(self, input_tensor: torch.Tensor) -> _Statistics: @@ -353,25 +358,14 @@ def update_stats(self, input_tensor: torch.Tensor) -> _Statistics: self.observer.merge_stats(new_stats, input_tensor) return new_stats - def compute_dynamic_encodings(self, input_tensor: torch.Tensor, bitwidth: int,\ - is_symmetric: bool, percentile: float)-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - return self.compute_encodings_from_stats( - self.observer.collect_stats(input_tensor), bitwidth, is_symmetric, percentile) - - def compute_encodings(self, bitwidth: int, is_symmetric: bool, percentile: float) -> torch.Tensor: - return self.compute_encodings_from_stats(self.observer.get_stats(), bitwidth, is_symmetric, percentile) - # pylint: disable=too-many-locals @torch.no_grad() - def compute_encodings_from_stats(self, stats: List[_Histogram], bitwidth: int, is_symmetric: bool, percentile: float)\ + def compute_encodings_from_stats(self, stats: List[_Histogram], bitwidth: int, is_symmetric: bool)\ -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: if bitwidth <= 0: raise ValueError('Bitwidth cannot be less than or equal to 0.') - if percentile < 50 or percentile > 100: - raise ValueError('Percentile value must be within 50-100 range') - if stats[0].histogram is None: raise StatisticsNotFoundError('No statistics present to compute encodings.') @@ -381,10 +375,10 @@ def compute_encodings_from_stats(self, stats: List[_Histogram], bitwidth: int, i for list_elem in stats: cum_sum = torch.cumsum(list_elem.histogram, dim=0) # trim percentile value from min and max - max_index = torch.searchsorted(cum_sum, torch.quantile(cum_sum, percentile/100)) - min_index = torch.searchsorted(cum_sum, torch.quantile(cum_sum, 1 - percentile/100)) + max_index = torch.searchsorted(cum_sum, torch.quantile(cum_sum, self.percentile/100)) + min_index = torch.searchsorted(cum_sum, torch.quantile(cum_sum, 1 - self.percentile/100)) - if percentile == 100: + if self.percentile == 100: min_index = 0 max_index = -1 curr_min = list_elem.bin_edges[min_index] diff --git a/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py b/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py index aad6fd0a29a..af28eecb8a4 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py @@ -49,7 +49,7 @@ class TestEncodingAnalyzer(): @pytest.fixture def encoding_analyzers(self): min_max_encoding_analyzer = MinMaxEncodingAnalyzer((1,)) - percentile_encoding_analyzer = PercentileEncodingAnalyzer((1,), 3) + percentile_encoding_analyzer = PercentileEncodingAnalyzer((1,), num_bins=3, percentile=99) sqnr_encoding_analyzer = SqnrEncodingAnalyzer((1, )) encoding_analyzer_list = [min_max_encoding_analyzer, percentile_encoding_analyzer, sqnr_encoding_analyzer] yield encoding_analyzer_list @@ -58,9 +58,6 @@ def test_compute_encodings_with_negative_bitwidth(self, encoding_analyzers): for encoding_analyzer in encoding_analyzers: encoding_analyzer.update_stats(torch.randn(3, 4)) with pytest.raises(ValueError): - if isinstance(encoding_analyzer, PercentileEncodingAnalyzer): - encoding_analyzer.compute_encodings(bitwidth = 0, is_symmetric = False, percentile=99) - else: encoding_analyzer.compute_encodings(bitwidth = 0, is_symmetric = False) def test_reset_stats(self, encoding_analyzers): @@ -84,10 +81,7 @@ def test_reset_stats(self, encoding_analyzers): def test_compute_encodings_with_no_stats(self, encoding_analyzers): for encoding_analyzer in encoding_analyzers: with pytest.raises(RuntimeError): - if isinstance(encoding_analyzer, PercentileEncodingAnalyzer): - encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False, percentile = 99) - else: - encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False) + encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False) @pytest.mark.parametrize('dtype', [torch.float, torch.half]) @pytest.mark.parametrize('symmetric', [True, False]) @@ -96,24 +90,12 @@ def test_continuity(self, symmetric, dtype, encoding_analyzers): normal_range = torch.arange(-128, 128).to(dtype) / 256 eps = torch.finfo(dtype).eps - if isinstance(encoding_analyzer, PercentileEncodingAnalyzer): - percentile_val = 99 - min_1, max_1 = encoding_analyzer.compute_dynamic_encodings(normal_range * (1 - eps), - bitwidth=8, is_symmetric=symmetric, - percentile=percentile_val) - min_2, max_2 = encoding_analyzer.compute_dynamic_encodings(normal_range, - bitwidth=8, is_symmetric=symmetric, - percentile=percentile_val) - min_3, max_3 = encoding_analyzer.compute_dynamic_encodings(normal_range * (1 + eps), - bitwidth=8, is_symmetric=symmetric, - percentile=percentile_val) - else: - min_1, max_1 = encoding_analyzer.compute_dynamic_encodings(normal_range * (1 - eps), - bitwidth=8, is_symmetric=symmetric) - min_2, max_2 = encoding_analyzer.compute_dynamic_encodings(normal_range, - bitwidth=8, is_symmetric=symmetric) - min_3, max_3 = encoding_analyzer.compute_dynamic_encodings(normal_range * (1 + eps), - bitwidth=8, is_symmetric=symmetric) + min_1, max_1 = encoding_analyzer.compute_dynamic_encodings(normal_range * (1 - eps), + bitwidth=8, is_symmetric=symmetric) + min_2, max_2 = encoding_analyzer.compute_dynamic_encodings(normal_range, + bitwidth=8, is_symmetric=symmetric) + min_3, max_3 = encoding_analyzer.compute_dynamic_encodings(normal_range * (1 + eps), + bitwidth=8, is_symmetric=symmetric) assert min_3 <= min_2 <= min_1 <= max_1 <= max_2 <= max_3 assert torch.allclose(max_1, max_2, atol=eps) @@ -205,7 +187,7 @@ def histogram_based_encoding_analyzers(self, request): min_max_shape = request.param[0] num_bins = request.param[1] - percentile_encoding_analyzer = PercentileEncodingAnalyzer(min_max_shape, num_bins) + percentile_encoding_analyzer = PercentileEncodingAnalyzer(min_max_shape, num_bins = num_bins, percentile = 99) #sqnr_encoding_analyzer = SqnrEncodingAnalyzer(min_max_shape, num_bins) encoding_analyzer_list = [percentile_encoding_analyzer] yield encoding_analyzer_list @@ -219,20 +201,13 @@ def test_compute_encodings_with_same_nonzero_tensor(self, histogram_based_encodi cum_sum = torch.cumsum(stats.histogram, dim=0) index = torch.searchsorted(cum_sum, torch.quantile(cum_sum, 99/100)) max_val = stats.bin_edges[index] - - if isinstance(encoding_analyzer, PercentileEncodingAnalyzer): - asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False, percentile=99) - else: - asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False) - + + asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False) assert torch.allclose(asymmetric_min, torch.full(tuple(encoding_analyzer.observer.shape), 0.0)) assert torch.allclose(asymmetric_max, torch.full(tuple(encoding_analyzer.observer.shape), max_val)) - if isinstance(encoding_analyzer, PercentileEncodingAnalyzer): - symmetric_min , symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True, percentile = 99) - else: - symmetric_min , symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True) + symmetric_min , symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True) assert torch.allclose(symmetric_min, torch.full(tuple(encoding_analyzer.observer.shape), -1 * max_val)) assert torch.allclose(symmetric_max, torch.full(tuple(encoding_analyzer.observer.shape), max_val)) @@ -245,20 +220,12 @@ def test_compute_encodings_with_only_zero_tensor(self, histogram_based_encoding_ cum_sum = torch.cumsum(stats.histogram, dim=0) index = torch.searchsorted(cum_sum, torch.quantile(cum_sum, 99/100)) min_value = stats.bin_edges[index] - - if isinstance(encoding_analyzer, PercentileEncodingAnalyzer): - asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False, percentile=99) - else: - asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False) - + + asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False) assert torch.allclose(asymmetric_min, torch.full(tuple(encoding_analyzer.observer.shape), min_value)) assert torch.all(torch.eq(asymmetric_max, torch.full(tuple(encoding_analyzer.observer.shape), 0))) - - if isinstance(encoding_analyzer, PercentileEncodingAnalyzer): - symmetric_min , symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True, percentile = 99) - else: - symmetric_min , symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True) + symmetric_min , symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True) assert torch.allclose(symmetric_min, torch.full(tuple(encoding_analyzer.observer.shape), min_value)) assert torch.allclose(symmetric_max, torch.full(tuple(encoding_analyzer.observer.shape), -1 * min_value)) @@ -267,7 +234,7 @@ def test_invalid_bin_value(self, num_bins): min_max_shape = (1,) with pytest.raises(ValueError): - PercentileEncodingAnalyzer(num_bins = num_bins, shape=min_max_shape) + PercentileEncodingAnalyzer(num_bins = num_bins, shape=min_max_shape, percentile = 99) with pytest.raises(ValueError): SqnrEncodingAnalyzer(min_max_shape, num_bins) @@ -279,19 +246,12 @@ def test_cuda_inputs(self, histogram_based_encoding_analyzers, symmetric): for encoding_analyzer in histogram_based_encoding_analyzers: input_tensor = torch.tensor([2.0, 3.5, 4.2, 5.0]) encoding_analyzer.update_stats(input_tensor.cuda()) - if isinstance(encoding_analyzer, PercentileEncodingAnalyzer): - encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = symmetric, percentile=99) - else: - encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = symmetric) + encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = symmetric) encoding_analyzer.update_stats(input_tensor.cuda()) input_tensor_2 = input_tensor * 1.1 - 0.1 encoding_analyzer.update_stats(input_tensor_2.cuda()) - - if isinstance(encoding_analyzer, PercentileEncodingAnalyzer): - encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = symmetric, percentile=99) - else: - encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = symmetric) + encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = symmetric) assert encoding_analyzer.observer.stats[0].histogram.is_cuda assert encoding_analyzer.observer.stats[0].bin_edges.is_cuda @@ -412,9 +372,9 @@ def test_merge_stats_without_resizing(self, histogram_based_encoding_analyzers): @pytest.mark.parametrize('shape', [(4,), (3, 1), (2, 1, 1), (3, 4), (2, 3, 1), (2, 1, 4), (2, 3, 4)]) def test_compute_encodings_multidimensional(self, symmetric, device, shape): x = torch.arange(24, dtype=torch.float).view(2, 3, 4).to(device) - encoding_analyzer = PercentileEncodingAnalyzer(shape=shape) + encoding_analyzer = PercentileEncodingAnalyzer(shape=shape, percentile = 99) encoding_analyzer.update_stats(x) - encoding_min, encoding_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric=symmetric, percentile=99) + encoding_min, encoding_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric=symmetric) assert encoding_min.shape == shape assert encoding_max.shape == shape if device == 'cuda': @@ -521,75 +481,74 @@ class TestPercentileEncodingAnalyzer(): @pytest.mark.parametrize("percentile_value", [-1, 49, 5, 101]) def test_invalid_percentile_value(self, percentile_value): with pytest.raises(ValueError): - encoding_analyzer = PercentileEncodingAnalyzer((1,), 3) - encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False, percentile=percentile_value) + PercentileEncodingAnalyzer((1,), percentile = percentile_value, num_bins = 3) def test_compute_encodings_asymmetric_normalized(self): - encoding_analyzer = PercentileEncodingAnalyzer((1,), 3) + encoding_analyzer = PercentileEncodingAnalyzer((1,), percentile = 99, num_bins = 3) mean = std_dev = 2 input_tensor = np.random.normal(mean, std_dev, size=(100000)) encoding_analyzer.update_stats(torch.from_numpy(input_tensor)) - asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False, percentile=99) + asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False) # 99% of the population is within 2 1/2 standard deviations of the mean assert asymmetric_min > mean - std_dev * 2.5 assert asymmetric_max < mean + std_dev * 2.5 def test_compute_encodings_asymmetric_sequential(self): - encoding_analyzer = PercentileEncodingAnalyzer((1,), 500) + encoding_analyzer = PercentileEncodingAnalyzer((1,), percentile = 99, num_bins = 500) input_tensor = torch.arange(start=0, end=1001, step=1, dtype=torch.float) encoding_analyzer.update_stats(input_tensor) - asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False, percentile=99) + asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False) # encoding max is the histogram bin edge which contains 99% percentile (990.02) assert asymmetric_min == 0 assert asymmetric_max == 990.0 def test_compute_encodings_signed_symmetric_normalized(self): - encoding_analyzer = PercentileEncodingAnalyzer((1,), 3) + encoding_analyzer = PercentileEncodingAnalyzer((1,), percentile = 99, num_bins = 3) mean = std_dev = 2 input_tensor = np.random.normal(mean, std_dev, size=(10000)) encoding_analyzer.update_stats(torch.from_numpy(input_tensor)) - symmetric_min, symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True, percentile=99) + symmetric_min, symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True) largest_absolute_value = max(abs(element) for element in input_tensor) assert symmetric_min > -largest_absolute_value assert symmetric_max < largest_absolute_value def test_compute_encodings_signed_symmetric_sequential(self): - encoding_analyzer = PercentileEncodingAnalyzer((1,), 500) + encoding_analyzer = PercentileEncodingAnalyzer((1,), percentile = 99, num_bins = 500) input_tensor = torch.arange(start=0, end=1001, step=1, dtype=torch.float) encoding_analyzer.update_stats(input_tensor) - symmetric_min, symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True, percentile=99) + symmetric_min, symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True) assert symmetric_min == -990.0 assert symmetric_max == 990.0 def test_compute_encodings_100_percentile(self): - encoding_analyzer = PercentileEncodingAnalyzer((1,), 3) + encoding_analyzer = PercentileEncodingAnalyzer((1,), percentile = 100, num_bins = 3) mean = std_dev = 2 input_tensor = np.random.normal(mean, std_dev, size=(10000)) encoding_analyzer.update_stats(torch.from_numpy(input_tensor)) - symmetric_min, symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True, percentile=100) + symmetric_min, symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True) largest_absolute_value = max(abs(element) for element in input_tensor) assert abs(symmetric_min) <= largest_absolute_value assert symmetric_max <= largest_absolute_value - asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False, percentile=100) + asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False) assert np.allclose(asymmetric_min.item(), min(input_tensor)) assert np.allclose(asymmetric_max.item(), max(input_tensor)) def test_compute_encodings_50_percentile(self): - encoding_analyzer = PercentileEncodingAnalyzer((1,), 3) + encoding_analyzer = PercentileEncodingAnalyzer((1,), percentile = 50, num_bins = 3) input_tensor = torch.arange(start=0, end=1001, step=1, dtype=torch.float) encoding_analyzer.update_stats(input_tensor) bw = 8 - symmetric_min, symmetric_max = encoding_analyzer.compute_encodings(bitwidth = bw, is_symmetric = True, percentile=50) + symmetric_min, symmetric_max = encoding_analyzer.compute_encodings(bitwidth = bw, is_symmetric = True) stats = encoding_analyzer.observer.stats[0] cum_sum = torch.cumsum(stats.histogram, dim=0) @@ -599,7 +558,7 @@ def test_compute_encodings_50_percentile(self): assert symmetric_min == -1 * mid_value assert symmetric_max == mid_value - asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = bw, is_symmetric = False, percentile=50) + asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = bw, is_symmetric = False) assert asymmetric_min == 0 assert asymmetric_max == mid_value @@ -810,4 +769,4 @@ def test_select_best_delta_offset(self): # Find the best delta/offsets based on the candidates best_delta, best_offset = encoding_analyzer._select_best_candidates(deltas, offsets, histograms, 255) assert torch.equal(best_delta, torch.Tensor([1/255.0, 2/255.0]).view(2, 1)) - assert torch.equal(best_offset, torch.Tensor([-128, 0]).view(2, 1)) + assert torch.equal(best_offset, torch.Tensor([-128, 0]).view(2, 1)) \ No newline at end of file