Skip to content

Commit

Permalink
Adding tests common to Histogram based Encoding Analyzers (#2667)
Browse files Browse the repository at this point in the history
* Adding tests common to Histogram based Encoding Analyzers
---------

Signed-off-by: Priyanka Dangi <quic_pdangi@quicinc.com>
  • Loading branch information
quic-pdangi authored Jan 25, 2024
1 parent 5c4e9b2 commit 6717a30
Showing 1 changed file with 74 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
# =============================================================================
import torch
import pytest
from aimet_torch.experimental.v2.quantization.encoding_analyzer import get_encoding_analyzer_cls, CalibrationMethod
from aimet_torch.experimental.v2.quantization.encoding_analyzer import get_encoding_analyzer_cls, CalibrationMethod, MseEncodingAnalyzer, PercentileEncodingAnalyzer, SqnrEncodingAnalyzer


class TestEncodingAnalyzer():
Expand Down Expand Up @@ -162,3 +162,76 @@ def test_update_stats_incompatible_dimension(self):
encoding_analyzer_1 = get_encoding_analyzer_cls(CalibrationMethod.MinMax, [3, 4])
with pytest.raises(RuntimeError):
encoding_analyzer_1.update_stats(torch.randn(2, 3, 5))

@pytest.mark.skip('Tests skipped due to TDD')
class TestHistogramEncodingAnalyzer:
@pytest.fixture
def histogram_based_encoding_analyzers(self, request):
min_max_shape = request.param[0]
num_bins = request.param[1]

mse_encoding_analyzer = MseEncodingAnalyzer(shape= min_max_shape, num_bins = num_bins)
percentile_encoding_analyzer = PercentileEncodingAnalyzer(percentile=99, shape= min_max_shape, num_bins = num_bins)
sqnr_encoding_analyzer = SqnrEncodingAnalyzer(shape= min_max_shape, num_bins = num_bins)
encoding_analyzer_list = [mse_encoding_analyzer, percentile_encoding_analyzer, sqnr_encoding_analyzer]
yield encoding_analyzer_list

@pytest.mark.parametrize('num_bins', [-1, 0])
def test_invalid_bin_value(self, num_bins):
min_max_shape = (1,)
with pytest.raises(ValueError):
MseEncodingAnalyzer(num_bins=num_bins)

with pytest.raises(ValueError):
PercentileEncodingAnalyzer(num_bins = num_bins, percentile=99, shape=min_max_shape)

with pytest.raises(ValueError):
SqnrEncodingAnalyzer(num_bins = num_bins, shape=min_max_shape)

@pytest.mark.parametrize("histogram_based_encoding_analyzers", [((1,), 3)], indirect=True)
def test_merge_stats(self, histogram_based_encoding_analyzers):
for encoding_analyzer in histogram_based_encoding_analyzers:
input_tensor_1 = [2.0, 3.5, 4.2, 5.0]
encoding_analyzer.update_stats(input_tensor_1)
assert encoding_analyzer.stats.min == 2
assert encoding_analyzer.stats.max == 5
assert encoding_analyzer.stats.histogram == [1, 1, 2]
assert encoding_analyzer.stats.bin_edges == [2, 3, 4, 5]

input_tensor_2 = [5.3, 6.4, 7.0, 8.0]
encoding_analyzer.update_stats(input_tensor_2)
assert encoding_analyzer.stats.min == 2
assert encoding_analyzer.stats.max == 8
assert encoding_analyzer.stats.histogram == [2, 3, 3]
assert encoding_analyzer.stats.bin_edges == [2, 4, 6, 8]

@pytest.mark.parametrize("histogram_based_encoding_analyzers", [((1,), 3)], indirect=True)
def test_merge_stats_same_tensor(self, histogram_based_encoding_analyzers):
for encoding_analyzer in histogram_based_encoding_analyzers:
input_tensor_1 = [2.0, 3.5, 4.2, 5.0]
encoding_analyzer.update_stats(input_tensor_1)
assert encoding_analyzer.stats.min == 2
assert encoding_analyzer.stats.max == 5
assert encoding_analyzer.stats.histogram == [1, 1, 2]
assert encoding_analyzer.stats.bin_edges == [2, 3, 4, 5]

input_tensor_2 = [2.0, 3.5, 4.2, 5.0]
encoding_analyzer.update_stats(input_tensor_2)
assert encoding_analyzer.stats.min == 2
assert encoding_analyzer.stats.max == 5
assert encoding_analyzer.stats.histogram == [2, 2, 4]
assert encoding_analyzer.stats.bin_edges == [2, 3, 4, 5]

@pytest.mark.parametrize("histogram_based_encoding_analyzers", [((1,), 5)], indirect=True)
def test_handle_outliers(self, histogram_based_encoding_analyzers):
for encoding_analyzer in histogram_based_encoding_analyzers:
input_tensor = torch.arange(start=0, end=100, step=0.5, dtype=torch.float)
data_type = encoding_analyzer.observer.stats.min.dtype
outliers = torch.tensor([torch.finfo(data_type).tiny, torch.finfo(data_type).max])
input_tensor = torch.cat((input_tensor, outliers), 1)
encoding_analyzer.update_stats(input_tensor)
assert encoding_analyzer.stats.min == 0
assert encoding_analyzer.stats.max == 99.5
assert encoding_analyzer.stats.histogram == [40, 40, 40, 40, 40]
assert encoding_analyzer.stats.bin_edge == [0, 19.9, 39.8, 59.7, 79.6, 99.5]

0 comments on commit 6717a30

Please sign in to comment.