diff --git a/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py b/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py index 7d99ebbad73..9b54842dd14 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py @@ -36,7 +36,7 @@ # ============================================================================= import torch import pytest -from aimet_torch.experimental.v2.quantization.encoding_analyzer import get_encoding_analyzer_cls, CalibrationMethod +from aimet_torch.experimental.v2.quantization.encoding_analyzer import get_encoding_analyzer_cls, CalibrationMethod, MseEncodingAnalyzer, PercentileEncodingAnalyzer, SqnrEncodingAnalyzer class TestEncodingAnalyzer(): @@ -162,3 +162,76 @@ def test_update_stats_incompatible_dimension(self): encoding_analyzer_1 = get_encoding_analyzer_cls(CalibrationMethod.MinMax, [3, 4]) with pytest.raises(RuntimeError): encoding_analyzer_1.update_stats(torch.randn(2, 3, 5)) + +@pytest.mark.skip('Tests skipped due to TDD') +class TestHistogramEncodingAnalyzer: + @pytest.fixture + def histogram_based_encoding_analyzers(self, request): + min_max_shape = request.param[0] + num_bins = request.param[1] + + mse_encoding_analyzer = MseEncodingAnalyzer(shape= min_max_shape, num_bins = num_bins) + percentile_encoding_analyzer = PercentileEncodingAnalyzer(percentile=99, shape= min_max_shape, num_bins = num_bins) + sqnr_encoding_analyzer = SqnrEncodingAnalyzer(shape= min_max_shape, num_bins = num_bins) + encoding_analyzer_list = [mse_encoding_analyzer, percentile_encoding_analyzer, sqnr_encoding_analyzer] + yield encoding_analyzer_list + + @pytest.mark.parametrize('num_bins', [-1, 0]) + def test_invalid_bin_value(self, num_bins): + min_max_shape = (1,) + with pytest.raises(ValueError): + MseEncodingAnalyzer(num_bins=num_bins) + + with pytest.raises(ValueError): + PercentileEncodingAnalyzer(num_bins = num_bins, percentile=99, shape=min_max_shape) + + with pytest.raises(ValueError): + SqnrEncodingAnalyzer(num_bins = num_bins, shape=min_max_shape) + + @pytest.mark.parametrize("histogram_based_encoding_analyzers", [((1,), 3)], indirect=True) + def test_merge_stats(self, histogram_based_encoding_analyzers): + for encoding_analyzer in histogram_based_encoding_analyzers: + input_tensor_1 = [2.0, 3.5, 4.2, 5.0] + encoding_analyzer.update_stats(input_tensor_1) + assert encoding_analyzer.stats.min == 2 + assert encoding_analyzer.stats.max == 5 + assert encoding_analyzer.stats.histogram == [1, 1, 2] + assert encoding_analyzer.stats.bin_edges == [2, 3, 4, 5] + + input_tensor_2 = [5.3, 6.4, 7.0, 8.0] + encoding_analyzer.update_stats(input_tensor_2) + assert encoding_analyzer.stats.min == 2 + assert encoding_analyzer.stats.max == 8 + assert encoding_analyzer.stats.histogram == [2, 3, 3] + assert encoding_analyzer.stats.bin_edges == [2, 4, 6, 8] + + @pytest.mark.parametrize("histogram_based_encoding_analyzers", [((1,), 3)], indirect=True) + def test_merge_stats_same_tensor(self, histogram_based_encoding_analyzers): + for encoding_analyzer in histogram_based_encoding_analyzers: + input_tensor_1 = [2.0, 3.5, 4.2, 5.0] + encoding_analyzer.update_stats(input_tensor_1) + assert encoding_analyzer.stats.min == 2 + assert encoding_analyzer.stats.max == 5 + assert encoding_analyzer.stats.histogram == [1, 1, 2] + assert encoding_analyzer.stats.bin_edges == [2, 3, 4, 5] + + input_tensor_2 = [2.0, 3.5, 4.2, 5.0] + encoding_analyzer.update_stats(input_tensor_2) + assert encoding_analyzer.stats.min == 2 + assert encoding_analyzer.stats.max == 5 + assert encoding_analyzer.stats.histogram == [2, 2, 4] + assert encoding_analyzer.stats.bin_edges == [2, 3, 4, 5] + + @pytest.mark.parametrize("histogram_based_encoding_analyzers", [((1,), 5)], indirect=True) + def test_handle_outliers(self, histogram_based_encoding_analyzers): + for encoding_analyzer in histogram_based_encoding_analyzers: + input_tensor = torch.arange(start=0, end=100, step=0.5, dtype=torch.float) + data_type = encoding_analyzer.observer.stats.min.dtype + outliers = torch.tensor([torch.finfo(data_type).tiny, torch.finfo(data_type).max]) + input_tensor = torch.cat((input_tensor, outliers), 1) + encoding_analyzer.update_stats(input_tensor) + assert encoding_analyzer.stats.min == 0 + assert encoding_analyzer.stats.max == 99.5 + assert encoding_analyzer.stats.histogram == [40, 40, 40, 40, 40] + assert encoding_analyzer.stats.bin_edge == [0, 19.9, 39.8, 59.7, 79.6, 99.5] +