Skip to content

Commit

Permalink
Updating tests
Browse files Browse the repository at this point in the history
Signed-off-by: Priyanka Dangi <quic_pdangi@quicinc.com>
  • Loading branch information
quic-pdangi authored and quic-akhobare committed Feb 2, 2024
1 parent efe38e7 commit f011a84
Showing 1 changed file with 3 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import pytest
import numpy as np
import random
from aimet_torch.experimental.v2.quantization.encoding_analyzer import MseEncodingAnalyzer, SqnrEncodingAnalyzer, PercentileEncodingAnalyzer, MinMaxEncodingAnalyzer, _HistogramObserver
from aimet_torch.experimental.v2.quantization.encoding_analyzer import SqnrEncodingAnalyzer, PercentileEncodingAnalyzer, MinMaxEncodingAnalyzer, _HistogramObserver

@pytest.fixture(autouse=True)
def set_seed():
Expand All @@ -50,10 +50,9 @@ class TestEncodingAnalyzer():
def encoding_analyzers(self):
min_max_encoding_analyzer = MinMaxEncodingAnalyzer((1,))
# TODO: Uncomment after implementation is complete
# mse_encoding_analyzer = MseEncodingAnalyzer()
# percentile_encoding_analyzer = PercentileEncodingAnalyzer()
# sqnr_encoding_analyzer = SqnrEncodingAnalyzer()
# encoding_analyzer_list = [min_max_encoding_analyzer, mse_encoding_analyzer, percentile_encoding_analyzer, sqnr_encoding_analyzer]
# encoding_analyzer_list = [min_max_encoding_analyzer, percentile_encoding_analyzer, sqnr_encoding_analyzer]
encoding_analyzer_list = [min_max_encoding_analyzer]
yield encoding_analyzer_list

Expand Down Expand Up @@ -182,17 +181,14 @@ def histogram_based_encoding_analyzers(self, request):
min_max_shape = request.param[0]
num_bins = request.param[1]

mse_encoding_analyzer = MseEncodingAnalyzer(min_max_shape, num_bins)
percentile_encoding_analyzer = PercentileEncodingAnalyzer(min_max_shape, num_bins)
sqnr_encoding_analyzer = SqnrEncodingAnalyzer(min_max_shape, num_bins)
encoding_analyzer_list = [mse_encoding_analyzer, percentile_encoding_analyzer, sqnr_encoding_analyzer]
encoding_analyzer_list = [percentile_encoding_analyzer, sqnr_encoding_analyzer]
yield encoding_analyzer_list

@pytest.mark.parametrize('num_bins', [-1, 0])
def test_invalid_bin_value(self, num_bins):
min_max_shape = (1,)
with pytest.raises(ValueError):
MseEncodingAnalyzer(min_max_shape, num_bins)

with pytest.raises(ValueError):
PercentileEncodingAnalyzer(min_max_shape, num_bins)
Expand Down Expand Up @@ -370,8 +366,6 @@ def test_collect_stats_multidimensional(self):
m = i % 4
assert torch.equal(histograms[i].min, x[j,k,m].min())
assert torch.equal(histograms[i].max, x[j,k,m].max())

assert False


def test_histogram_during_merging(self):
Expand Down

0 comments on commit f011a84

Please sign in to comment.