From f011a8426c7e0f08c4fd52b5e138fd916f88bba9 Mon Sep 17 00:00:00 2001 From: Priyanka Dangi Date: Thu, 1 Feb 2024 16:09:11 -0800 Subject: [PATCH] Updating tests Signed-off-by: Priyanka Dangi --- .../python/experimental/v2/test_encoding_analyzer.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py b/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py index 12f1420a45b..12eb606312a 100644 --- a/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py +++ b/TrainingExtensions/torch/test/python/experimental/v2/test_encoding_analyzer.py @@ -38,7 +38,7 @@ import pytest import numpy as np import random -from aimet_torch.experimental.v2.quantization.encoding_analyzer import MseEncodingAnalyzer, SqnrEncodingAnalyzer, PercentileEncodingAnalyzer, MinMaxEncodingAnalyzer, _HistogramObserver +from aimet_torch.experimental.v2.quantization.encoding_analyzer import SqnrEncodingAnalyzer, PercentileEncodingAnalyzer, MinMaxEncodingAnalyzer, _HistogramObserver @pytest.fixture(autouse=True) def set_seed(): @@ -50,10 +50,9 @@ class TestEncodingAnalyzer(): def encoding_analyzers(self): min_max_encoding_analyzer = MinMaxEncodingAnalyzer((1,)) # TODO: Uncomment after implementation is complete - # mse_encoding_analyzer = MseEncodingAnalyzer() # percentile_encoding_analyzer = PercentileEncodingAnalyzer() # sqnr_encoding_analyzer = SqnrEncodingAnalyzer() - # encoding_analyzer_list = [min_max_encoding_analyzer, mse_encoding_analyzer, percentile_encoding_analyzer, sqnr_encoding_analyzer] + # encoding_analyzer_list = [min_max_encoding_analyzer, percentile_encoding_analyzer, sqnr_encoding_analyzer] encoding_analyzer_list = [min_max_encoding_analyzer] yield encoding_analyzer_list @@ -182,17 +181,14 @@ def histogram_based_encoding_analyzers(self, request): min_max_shape = request.param[0] num_bins = request.param[1] - mse_encoding_analyzer = MseEncodingAnalyzer(min_max_shape, num_bins) percentile_encoding_analyzer = PercentileEncodingAnalyzer(min_max_shape, num_bins) sqnr_encoding_analyzer = SqnrEncodingAnalyzer(min_max_shape, num_bins) - encoding_analyzer_list = [mse_encoding_analyzer, percentile_encoding_analyzer, sqnr_encoding_analyzer] + encoding_analyzer_list = [percentile_encoding_analyzer, sqnr_encoding_analyzer] yield encoding_analyzer_list @pytest.mark.parametrize('num_bins', [-1, 0]) def test_invalid_bin_value(self, num_bins): min_max_shape = (1,) - with pytest.raises(ValueError): - MseEncodingAnalyzer(min_max_shape, num_bins) with pytest.raises(ValueError): PercentileEncodingAnalyzer(min_max_shape, num_bins) @@ -370,8 +366,6 @@ def test_collect_stats_multidimensional(self): m = i % 4 assert torch.equal(histograms[i].min, x[j,k,m].min()) assert torch.equal(histograms[i].max, x[j,k,m].max()) - - assert False def test_histogram_during_merging(self):