diff --git a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py index d1f32d4fcdf..b9d91ae8131 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/experimental/v2/quantization/encoding_analyzer.py @@ -41,9 +41,11 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TypeVar, Generic, Tuple, Optional +from typing import TypeVar, Generic, Tuple, Optional, List +import itertools +import numpy as np import torch -from aimet_torch.experimental.v2.utils import reduce, StatisticsNotFoundError +from aimet_torch.experimental.v2.utils import reduce, StatisticsNotFoundError, _is_expandable @dataclass @@ -51,6 +53,7 @@ class _MinMaxRange: min: Optional[torch.Tensor] = None max: Optional[torch.Tensor] = None +@dataclass class _Histogram: histogram: torch.Tensor = None bin_edges: torch.Tensor = None @@ -123,29 +126,108 @@ def reset_stats(self): def get_stats(self) -> _MinMaxRange: return self.stats + class _HistogramObserver(_Observer[_Histogram]): """ Observer for Histogram based calibration techniques (percentile, MSE) """ def __init__(self, shape: tuple, num_bins: int): super().__init__(shape) - self.stats = _Histogram() self.num_bins = num_bins + self.num_histograms = np.prod(self.shape) + self.stats = [] + for _ in range(self.num_histograms): + self.stats.append(_Histogram()) - @torch.no_grad() - def collect_stats(self, input_tensor: torch.Tensor) -> _Histogram: - # TODO - raise NotImplementedError @torch.no_grad() - def merge_stats(self, stats: _Histogram): - # TODO - raise NotImplementedError + def collect_stats(self, input_tensor: torch.Tensor) -> List[_Histogram]: + if not _is_expandable(self.shape, input_tensor.shape): + raise RuntimeError(f"Shape {self.shape} is incompatible with input of shape {input_tensor.shape}") + + hist_stats = [] + input_shape = tuple(input_tensor.shape) + histogram_shape = self.shape + + padded_histogram_shape = ( + *itertools.repeat(1, len(input_shape) - len(histogram_shape)), + *histogram_shape + ) + + for hist_num in range(self.num_histograms): + hist_input = input_tensor + + for axis, dim in enumerate(padded_histogram_shape): + if dim == 1: + continue + # elements in current axis, ex: could be W*C, C, or 1 for input_shape [H, W, C] + numel = np.prod(padded_histogram_shape[axis+1:], dtype=int) + # index where hist_input at current dimension will be sliced at + index = (hist_num // numel) % dim + hist_input = hist_input.select(axis, index).unsqueeze(axis) + + histogram, bin_edges = torch.histogram(hist_input, self.num_bins) + hist_stats.append(_Histogram(histogram, bin_edges, hist_input.min(), hist_input.max())) + + return hist_stats + + def _get_bin_num(self, bin_width: int, curr_min, data): + if bin_width: + return min(int((data - curr_min) / bin_width), self.num_bins - 1) + return bin_width + + # pylint: disable=arguments-differ + # pylint: disable=too-many-locals + @torch.no_grad() + def merge_stats(self, new_stats_list: List[_Histogram], input_tensor: torch.Tensor): + if self.stats[0].histogram is None: + self.stats = new_stats_list + return + + hist_inputs = torch.reshape(input_tensor, (len(new_stats_list), -1)) + + for index, new_stats in enumerate(new_stats_list): + curr_stats = self.stats[index] + curr_input = hist_inputs[index] + + updated_min = min(new_stats.min, curr_stats.min) + updated_max = max(new_stats.max, curr_stats.max) + + # if the current histogram can capture new_stats within in its range + if updated_min == curr_stats.min and updated_max == curr_stats.max: + histogram_updates = curr_stats.histogram + else: + dest_bin_width = (updated_max - updated_min) / self.num_bins + src_bin_width = (curr_stats.max - curr_stats.min) / self.num_bins + histogram_updates = np.zeros(self.num_bins) + + for curr_bin in range(self.num_bins): + curr_hist = curr_stats.histogram[curr_bin] + if curr_hist: + src_bin_start = curr_stats.min + src_bin_width * curr_bin + bin_index = self._get_bin_num(dest_bin_width, updated_min, src_bin_start) + dest_bin_end = updated_min + dest_bin_width * (bin_index + 1) + + # split curr_hist if values in source bin cannot neatly fold into dest bin + split_hist_value = torch.round(((dest_bin_end - src_bin_start) / src_bin_width) * curr_hist) + dest_bin_updated = min(split_hist_value, curr_hist) + # update appropriate bin with either the full or split curr_hist value + histogram_updates[bin_index] += dest_bin_updated + # if curr_hist is split, update other bin that the remaining values fall into + if dest_bin_updated < curr_hist: + bin_index = self._get_bin_num(dest_bin_width, updated_min, src_bin_start + dest_bin_width) + histogram_updates[bin_index] += curr_hist - dest_bin_updated + # create histogram given input tensor and full range + expanded_histogram, expanded_bin_edges = torch.histogram(curr_input, self.num_bins, range=(updated_min.item(), updated_max.item())) + expanded_histogram += histogram_updates + self.stats[index] = _Histogram(expanded_histogram, expanded_bin_edges, updated_min, updated_max) def reset_stats(self): - self.stats = _Histogram() + self.stats = [] + for _ in range(self.num_histograms): + self.stats.append(_Histogram()) - def get_stats(self) -> _Histogram: + def get_stats(self) -> List[_Histogram]: return self.stats class EncodingAnalyzer(Generic[_Statistics], ABC): @@ -178,7 +260,7 @@ class MinMaxEncodingAnalyzer(EncodingAnalyzer[_MinMaxRange]): """ Encoding Analyzer for Min-Max calibration technique """ - def __init__(self, shape): + def __init__(self, shape: tuple): observer = _MinMaxObserver(shape) super().__init__(observer) @@ -226,11 +308,20 @@ class PercentileEncodingAnalyzer(EncodingAnalyzer[_Histogram]): Encoding Analyzer for Percentile calibration technique """ def __init__(self, shape: tuple, num_bins: int = 2048): + if num_bins <= 0: + raise ValueError('Number of bins cannot be less than or equal to 0.') observer = _HistogramObserver(shape=shape, num_bins=num_bins) super().__init__(observer) @torch.no_grad() - def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\ + def update_stats(self, input_tensor: torch.Tensor) -> _Statistics: + new_stats = self.observer.collect_stats(input_tensor) + self.observer.merge_stats(new_stats, input_tensor) + return new_stats + + # pylint: disable=arguments-differ + @torch.no_grad() + def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool, percentile: float)\ -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: # TODO raise NotImplementedError @@ -240,22 +331,16 @@ class SqnrEncodingAnalyzer(EncodingAnalyzer[_Histogram]): Encoding Analyzer for SQNR Calibration technique """ def __init__(self, shape: tuple, num_bins: int = 2048): + if num_bins <= 0: + raise ValueError('Number of bins cannot be less than or equal to 0.') observer = _HistogramObserver(shape=shape, num_bins=num_bins) super().__init__(observer) @torch.no_grad() - def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\ - -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - # TODO - raise NotImplementedError - -class MseEncodingAnalyzer(EncodingAnalyzer[_Histogram]): - """ - Encoding Analyzer for Mean Square Error (MSE) Calibration technique - """ - def __init__(self, shape: tuple, num_bins: int = 2048): - observer = _HistogramObserver(shape=shape, num_bins=num_bins) - super().__init__(observer) + def update_stats(self, input_tensor: torch.Tensor) -> _Statistics: + new_stats = self.observer.collect_stats(input_tensor) + self.observer.merge_stats(new_stats, input_tensor) + return new_stats @torch.no_grad() def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\ diff --git a/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py b/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py index e2a4b462cea..12eb606312a 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py @@ -38,7 +38,7 @@ import pytest import numpy as np import random -from aimet_torch.experimental.v2.quantization.encoding_analyzer import MseEncodingAnalyzer, SqnrEncodingAnalyzer, PercentileEncodingAnalyzer, MinMaxEncodingAnalyzer +from aimet_torch.experimental.v2.quantization.encoding_analyzer import SqnrEncodingAnalyzer, PercentileEncodingAnalyzer, MinMaxEncodingAnalyzer, _HistogramObserver @pytest.fixture(autouse=True) def set_seed(): @@ -50,10 +50,9 @@ class TestEncodingAnalyzer(): def encoding_analyzers(self): min_max_encoding_analyzer = MinMaxEncodingAnalyzer((1,)) # TODO: Uncomment after implementation is complete - # mse_encoding_analyzer = MseEncodingAnalyzer() - # percentile_encoding_analyzer = PercentileEncodingAnalyzer(percentile=99) + # percentile_encoding_analyzer = PercentileEncodingAnalyzer() # sqnr_encoding_analyzer = SqnrEncodingAnalyzer() - # encoding_analyzer_list = [min_max_encoding_analyzer, mse_encoding_analyzer, percentile_encoding_analyzer, sqnr_encoding_analyzer] + # encoding_analyzer_list = [min_max_encoding_analyzer, percentile_encoding_analyzer, sqnr_encoding_analyzer] encoding_analyzer_list = [min_max_encoding_analyzer] yield encoding_analyzer_list @@ -175,92 +174,245 @@ def test_update_stats_incompatible_dimension(self): with pytest.raises(RuntimeError): encoding_analyzer.update_stats(torch.randn(2, 3, 5)) -@pytest.mark.skip('Tests skipped due to TDD') + class TestHistogramEncodingAnalyzer: @pytest.fixture def histogram_based_encoding_analyzers(self, request): min_max_shape = request.param[0] num_bins = request.param[1] - mse_encoding_analyzer = MseEncodingAnalyzer(shape= min_max_shape, num_bins = num_bins) - percentile_encoding_analyzer = PercentileEncodingAnalyzer(percentile=99, shape= min_max_shape, num_bins = num_bins) - sqnr_encoding_analyzer = SqnrEncodingAnalyzer(shape= min_max_shape, num_bins = num_bins) - encoding_analyzer_list = [mse_encoding_analyzer, percentile_encoding_analyzer, sqnr_encoding_analyzer] + percentile_encoding_analyzer = PercentileEncodingAnalyzer(min_max_shape, num_bins) + sqnr_encoding_analyzer = SqnrEncodingAnalyzer(min_max_shape, num_bins) + encoding_analyzer_list = [percentile_encoding_analyzer, sqnr_encoding_analyzer] yield encoding_analyzer_list @pytest.mark.parametrize('num_bins', [-1, 0]) def test_invalid_bin_value(self, num_bins): min_max_shape = (1,) - with pytest.raises(ValueError): - MseEncodingAnalyzer(num_bins=num_bins) with pytest.raises(ValueError): - PercentileEncodingAnalyzer(num_bins = num_bins, percentile=99, shape=min_max_shape) + PercentileEncodingAnalyzer(min_max_shape, num_bins) with pytest.raises(ValueError): - SqnrEncodingAnalyzer(num_bins = num_bins, shape=min_max_shape) + SqnrEncodingAnalyzer(min_max_shape, num_bins) @pytest.mark.parametrize("histogram_based_encoding_analyzers", [((1,), 3)], indirect=True) - def test_merge_stats(self, histogram_based_encoding_analyzers): + def test_merge_stats_resize_histogram(self, histogram_based_encoding_analyzers): for encoding_analyzer in histogram_based_encoding_analyzers: - input_tensor_1 = [2.0, 3.5, 4.2, 5.0] + input_tensor_1 = torch.tensor([2.0, 3.5, 4.2, 5.0]) encoding_analyzer.update_stats(input_tensor_1) - assert encoding_analyzer.stats.min == 2 - assert encoding_analyzer.stats.max == 5 - assert encoding_analyzer.stats.histogram == [1, 1, 2] - assert encoding_analyzer.stats.bin_edges == [2, 3, 4, 5] + assert len(encoding_analyzer.observer.stats) == 1 + assert encoding_analyzer.observer.stats[0].min == 2 + assert encoding_analyzer.observer.stats[0].max == 5 + assert torch.all(torch.eq(encoding_analyzer.observer.stats[0].histogram, torch.Tensor([1.0, 1.0, 2.0]))) + assert torch.all(torch.eq(encoding_analyzer.observer.stats[0].bin_edges, torch.Tensor([2.0, 3.0, 4.0, 5.0]))) - input_tensor_2 = [5.3, 6.4, 7.0, 8.0] + # update max + input_tensor_2 = torch.tensor([5.3, 6.4, 7.0, 8.0]) encoding_analyzer.update_stats(input_tensor_2) - assert encoding_analyzer.stats.min == 2 - assert encoding_analyzer.stats.max == 8 - assert encoding_analyzer.stats.histogram == [2, 3, 3] - assert encoding_analyzer.stats.bin_edges == [2, 4, 6, 8] + assert len(encoding_analyzer.observer.stats) == 1 + assert encoding_analyzer.observer.stats[0].min == 2 + assert encoding_analyzer.observer.stats[0].max == 8 + assert torch.all(torch.eq(encoding_analyzer.observer.stats[0].histogram, torch.Tensor([2.0, 3.0, 3.0]))) + assert torch.all(torch.eq(encoding_analyzer.observer.stats[0].bin_edges, torch.Tensor([2.0, 4.0, 6.0, 8.0]))) + + # update min + input_tensor_3 = torch.tensor([-4.2, 0, 2.3, 4.5]) + encoding_analyzer.update_stats(input_tensor_3) + assert len(encoding_analyzer.observer.stats) == 1 + assert encoding_analyzer.observer.stats[0].min == -4.2 + assert encoding_analyzer.observer.stats[0].max == 8 + assert torch.all(torch.eq(encoding_analyzer.observer.stats[0].histogram, torch.Tensor([1, 4, 7]))) + assert torch.allclose(encoding_analyzer.observer.stats[0].bin_edges, torch.Tensor([-4.2000, -0.133333, 3.933333, 8.000])) @pytest.mark.parametrize("histogram_based_encoding_analyzers", [((1,), 3)], indirect=True) - def test_merge_stats_same_tensor(self, histogram_based_encoding_analyzers): + def test_merge_stats_resize_histogram_with_ambiguous_bins(self, histogram_based_encoding_analyzers): for encoding_analyzer in histogram_based_encoding_analyzers: - input_tensor_1 = [2.0, 3.5, 4.2, 5.0] + input_tensor_1 = torch.tensor([-4.2, 2.4, 7.0, 8.0]) encoding_analyzer.update_stats(input_tensor_1) - assert encoding_analyzer.stats.min == 2 - assert encoding_analyzer.stats.max == 5 - assert encoding_analyzer.stats.histogram == [1, 1, 2] - assert encoding_analyzer.stats.bin_edges == [2, 3, 4, 5] - - input_tensor_2 = [2.0, 3.5, 4.2, 5.0] + assert len(encoding_analyzer.observer.stats) == 1 + assert encoding_analyzer.observer.stats[0].min == -4.2 + assert encoding_analyzer.observer.stats[0].max == 8 + assert torch.all(torch.eq(encoding_analyzer.observer.stats[0].histogram, torch.Tensor([1.0, 1.0, 2.0]))) + assert torch.allclose(encoding_analyzer.observer.stats[0].bin_edges, torch.Tensor([-4.2000, -0.133333, 3.933333, 8.000])) + + input_tensor_2 = torch.tensor([-6.7, -2.5, 7.2, 10.3]) + # hist is [2, 0, 2] for this tensor only + encoding_analyzer.update_stats(input_tensor_2) + assert len(encoding_analyzer.observer.stats) == 1 + assert encoding_analyzer.observer.stats[0].min == -6.7 + assert encoding_analyzer.observer.stats[0].max == 10.3 + ''' + Ambiguity lies when mapping 1st and 3rd bins ex: values in [-4.2, -0.133) could map to [-6.7, -1.033) or [-1.033, 4.633) + ''' + assert torch.all(torch.eq(encoding_analyzer.observer.stats[0].histogram, torch.Tensor([3.0, 1.0, 4.0]))) + assert torch.allclose(encoding_analyzer.observer.stats[0].bin_edges, torch.Tensor([-6.7000, -1.03333, 4.63333, 10.3000])) + + @pytest.mark.parametrize("histogram_based_encoding_analyzers", [((1,), 3)], indirect=True) + def test_merge_stats_resize_histogram_with_bin_splitting(self, histogram_based_encoding_analyzers): + for encoding_analyzer in histogram_based_encoding_analyzers: + input_tensor_1 = torch.tensor([1, 7, 5.3, 6, 5.7, 6.8, 6.2, 2.8, 3.9]) + encoding_analyzer.update_stats(input_tensor_1) + assert len(encoding_analyzer.observer.stats) == 1 + assert encoding_analyzer.observer.stats[0].min == 1 + assert encoding_analyzer.observer.stats[0].max == 7 + assert torch.all(torch.eq(encoding_analyzer.observer.stats[0].histogram, torch.Tensor([2.0, 1.0, 6.0]))) + assert torch.allclose(encoding_analyzer.observer.stats[0].bin_edges, torch.Tensor([1, 3, 5, 7])) + + input_tensor_2 = torch.tensor([0, 9, 7.8, 2.5, 4.6, 6.2, 8.8]) encoding_analyzer.update_stats(input_tensor_2) - assert encoding_analyzer.stats.min == 2 - assert encoding_analyzer.stats.max == 5 - assert encoding_analyzer.stats.histogram == [2, 2, 4] - assert encoding_analyzer.stats.bin_edges == [2, 3, 4, 5] + assert len(encoding_analyzer.observer.stats) == 1 + assert encoding_analyzer.observer.stats[0].min == 0 + assert encoding_analyzer.observer.stats[0].max == 9 + # 6 values from the source's histograms 3rd bucket are split in half into the destination's 2nd and 3rd bucket + assert torch.all(torch.eq(encoding_analyzer.observer.stats[0].histogram, torch.Tensor([4.0, 5.0, 7.0]))) + assert torch.allclose(encoding_analyzer.observer.stats[0].bin_edges, torch.Tensor([0, 3, 6, 9])) - @pytest.mark.parametrize("histogram_based_encoding_analyzers", [((1,), 5)], indirect=True) - def test_handle_outliers(self, histogram_based_encoding_analyzers): + @pytest.mark.parametrize("histogram_based_encoding_analyzers", [((1,), 1)], indirect=True) + def test_histogram_with_one_bin(self, histogram_based_encoding_analyzers): + for encoding_analyzer in histogram_based_encoding_analyzers: + input_tensor_1 = torch.tensor([1, 7, 5.3, 6, 5.7, 6.8, 6.2, 2.8, 3.9]) + encoding_analyzer.update_stats(input_tensor_1) + assert encoding_analyzer.observer.stats[0].min == 1 + assert encoding_analyzer.observer.stats[0].max == 7 + + @pytest.mark.parametrize("histogram_based_encoding_analyzers", [((1,), 3)], indirect=True) + def test_merge_stats_without_resizing(self, histogram_based_encoding_analyzers): for encoding_analyzer in histogram_based_encoding_analyzers: - input_tensor = torch.arange(start=0, end=100, step=0.5, dtype=torch.float) - data_type = encoding_analyzer.observer.stats.min.dtype - outliers = torch.tensor([torch.finfo(data_type).tiny, torch.finfo(data_type).max]) - input_tensor = torch.cat((input_tensor, outliers), 1) - encoding_analyzer.update_stats(input_tensor) - assert encoding_analyzer.stats.min == 0 - assert encoding_analyzer.stats.max == 99.5 - assert encoding_analyzer.stats.histogram == [40, 40, 40, 40, 40] - assert encoding_analyzer.stats.bin_edge == [0, 19.9, 39.8, 59.7, 79.6, 99.5] + input_tensor_1 = torch.tensor([2.0, 3.5, 4.2, 5.0]) + encoding_analyzer.update_stats(input_tensor_1) + assert len(encoding_analyzer.observer.stats) == 1 + assert encoding_analyzer.observer.stats[0].min == 2 + assert encoding_analyzer.observer.stats[0].max == 5 + assert torch.all(torch.eq(encoding_analyzer.observer.stats[0].histogram, torch.Tensor([1, 1, 2]))) + assert torch.all(torch.eq(encoding_analyzer.observer.stats[0].bin_edges, torch.Tensor([2, 3, 4, 5]))) + + # same min, max values + input_tensor_2 = torch.tensor([2.0, 3.3, 4.8, 5]) + encoding_analyzer.update_stats(input_tensor_2) + assert len(encoding_analyzer.observer.stats) == 1 + assert encoding_analyzer.observer.stats[0].min == 2 + assert encoding_analyzer.observer.stats[0].max == 5 + assert torch.all(torch.eq(encoding_analyzer.observer.stats[0].histogram, torch.Tensor([2, 2, 4]))) + assert torch.all(torch.eq(encoding_analyzer.observer.stats[0].bin_edges, torch.Tensor([2, 3, 4, 5]))) + + # min and max within current range + input_tensor_3 = torch.tensor([3.1, 3.3, 3.7, 3.9]) + encoding_analyzer.update_stats(input_tensor_3) + assert len(encoding_analyzer.observer.stats) == 1 + assert encoding_analyzer.observer.stats[0].min == 2 + assert encoding_analyzer.observer.stats[0].max == 5 + assert torch.all(torch.eq(encoding_analyzer.observer.stats[0].histogram, torch.Tensor([2, 6, 4]))) + assert torch.all(torch.eq(encoding_analyzer.observer.stats[0].bin_edges, torch.Tensor([2, 3, 4, 5]))) + + def test_collect_stats_multidimensional(self): + x = torch.arange(24, dtype=torch.float).view(2, 3, 4) + shape = (4,) + observer = _HistogramObserver(shape, num_bins=5) + histograms = observer.collect_stats(x) + for i in range(4): + assert torch.equal(histograms[i].min, x[:,:,i].min()) + assert torch.equal(histograms[i].max, x[:,:,i].max()) + + shape = (3, 1) + observer = _HistogramObserver(shape, num_bins=5) + histograms = observer.collect_stats(x) + for i in range(3): + assert torch.equal(histograms[i].min, x[:,i,:].min()) + assert torch.equal(histograms[i].max, x[:,i,:].max()) + + shape = (2, 1, 1) + observer = _HistogramObserver(shape, num_bins=5) + histograms = observer.collect_stats(x) + for i in range(2): + assert torch.equal(histograms[i].min, x[i,:,:].min()) + assert torch.equal(histograms[i].max, x[i,:,:].max()) + + + shape = (3, 4) + observer = _HistogramObserver(shape, num_bins=5) + histograms = observer.collect_stats(x) + for i in range(12): + j = i // 4 + k = i % 4 + assert torch.equal(histograms[i].min, x[:,j,k].min()) + assert torch.equal(histograms[i].max, x[:,j,k].max()) + + shape = (2, 3, 1) + observer = _HistogramObserver(shape, num_bins=5) + histograms = observer.collect_stats(x) + for i in range(6): + j = i // 3 + k = i % 3 + assert torch.equal(histograms[i].min, x[j,k,:].min()) + assert torch.equal(histograms[i].max, x[j,k,:].max()) + + shape = (2, 1, 4) + observer = _HistogramObserver(shape, num_bins=5) + histograms = observer.collect_stats(x) + for i in range(8): + j = i // 4 + k = i % 4 + assert torch.equal(histograms[i].min, x[j,:,k].min()) + assert torch.equal(histograms[i].max, x[j,:,k].max()) + + shape = (2, 3, 4) + observer = _HistogramObserver(shape, num_bins=5) + histograms = observer.collect_stats(x) + for i in range(24): + j = i // 12 + k = (i // 4) % 3 + m = i % 4 + assert torch.equal(histograms[i].min, x[j,k,m].min()) + assert torch.equal(histograms[i].max, x[j,k,m].max()) + + + def test_histogram_during_merging(self): + observer = _HistogramObserver((1,), num_bins=10) + input = torch.arange(-50, 51, dtype=torch.float) + old_histogram = observer.collect_stats(input) + observer.merge_stats(old_histogram, input) + + input = torch.arange(-50, 51, dtype=torch.float) * 1.5 + new_histogram = observer.collect_stats(input) + observer.merge_stats(new_histogram, input) + + merged_histogram = observer.stats[0] + assert list(merged_histogram.histogram) == [10, 15, 25, 25, 25, 25, 25, 26, 15, 11] + assert list(merged_histogram.bin_edges) == [-75, -60, -45, -30, -15, 0, 15, 30, 45, 60, 75] + + # (old_histogram) + # + # 10 10 10 10 10 10 10 10 10 11 + # |-----|-----|-----|-----|-----|-----|-----|-----|-----|-----| + # -50 | -40 -30 -20 | -10 0 10 | 20 30 40 | 50 + # | | | | | | | + # | | | | | | | + # | | | | | | | + # (+5) | (+15) | (+15) | (+15) | (+15) | (+15) | (+16) | (+5) + # 10 10 | 10 | 10 | 10 | 10 | 10 | 10 | 10 11 + # |--------|--------|--------|--------|--------|--------|--------|--------|--------|--------| + # -75 -60 -45 -30 -15 0 15 30 45 60 75 + # + # (new_histogram) + @pytest.mark.skip('Tests skipped due to TDD') class TestPercentileEncodingAnalyzer(): @pytest.mark.parametrize("percentile_value", [-1, 49, 5, 101]) def test_invalid_percentile_value(self, percentile_value): with pytest.raises(ValueError): - PercentileEncodingAnalyzer(percentile=percentile_value) + encoding_analyzer = PercentileEncodingAnalyzer() + encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False, percentile=percentile_value) + def test_compute_encodings_asymmetric_normalized(self): - encoding_analyzer = PercentileEncodingAnalyzer(percentile=99) + encoding_analyzer = PercentileEncodingAnalyzer() mean = std_dev = 2 input_tensor = np.random.normal(mean, std_dev, size=(10000)) encoding_analyzer.update_stats(input_tensor) - asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False) + asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False, percentile=99) # 99.7% of values in a normal disturbtion are within 3 standard deviations of the mean assert asymmetric_min < mean - std_dev * 2 assert asymmetric_min > mean - std_dev * 3 @@ -269,43 +421,43 @@ def test_compute_encodings_asymmetric_normalized(self): assert asymmetric_max < mean + std_dev * 3 def test_compute_encodings_asymmetric_sequential(self): - encoding_analyzer = PercentileEncodingAnalyzer(num_bins = 500, percentile=99) + encoding_analyzer = PercentileEncodingAnalyzer(num_bins = 500) input_tensor = torch.arange(start=0, end=1001, step=1, dtype=torch.float) encoding_analyzer.update_stats(input_tensor) - asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False) + asymmetric_min, asymmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = False, percentile=99) # encoding max is the histogram bin edge which contains 99% percentile (990.02) assert asymmetric_min == 0 assert asymmetric_max == 990.0 def test_compute_encodings_signed_symmetric_normalized(self): - encoding_analyzer = PercentileEncodingAnalyzer(percentile=99) + encoding_analyzer = PercentileEncodingAnalyzer() mean = std_dev = 2 input_tensor = np.random.normal(mean, std_dev, size=(10000)) encoding_analyzer.update_stats(input_tensor) - symmetric_min, symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True) + symmetric_min, symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True, percentile=99) largest_absolute_value = max(abs(element) for element in input_tensor) assert symmetric_min > -largest_absolute_value assert symmetric_max < largest_absolute_value def test_compute_encodings_signed_symmetric_sequential(self): - encoding_analyzer = PercentileEncodingAnalyzer(num_bins = 500, percentile=99) + encoding_analyzer = PercentileEncodingAnalyzer(num_bins = 500) input_tensor = torch.arange(start=0, end=1001, step=1, dtype=torch.float) encoding_analyzer.update_stats(input_tensor) - symmetric_min, symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True) + symmetric_min, symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True, percentile=99) assert symmetric_min == -990.0 assert symmetric_max == 990.0 def test_compute_encodings_100_percentile(self): - encoding_analyzer = PercentileEncodingAnalyzer(percentile=100) + encoding_analyzer = PercentileEncodingAnalyzer() mean = std_dev = 2 input_tensor = np.random.normal(mean, std_dev, size=(10000)) encoding_analyzer.update_stats(input_tensor) - symmetric_min, symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True) + symmetric_min, symmetric_max = encoding_analyzer.compute_encodings(bitwidth = 8, is_symmetric = True, percentile=100) largest_absolute_value = max(abs(element) for element in input_tensor) assert abs(symmetric_min) <= largest_absolute_value assert symmetric_max <= largest_absolute_value @@ -315,14 +467,14 @@ def test_compute_encodings_100_percentile(self): assert asymmetric_max == max(input_tensor) def test_compute_encodings_50_percentile(self): - encoding_analyzer = PercentileEncodingAnalyzer(percentile=50) + encoding_analyzer = PercentileEncodingAnalyzer() input_tensor = torch.arange(start=0, end=1001, step=1, dtype=torch.float) encoding_analyzer.update_stats(input_tensor) bw = 8 updated_min = torch.finfo(asymmetric_min.dtype).tiny * (2 ** (bw - 1)) updated_max = torch.finfo(asymmetric_min.dtype).tiny * ((2 **(bw - 1)) - 1) - symmetric_min, symmetric_max = encoding_analyzer.compute_encodings(bitwidth = bw, is_symmetric = True) + symmetric_min, symmetric_max = encoding_analyzer.compute_encodings(bitwidth = bw, is_symmetric = True, percentile=50) assert symmetric_min == min(-updated_min, -updated_max) assert symmetric_max == max(updated_min, updated_max)