diff --git a/tests/pytorch_tests/trainable_infrastructure_tests/base_pytorch_trainable_infra_test.py b/tests/pytorch_tests/trainable_infrastructure_tests/base_pytorch_trainable_infra_test.py index ac208278c..8baa2fcf4 100644 --- a/tests/pytorch_tests/trainable_infrastructure_tests/base_pytorch_trainable_infra_test.py +++ b/tests/pytorch_tests/trainable_infrastructure_tests/base_pytorch_trainable_infra_test.py @@ -105,9 +105,10 @@ def get_weights_quantization_config(self): weights_per_channel_threshold=True, min_threshold=0) - def get_activation_quantization_config(self): - return TrainableQuantizerActivationConfig(activation_quantization_method=QuantizationMethod.POWER_OF_TWO, + def get_activation_quantization_config(self, quant_method=QuantizationMethod.POWER_OF_TWO, + activation_quant_params=None): + return TrainableQuantizerActivationConfig(activation_quantization_method=quant_method, activation_n_bits=8, - activation_quantization_params={}, + activation_quantization_params=activation_quant_params or {}, enable_activation_quantization=True, min_threshold=0) diff --git a/tests/pytorch_tests/trainable_infrastructure_tests/test_pytorch_trainable_infra_runner.py b/tests/pytorch_tests/trainable_infrastructure_tests/test_pytorch_trainable_infra_runner.py index 3aa36aaaa..e62dd06d7 100644 --- a/tests/pytorch_tests/trainable_infrastructure_tests/test_pytorch_trainable_infra_runner.py +++ b/tests/pytorch_tests/trainable_infrastructure_tests/test_pytorch_trainable_infra_runner.py @@ -34,7 +34,8 @@ from model_compression_toolkit.trainable_infrastructure.pytorch.base_pytorch_quantizer import \ BasePytorchTrainableQuantizer from tests.pytorch_tests.trainable_infrastructure_tests.trainable_pytorch.test_pytorch_base_quantizer import \ - TestPytorchBaseWeightsQuantizer, TestPytorchBaseActivationQuantizer, TestPytorchQuantizerWithoutMarkDecorator + TestPytorchBaseWeightsQuantizer, TestPytorchBaseActivationQuantizer, TestPytorchQuantizerWithoutMarkDecorator, \ + TestPytorchSTEActivationQuantizerQParamFreeze from tests.pytorch_tests.trainable_infrastructure_tests.trainable_pytorch.test_pytorch_get_quantizers import \ TestGetTrainableQuantizer @@ -46,6 +47,9 @@ def test_pytorch_base_quantizer(self): TestPytorchBaseActivationQuantizer(self).run_test() TestPytorchQuantizerWithoutMarkDecorator(self).run_test() + def test_pytorch_ste_activation_quantizers_qparams_freeze(self): + TestPytorchSTEActivationQuantizerQParamFreeze(self).run_test() + def test_pytorch_get_quantizers(self): TestGetTrainableQuantizer(self, quant_target=QuantizationTarget.Weights, quant_method=QuantizationMethod.POWER_OF_TWO, diff --git a/tests/pytorch_tests/trainable_infrastructure_tests/trainable_pytorch/test_pytorch_base_quantizer.py b/tests/pytorch_tests/trainable_infrastructure_tests/trainable_pytorch/test_pytorch_base_quantizer.py index 490be870e..f521e2ef9 100644 --- a/tests/pytorch_tests/trainable_infrastructure_tests/trainable_pytorch/test_pytorch_base_quantizer.py +++ b/tests/pytorch_tests/trainable_infrastructure_tests/trainable_pytorch/test_pytorch_base_quantizer.py @@ -14,7 +14,10 @@ # ============================================================================== from typing import List, Any -from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup +import torch + +from mct_quantizers import PytorchActivationQuantizationHolder +from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup, VAR from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import \ TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig from model_compression_toolkit.trainable_infrastructure.pytorch.base_pytorch_quantizer import \ @@ -22,6 +25,9 @@ from tests.pytorch_tests.trainable_infrastructure_tests.base_pytorch_trainable_infra_test import \ BasePytorchInfrastructureTest, ZeroWeightsQuantizer, ZeroActivationsQuantizer from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod +from model_compression_toolkit.trainable_infrastructure.pytorch.activation_quantizers import ( + STESymmetricActivationTrainableQuantizer, STEUniformActivationTrainableQuantizer, + LSQSymmetricActivationTrainableQuantizer, LSQUniformActivationTrainableQuantizer) class TestPytorchBaseWeightsQuantizer(BasePytorchInfrastructureTest): @@ -51,6 +57,8 @@ def run_test(self): weight_quantization_config = super(TestPytorchBaseWeightsQuantizer, self).get_weights_quantization_config() quantizer = ZeroWeightsQuantizer(weight_quantization_config) self.unit_test.assertTrue(quantizer.quantization_config == weight_quantization_config) + # unless implemented explicitly, by default quant params should not be frozen + self.unit_test.assertTrue(quantizer.freeze_quant_params is False) class TestPytorchBaseActivationQuantizer(BasePytorchInfrastructureTest): @@ -78,6 +86,30 @@ def run_test(self): activation_quantization_config = super(TestPytorchBaseActivationQuantizer, self).get_activation_quantization_config() quantizer = ZeroActivationsQuantizer(activation_quantization_config) self.unit_test.assertTrue(quantizer.quantization_config == activation_quantization_config) + # unless implemented explicitly, by default quant params should not be frozen + self.unit_test.assertTrue(quantizer.freeze_quant_params is False) + + +class TestPytorchSTEActivationQuantizerQParamFreeze(BasePytorchInfrastructureTest): + def run_test(self): + sym_qparams = {'is_signed': True, 'threshold': [1]} + self._run_test(STESymmetricActivationTrainableQuantizer, False, QuantizationMethod.POWER_OF_TWO, sym_qparams) + self._run_test(STESymmetricActivationTrainableQuantizer, True, QuantizationMethod.SYMMETRIC, sym_qparams) + + uniform_qparams = {'range_min': 0, 'range_max': 5} + self._run_test(STEUniformActivationTrainableQuantizer, False, QuantizationMethod.UNIFORM, uniform_qparams) + self._run_test(STEUniformActivationTrainableQuantizer, True, QuantizationMethod.UNIFORM, uniform_qparams) + + def _run_test(self, activation_quantizer_cls, freeze, quant_method, activation_quant_params): + quant_config = self.get_activation_quantization_config(quant_method=quant_method, + activation_quant_params=activation_quant_params) + quantizer = activation_quantizer_cls(quant_config, freeze_quant_params=freeze) + holder = PytorchActivationQuantizationHolder(quantizer) + quantizer.initialize_quantization(torch.Size((5,)), 'foo', holder) + self.unit_test.assertTrue(quantizer.freeze_quant_params is freeze) + self.unit_test.assertTrue(quantizer.quantizer_parameters) + for p in quantizer.quantizer_parameters.values(): + self.unit_test.assertTrue(p[VAR].requires_grad is not freeze) class _TestQuantizer(BasePytorchTrainableQuantizer):