From aa1dd8b3dd20b4ba52187017120b45bb709d15ed Mon Sep 17 00:00:00 2001 From: irenab Date: Wed, 21 Aug 2024 20:05:43 +0300 Subject: [PATCH] add a flag to freeze quant params in base trainable quantizer and ste activation quanizers --- .../pytorch/quantizer/quantization_builder.py | 14 +++++----- .../common/base_trainable_quantizer.py | 15 +++++----- .../ste/symmetric_ste.py | 12 ++++---- .../activation_quantizers/ste/uniform_ste.py | 13 +++++---- .../pytorch/base_pytorch_quantizer.py | 8 ++++++ ...est_activation_quantization_holder_gptq.py | 28 +++++++++++++------ .../model_tests/feature_models/qat_test.py | 18 +++++++----- 7 files changed, 68 insertions(+), 40 deletions(-) diff --git a/model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py b/model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py index 7a72e8fea..ba5e590c1 100644 --- a/model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +++ b/model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py @@ -25,8 +25,9 @@ from mct_quantizers.pytorch.quantizers import BasePyTorchInferableQuantizer from model_compression_toolkit.logger import Logger +from model_compression_toolkit.trainable_infrastructure import TrainingMethod, BasePytorchActivationTrainableQuantizer from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \ - get_trainable_quantizer_weights_config + get_trainable_quantizer_weights_config, get_trainable_quantizer_activation_config from model_compression_toolkit.trainable_infrastructure.common.get_quantizers import \ get_trainable_quantizer_class @@ -68,12 +69,11 @@ def quantization_builder(n: common.BaseNode, quant_method = n.final_activation_quantization_cfg.activation_quantization_method - quantizer_class = get_inferable_quantizer_class(quant_target=QuantizationTarget.Activation, + quantizer_class = get_trainable_quantizer_class(quant_target=QuantizationTarget.Activation, + quantizer_id=TrainingMethod.STE, quant_method=quant_method, - quantizer_base_class=BasePyTorchInferableQuantizer) - - kwargs = get_activation_inferable_quantizer_kwargs(n.final_activation_quantization_cfg) - - activation_quantizers.append(quantizer_class(**kwargs)) + quantizer_base_class=BasePytorchActivationTrainableQuantizer) + cfg = get_trainable_quantizer_activation_config(n, None) + activation_quantizers.append(quantizer_class(cfg, freeze_quant_params=True)) return weights_quantizers, activation_quantizers diff --git a/model_compression_toolkit/trainable_infrastructure/common/base_trainable_quantizer.py b/model_compression_toolkit/trainable_infrastructure/common/base_trainable_quantizer.py index aefd8c561..b8e7d0933 100644 --- a/model_compression_toolkit/trainable_infrastructure/common/base_trainable_quantizer.py +++ b/model_compression_toolkit/trainable_infrastructure/common/base_trainable_quantizer.py @@ -14,18 +14,16 @@ # ============================================================================== from abc import ABC, abstractmethod from enum import Enum -from typing import Union, List, Any from inspect import signature - -from model_compression_toolkit.logger import Logger +from typing import Union, List, Any from mct_quantizers.common.base_inferable_quantizer import BaseInferableQuantizer, \ QuantizationTarget -from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import \ - TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig from mct_quantizers.common.constants import QUANTIZATION_METHOD, \ QUANTIZATION_TARGET - +from model_compression_toolkit.logger import Logger +from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import \ + TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig VAR = 'var' GROUP = 'group' @@ -43,12 +41,14 @@ class VariableGroup(Enum): class BaseTrainableQuantizer(BaseInferableQuantizer, ABC): def __init__(self, - quantization_config: Union[TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig]): + quantization_config: Union[TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig], + freeze_quant_params: bool = False): """ This class is a base quantizer which validates the provided quantization config and defines an abstract function which any quantizer needs to implment. Args: quantization_config: quantizer config class contains all the information about the quantizer configuration. + freeze_quant_params: whether to freeze all learnable quantization parameters during training. """ # verify the quantizer class that inherits this class only has a config argument and key-word arguments @@ -85,6 +85,7 @@ def __init__(self, f"Unrecognized 'QuantizationTarget': {static_quantization_target}.") # pragma: no cover self.quantizer_parameters = {} + self.freeze_quant_params = freeze_quant_params @classmethod def get_sig(cls): diff --git a/model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/ste/symmetric_ste.py b/model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/ste/symmetric_ste.py index e4ab53630..1899a41db 100644 --- a/model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/ste/symmetric_ste.py +++ b/model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/ste/symmetric_ste.py @@ -18,7 +18,8 @@ import torch from torch import nn -from mct_quantizers import mark_quantizer, QuantizationTarget, QuantizationMethod, PytorchQuantizationWrapper +from mct_quantizers import mark_quantizer, QuantizationTarget, QuantizationMethod, PytorchQuantizationWrapper, \ + PytorchActivationQuantizationHolder from mct_quantizers.pytorch.quantizers import ActivationPOTInferableQuantizer, ActivationSymmetricInferableQuantizer from model_compression_toolkit import constants as C from model_compression_toolkit.core.pytorch.utils import to_torch_tensor @@ -39,14 +40,15 @@ class STESymmetricActivationTrainableQuantizer(BasePytorchActivationTrainableQua Trainable constrained quantizer to quantize a layer activations. """ - def __init__(self, quantization_config: TrainableQuantizerActivationConfig): + def __init__(self, quantization_config: TrainableQuantizerActivationConfig, freeze_quant_params: bool = False): """ Initialize a STESymmetricActivationTrainableQuantizer object with parameters to use for symmetric or power of two quantization. Args: quantization_config: trainable quantizer config class + freeze_quant_params: whether to freeze learnable quantization parameters """ - super().__init__(quantization_config) + super().__init__(quantization_config, freeze_quant_params) self.power_of_two = quantization_config.activation_quantization_method == QuantizationMethod.POWER_OF_TWO self.sign = quantization_config.activation_quantization_params['is_signed'] np_threshold_values = quantization_config.activation_quantization_params[C.THRESHOLD] @@ -56,7 +58,7 @@ def __init__(self, quantization_config: TrainableQuantizerActivationConfig): def initialize_quantization(self, tensor_shape: torch.Size, name: str, - layer: PytorchQuantizationWrapper): + layer: PytorchActivationQuantizationHolder): """ Add quantizer parameters to the quantizer parameters dictionary @@ -66,7 +68,7 @@ def initialize_quantization(self, layer: Layer to quantize. """ layer.register_parameter(name, nn.Parameter(to_torch_tensor(self.threshold_tensor), - requires_grad=True)) + requires_grad=not self.freeze_quant_params)) # save the quantizer added parameters for later calculations self.add_quantizer_variable(THRESHOLD_TENSOR, layer.get_parameter(name), VariableGroup.QPARAMS) diff --git a/model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/ste/uniform_ste.py b/model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/ste/uniform_ste.py index 12c754697..dd18ef5a0 100644 --- a/model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/ste/uniform_ste.py +++ b/model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/ste/uniform_ste.py @@ -36,14 +36,15 @@ class STEUniformActivationTrainableQuantizer(BasePytorchActivationTrainableQuant Trainable constrained quantizer to quantize a layer activations. """ - def __init__(self, quantization_config: TrainableQuantizerActivationConfig): + def __init__(self, quantization_config: TrainableQuantizerActivationConfig, freeze_quant_params: bool = False): """ Initialize a STEUniformActivationTrainableQuantizer object with parameters to use for uniform quantization. Args: - quantization_config: trainable quantizer config class + quantization_config: trainable quantizer config class. + freeze_quant_params: whether to freeze learnable quantization parameters. """ - super().__init__(quantization_config) + super().__init__(quantization_config, freeze_quant_params) np_min_range = quantization_config.activation_quantization_params[C.RANGE_MIN] np_max_range = quantization_config.activation_quantization_params[C.RANGE_MAX] @@ -56,7 +57,7 @@ def initialize_quantization(self, name: str, layer: PytorchQuantizationWrapper): """ - Add quantizer parameters to the quantizer parameters dictionary + Add quantizer parameters to the quantizer parameters dictionary. Args: tensor_shape: tensor shape of the quantized tensor. @@ -64,9 +65,9 @@ def initialize_quantization(self, layer: Layer to quantize. """ layer.register_parameter(name+"_"+FQ_MIN, nn.Parameter(to_torch_tensor(self.min_range_tensor), - requires_grad=True)) + requires_grad=not self.freeze_quant_params)) layer.register_parameter(name+"_"+FQ_MAX, nn.Parameter(to_torch_tensor(self.max_range_tensor), - requires_grad=True)) + requires_grad=not self.freeze_quant_params)) # Save the quantizer parameters for later calculations self.add_quantizer_variable(FQ_MIN, layer.get_parameter(name+"_"+FQ_MIN), VariableGroup.QPARAMS) diff --git a/model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py b/model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py index f41658a1c..01c9dcb30 100644 --- a/model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +++ b/model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py @@ -46,6 +46,14 @@ def get_trainable_variables(self, group: VariableGroup) -> List[torch.Tensor]: quantizer_parameter, parameter_group = parameter_dict[VAR], parameter_dict[GROUP] if quantizer_parameter.requires_grad and parameter_group == group: quantizer_trainable.append(quantizer_parameter) + + # sanity check to catch inconsistent initialization + if self.freeze_quant_params and group == VariableGroup.QPARAMS and quantizer_trainable: + Logger.critical( + 'Found trainable quantization params despite self.freeze_quant_params=True. ' + 'Quantization parameters were probably not initialized correctly in the Quantizer.' + ) # pragma: no cover + return quantizer_trainable else: diff --git a/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py b/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py index 874caa5a7..e8ec00189 100644 --- a/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py +++ b/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py @@ -2,9 +2,11 @@ import unittest import torch +from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup + from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper -from mct_quantizers.pytorch.quantizers import ActivationPOTInferableQuantizer from torch.nn import Conv2d +from torch.fx import symbolic_trace import numpy as np import model_compression_toolkit as mct @@ -13,7 +15,8 @@ from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation from model_compression_toolkit.gptq.pytorch.gptq_training import PytorchGPTQTrainer from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc -from torch.fx import symbolic_trace +from model_compression_toolkit.trainable_infrastructure import TrainingMethod +from model_compression_toolkit.trainable_infrastructure.pytorch.activation_quantizers import STESymmetricActivationTrainableQuantizer from tests.common_tests.helpers.prep_graph_for_func_test import prepare_graph_with_quantization_parameters @@ -73,7 +76,12 @@ def test_adding_holder_instead_quantize_wrapper(self): # check that 4 activation quantization holders where generated self.assertTrue(len(activation_quantization_holders_in_model) == 3) for a in activation_quantization_holders_in_model: - self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer)) + self.assertTrue(isinstance(a.activation_holder_quantizer, STESymmetricActivationTrainableQuantizer)) + self.assertEquals(a.activation_holder_quantizer.identifier, TrainingMethod.STE) + # activation quantization params for gptq should be frozen (non-learnable) + self.assertTrue(a.activation_holder_quantizer.freeze_quant_params is True) + self.assertEquals(a.activation_holder_quantizer.get_trainable_variables(VariableGroup.QPARAMS), []) + for name, module in gptq_model.named_modules(): if isinstance(module, PytorchQuantizationWrapper): self.assertTrue(len(module.weights_quantizers) > 0) @@ -87,7 +95,7 @@ def test_adding_holder_after_relu(self): # check that 3 activation quantization holders where generated self.assertTrue(len(activation_quantization_holders_in_model) == 3) for a in activation_quantization_holders_in_model: - self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer)) + self.assertTrue(isinstance(a.activation_holder_quantizer, STESymmetricActivationTrainableQuantizer)) for name, module in gptq_model.named_modules(): if isinstance(module, PytorchQuantizationWrapper): self.assertTrue(len(module.weights_quantizers) > 0) @@ -102,14 +110,18 @@ def test_adding_holders_after_reuse(self): # check that 4 activation quantization holders where generated self.assertTrue(len(activation_quantization_holders_in_model) == 3) for a in activation_quantization_holders_in_model: - self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer)) + self.assertTrue(isinstance(a.activation_holder_quantizer, STESymmetricActivationTrainableQuantizer)) for name, module in gptq_model.named_modules(): if isinstance(module, PytorchQuantizationWrapper): self.assertTrue(len(module.weights_quantizers) > 0) # Test that two holders are getting inputs from reused conv2d (the layer that is wrapped) - fx_model = symbolic_trace(gptq_model) - self.assertTrue(list(fx_model.graph.nodes)[3].all_input_nodes[0] == list(fx_model.graph.nodes)[2]) - self.assertTrue(list(fx_model.graph.nodes)[6].all_input_nodes[0] == list(fx_model.graph.nodes)[5]) + + # FIXME there is no reuse support and the test doesn't test what it says it tests. It doesn't even look + # at correct layers. After moving to trainable quantizer the test makes even less sense since now fx traces + # all quantization operations instead of fake_quant layer. + # fx_model = symbolic_trace(gptq_model) + # self.assertTrue(list(fx_model.graph.nodes)[3].all_input_nodes[0] == list(fx_model.graph.nodes)[2]) + # self.assertTrue(list(fx_model.graph.nodes)[6].all_input_nodes[0] == list(fx_model.graph.nodes)[5]) def _get_gptq_model(self, input_shape, in_model): pytorch_impl = GPTQPytorchImplemantation() diff --git a/tests/pytorch_tests/model_tests/feature_models/qat_test.py b/tests/pytorch_tests/model_tests/feature_models/qat_test.py index fd61ff2cb..75def1874 100644 --- a/tests/pytorch_tests/model_tests/feature_models/qat_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/qat_test.py @@ -23,26 +23,26 @@ from torch import Tensor import model_compression_toolkit as mct -import model_compression_toolkit.trainable_infrastructure.common.training_method from mct_quantizers import PytorchActivationQuantizationHolder, QuantizationTarget, PytorchQuantizationWrapper +from mct_quantizers.common.base_inferable_quantizer import QuantizerID from mct_quantizers.common.get_all_subclasses import get_all_subclasses from mct_quantizers.pytorch.quantizers import BasePyTorchInferableQuantizer from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device from model_compression_toolkit.core.pytorch.utils import to_torch_tensor -from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_weight_quantizer import BasePytorchQATWeightTrainableQuantizer +from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_weight_quantizer import \ + BasePytorchQATWeightTrainableQuantizer +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc, \ + get_op_quantization_configs +from model_compression_toolkit.trainable_infrastructure import TrainingMethod +from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup from model_compression_toolkit.trainable_infrastructure.pytorch.activation_quantizers.base_activation_quantizer import \ BasePytorchActivationTrainableQuantizer from model_compression_toolkit.trainable_infrastructure.pytorch.activation_quantizers.ste.symmetric_ste import \ STESymmetricActivationTrainableQuantizer -from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc -from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \ - get_op_quantization_configs from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model, \ generate_tp_model_with_activation_mp from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest from tests.pytorch_tests.tpc_pytorch import get_mp_activation_pytorch_tpc_dict -from mct_quantizers.common.base_inferable_quantizer import QuantizerID -from model_compression_toolkit.trainable_infrastructure import TrainingMethod def dummy_train(qat_ready_model, x, y): @@ -179,6 +179,10 @@ def compare(self, ptq_model, qat_ready_model, qat_finalized_model, input_x=None, and self.activation_quantization_method in _q.quantization_method] self.unit_test.assertTrue(len(q) == 1) self.unit_test.assertTrue(isinstance(layer.activation_holder_quantizer, q[0])) + # quantization params in qat should be trainable (not frozen) + self.unit_test.assertFalse(layer.activation_holder_quantizer.freeze_quant_params) + trainable_params = layer.activation_holder_quantizer.get_trainable_variables(VariableGroup.QPARAMS) + self.unit_test.assertTrue(len(trainable_params) > 0) elif isinstance(layer, PytorchQuantizationWrapper) and isinstance(layer.layer, nn.Conv2d): q = [_q for _q in all_qat_weight_quantizers if _q.identifier == self.training_method and _q.quantization_target == QuantizationTarget.Weights