From ffe7c227ca6600916d2187507d2dbdd4efebdec7 Mon Sep 17 00:00:00 2001 From: reuvenp Date: Sat, 12 Oct 2024 14:04:08 +0300 Subject: [PATCH] add keras gradual quantization tests --- .../gptq/keras/gptq_training.py | 1 - tests_pytest/keras/gptq/__init__.py | 0 .../gptq/test_gradual_act_quantization.py | 102 ++++++++++++++++++ .../gptq/test_gradual_act_quantization.py | 6 +- .../test_linear_annealing.py | 8 +- 5 files changed, 109 insertions(+), 8 deletions(-) create mode 100644 tests_pytest/keras/gptq/__init__.py create mode 100644 tests_pytest/keras/gptq/test_gradual_act_quantization.py diff --git a/model_compression_toolkit/gptq/keras/gptq_training.py b/model_compression_toolkit/gptq/keras/gptq_training.py index a04698b63..4c15af8cc 100644 --- a/model_compression_toolkit/gptq/keras/gptq_training.py +++ b/model_compression_toolkit/gptq/keras/gptq_training.py @@ -44,7 +44,6 @@ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig from model_compression_toolkit.core.common import Graph from model_compression_toolkit.gptq.keras.graph_info import get_weights_for_loss, get_gptq_trainable_parameters -from model_compression_toolkit.gptq.keras.quantizer.regularization_factory import get_regularization from model_compression_toolkit.core.common.framework_info import FrameworkInfo from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation import numpy as np diff --git a/tests_pytest/keras/gptq/__init__.py b/tests_pytest/keras/gptq/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_pytest/keras/gptq/test_gradual_act_quantization.py b/tests_pytest/keras/gptq/test_gradual_act_quantization.py new file mode 100644 index 000000000..cde508c7c --- /dev/null +++ b/tests_pytest/keras/gptq/test_gradual_act_quantization.py @@ -0,0 +1,102 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from unittest.mock import Mock +import pytest +import numpy as np +import tensorflow as tf + +from model_compression_toolkit.gptq.common.gradual_activation_quantization import GradualActivationQuantizerWrapper, \ + get_gradual_activation_quantizer_wrapper_factory +from model_compression_toolkit.trainable_infrastructure.keras.annealing_schedulers import KerasLinearAnnealingScheduler +from model_compression_toolkit.gptq import GradientPTQConfig, GradualActivationQuantizationConfig, QFractionLinearAnnealingConfig + + + +@pytest.fixture +def x(): + return tf.random.normal((2, 5, 6, 7), seed=42, dtype=tf.float32) + + +class Quantizer: + def __call__(self, x, training): + self.training = training + return 3 * x + 1 + + +class TestGradualActivationQuantization: + + def test_gradual_act_quant_wrapper(self, x): + quantizer = Quantizer() + qw = GradualActivationQuantizerWrapper(quantizer, q_fraction_scheduler=lambda t: t / (t + 1)) + + y0, y1, y2 = [qw(x, training=True) for _ in range(3)] + np.testing.assert_array_almost_equal(y0.numpy(), x.numpy()) # t=0 + np.testing.assert_allclose(y1.numpy(), 0.5 * x.numpy() + (1.5 * x.numpy() + 0.5), rtol=1e-5, atol=1e-8) # t=1 + np.testing.assert_allclose(y2.numpy(), x.numpy() / 3 + (2 * x.numpy() + 2 / 3), rtol=1e-5, atol=1e-8) # t=2 + assert quantizer.training is True + + _ = qw(x, training=False) + assert quantizer.training is False # correct flag was propagated + + def test_factory_no_qdrop(self): + quantizer_wrapper, quantizer = self._run_factory_test(qdrop_cfg=None, get_grad_steps_fn=None) + assert quantizer_wrapper is quantizer + + @pytest.mark.parametrize('end_step', (20, None)) + def test_factory_linear(self, x, end_step): + qdrop_cfg = GradualActivationQuantizationConfig( + QFractionLinearAnnealingConfig(initial_q_fraction=0.3, target_q_fraction=0.8, start_step=10, + end_step=end_step) + ) + + def get_total_steps(): + if end_step is None: + return 50 + assert False # should not be called if end_step is passed + + quantizer_wrapper, quantizer = self._run_factory_test(qdrop_cfg, get_total_steps) + + scheduler = quantizer_wrapper.q_fraction_scheduler + assert isinstance(scheduler, KerasLinearAnnealingScheduler) + exp_end_step = 50 if end_step is None else end_step + assert scheduler.t_start == 10 + assert scheduler.t_end == exp_end_step + assert scheduler.initial_val == 0.3 + assert scheduler.target_val == 0.8 + + y = [quantizer_wrapper(x, training=True) for _ in range(exp_end_step + 1)] + + np.testing.assert_allclose(y[9].numpy(), 0.7 * x.numpy() + 0.3 * quantizer(x, training=True).numpy(), rtol=1e-5, atol=1e-8) + np.testing.assert_allclose(y[10].numpy(), 0.7 * x.numpy() + 0.3 * quantizer(x, training=True).numpy(), rtol=1e-5, atol=1e-8) + np.testing.assert_allclose(y[-1].numpy(), 0.2 * x.numpy() + 0.8 * quantizer(x, training=True).numpy(), rtol=1e-5, atol=1e-8) + + def test_factory_linear_common_case(self, x): + # validate that we actually implemented the right thing - on first call float input, on last call fully quantized + qdrop_cfg = GradualActivationQuantizationConfig( + QFractionLinearAnnealingConfig(initial_q_fraction=0, target_q_fraction=1, start_step=0, end_step=None) + ) + quantizer_wrapper, quantizer = self._run_factory_test(qdrop_cfg, lambda: 15) + y0, *_, y_last = [quantizer_wrapper(x, training=True) for _ in range(16)] + np.testing.assert_array_almost_equal(y0.numpy(), x.numpy()) + np.testing.assert_allclose(y_last.numpy(), quantizer(x, training=True).numpy()) + + def _run_factory_test(self, qdrop_cfg, get_grad_steps_fn): + # Mocks are used to just pass anything + gptq_cfg = GradientPTQConfig(n_epochs=5, optimizer=Mock(), loss=Mock(), + gradual_activation_quantization_config=qdrop_cfg) + factory = get_gradual_activation_quantizer_wrapper_factory(gptq_cfg, get_grad_steps_fn, KerasLinearAnnealingScheduler) + quantizer = Quantizer() + quantizer_wrapper = factory(quantizer) + return quantizer_wrapper, quantizer diff --git a/tests_pytest/pytorch/gptq/test_gradual_act_quantization.py b/tests_pytest/pytorch/gptq/test_gradual_act_quantization.py index 19f02a660..08c440d6c 100644 --- a/tests_pytest/pytorch/gptq/test_gradual_act_quantization.py +++ b/tests_pytest/pytorch/gptq/test_gradual_act_quantization.py @@ -18,7 +18,7 @@ import torch from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device -from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import LinearAnnealingScheduler +from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import PytorchLinearAnnealingScheduler from model_compression_toolkit.gptq import GradientPTQConfig, GradualActivationQuantizationConfig, QFractionLinearAnnealingConfig from model_compression_toolkit.gptq.common.gradual_activation_quantization import ( GradualActivationQuantizerWrapper, get_gradual_activation_quantizer_wrapper_factory) @@ -68,7 +68,7 @@ def get_total_steps(): quantizer_wrapper, quantizer = self._run_factory_test(qdrop_cfg, get_total_steps) scheduler = quantizer_wrapper.q_fraction_scheduler - assert isinstance(scheduler, LinearAnnealingScheduler) + assert isinstance(scheduler, PytorchLinearAnnealingScheduler) exp_end_step = 50 if end_step is None else end_step assert scheduler.t_start == 10 assert scheduler.t_end == exp_end_step @@ -94,7 +94,7 @@ def _run_factory_test(self, qdrop_cfg, get_grad_steps_fn): # Mocks are used to just pass anything gptq_cfg = GradientPTQConfig(n_epochs=5, optimizer=Mock(), loss=Mock(), gradual_activation_quantization_config=qdrop_cfg) - factory = get_gradual_activation_quantizer_wrapper_factory(gptq_cfg, get_grad_steps_fn, LinearAnnealingScheduler) + factory = get_gradual_activation_quantizer_wrapper_factory(gptq_cfg, get_grad_steps_fn, PytorchLinearAnnealingScheduler) quantizer = Quantizer() quantizer_wrapper = factory(quantizer) return quantizer_wrapper, quantizer diff --git a/tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py b/tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py index d6edca605..d5dfecf00 100644 --- a/tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py +++ b/tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py @@ -15,11 +15,11 @@ import torch import pytest -from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import LinearAnnealingScheduler +from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import PytorchLinearAnnealingScheduler def test_linear_annealing(): - scheduler = LinearAnnealingScheduler(t_start=10, t_end=35, initial_val=3.4, target_val=-1.6) + scheduler = PytorchLinearAnnealingScheduler(t_start=10, t_end=35, initial_val=3.4, target_val=-1.6) for t in [0, 9, 10]: assert _isclose(scheduler(t), 3.4) @@ -32,7 +32,7 @@ def test_linear_annealing(): def test_linear_annealing_ascending(): - scheduler = LinearAnnealingScheduler(t_start=0, t_end=5, initial_val=-0.5, target_val=1.5) + scheduler = PytorchLinearAnnealingScheduler(t_start=0, t_end=5, initial_val=-0.5, target_val=1.5) assert _isclose(scheduler(0), -0.5) assert _isclose(scheduler(1), -0.1) assert _isclose(scheduler(4), 1.1) @@ -42,7 +42,7 @@ def test_linear_annealing_ascending(): @pytest.mark.parametrize('start', [5, -1]) def test_invalid(start): with pytest.raises(ValueError): - LinearAnnealingScheduler(t_start=start, t_end=4, initial_val=1, target_val=0) + PytorchLinearAnnealingScheduler(t_start=start, t_end=4, initial_val=1, target_val=0) def _isclose(x, y):