Skip to content

Commit

Permalink
add tests for ste activation qparams freezing
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Sep 4, 2024
1 parent aa1dd8b commit 4fa358d
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,10 @@ def get_weights_quantization_config(self):
weights_per_channel_threshold=True,
min_threshold=0)

def get_activation_quantization_config(self):
return TrainableQuantizerActivationConfig(activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
def get_activation_quantization_config(self, quant_method=QuantizationMethod.POWER_OF_TWO,
activation_quant_params=None):
return TrainableQuantizerActivationConfig(activation_quantization_method=quant_method,
activation_n_bits=8,
activation_quantization_params={},
activation_quantization_params=activation_quant_params or {},
enable_activation_quantization=True,
min_threshold=0)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from model_compression_toolkit.trainable_infrastructure.pytorch.base_pytorch_quantizer import \
BasePytorchTrainableQuantizer
from tests.pytorch_tests.trainable_infrastructure_tests.trainable_pytorch.test_pytorch_base_quantizer import \
TestPytorchBaseWeightsQuantizer, TestPytorchBaseActivationQuantizer, TestPytorchQuantizerWithoutMarkDecorator
TestPytorchBaseWeightsQuantizer, TestPytorchBaseActivationQuantizer, TestPytorchQuantizerWithoutMarkDecorator, \
TestPytorchSTEActivationQuantizerQParamFreeze
from tests.pytorch_tests.trainable_infrastructure_tests.trainable_pytorch.test_pytorch_get_quantizers import \
TestGetTrainableQuantizer

Expand All @@ -46,6 +47,9 @@ def test_pytorch_base_quantizer(self):
TestPytorchBaseActivationQuantizer(self).run_test()
TestPytorchQuantizerWithoutMarkDecorator(self).run_test()

def test_pytorch_ste_activation_quantizers_qparams_freeze(self):
TestPytorchSTEActivationQuantizerQParamFreeze(self).run_test()

def test_pytorch_get_quantizers(self):
TestGetTrainableQuantizer(self, quant_target=QuantizationTarget.Weights,
quant_method=QuantizationMethod.POWER_OF_TWO,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,20 @@
# ==============================================================================
from typing import List, Any

from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
import torch

from mct_quantizers import PytorchActivationQuantizationHolder
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup, VAR
from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import \
TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig
from model_compression_toolkit.trainable_infrastructure.pytorch.base_pytorch_quantizer import \
BasePytorchTrainableQuantizer
from tests.pytorch_tests.trainable_infrastructure_tests.base_pytorch_trainable_infra_test import \
BasePytorchInfrastructureTest, ZeroWeightsQuantizer, ZeroActivationsQuantizer
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.trainable_infrastructure.pytorch.activation_quantizers import (
STESymmetricActivationTrainableQuantizer, STEUniformActivationTrainableQuantizer,
LSQSymmetricActivationTrainableQuantizer, LSQUniformActivationTrainableQuantizer)


class TestPytorchBaseWeightsQuantizer(BasePytorchInfrastructureTest):
Expand Down Expand Up @@ -51,6 +57,8 @@ def run_test(self):
weight_quantization_config = super(TestPytorchBaseWeightsQuantizer, self).get_weights_quantization_config()
quantizer = ZeroWeightsQuantizer(weight_quantization_config)
self.unit_test.assertTrue(quantizer.quantization_config == weight_quantization_config)
# unless implemented explicitly, by default quant params should not be frozen
self.unit_test.assertTrue(quantizer.freeze_quant_params is False)


class TestPytorchBaseActivationQuantizer(BasePytorchInfrastructureTest):
Expand Down Expand Up @@ -78,6 +86,30 @@ def run_test(self):
activation_quantization_config = super(TestPytorchBaseActivationQuantizer, self).get_activation_quantization_config()
quantizer = ZeroActivationsQuantizer(activation_quantization_config)
self.unit_test.assertTrue(quantizer.quantization_config == activation_quantization_config)
# unless implemented explicitly, by default quant params should not be frozen
self.unit_test.assertTrue(quantizer.freeze_quant_params is False)


class TestPytorchSTEActivationQuantizerQParamFreeze(BasePytorchInfrastructureTest):
def run_test(self):
sym_qparams = {'is_signed': True, 'threshold': [1]}
self._run_test(STESymmetricActivationTrainableQuantizer, False, QuantizationMethod.POWER_OF_TWO, sym_qparams)
self._run_test(STESymmetricActivationTrainableQuantizer, True, QuantizationMethod.SYMMETRIC, sym_qparams)

uniform_qparams = {'range_min': 0, 'range_max': 5}
self._run_test(STEUniformActivationTrainableQuantizer, False, QuantizationMethod.UNIFORM, uniform_qparams)
self._run_test(STEUniformActivationTrainableQuantizer, True, QuantizationMethod.UNIFORM, uniform_qparams)

def _run_test(self, activation_quantizer_cls, freeze, quant_method, activation_quant_params):
quant_config = self.get_activation_quantization_config(quant_method=quant_method,
activation_quant_params=activation_quant_params)
quantizer = activation_quantizer_cls(quant_config, freeze_quant_params=freeze)
holder = PytorchActivationQuantizationHolder(quantizer)
quantizer.initialize_quantization(torch.Size((5,)), 'foo', holder)
self.unit_test.assertTrue(quantizer.freeze_quant_params is freeze)
self.unit_test.assertTrue(quantizer.quantizer_parameters)
for p in quantizer.quantizer_parameters.values():
self.unit_test.assertTrue(p[VAR].requires_grad is not freeze)


class _TestQuantizer(BasePytorchTrainableQuantizer):
Expand Down

0 comments on commit 4fa358d

Please sign in to comment.