diff --git a/model_compression_toolkit/qat/common/qat_config.py b/model_compression_toolkit/qat/common/qat_config.py index 0bbddb76b..1817191c4 100644 --- a/model_compression_toolkit/qat/common/qat_config.py +++ b/model_compression_toolkit/qat/common/qat_config.py @@ -45,9 +45,12 @@ class TrainingMethod(Enum): DQA - DNN Quantization with Attention. Includes a smooth quantization introduces by DQA method + LSQ - Learned Step size Quantization. Includes PowerOfTwo, symmetric & uniform quantizers: https://arxiv.org/pdf/1902.08153.pdf + """ STE = "STE", - DQA = "DQA" + DQA = "DQA", + LSQ = "LSQ" class QATConfig: diff --git a/model_compression_toolkit/qat/keras/quantizer/__init__.py b/model_compression_toolkit/qat/keras/quantizer/__init__.py index 2a7ff806e..981616857 100644 --- a/model_compression_toolkit/qat/keras/quantizer/__init__.py +++ b/model_compression_toolkit/qat/keras/quantizer/__init__.py @@ -15,3 +15,5 @@ import model_compression_toolkit.qat.keras.quantizer.ste_rounding.symmetric_ste import model_compression_toolkit.qat.keras.quantizer.ste_rounding.uniform_ste +import model_compression_toolkit.qat.keras.quantizer.lsq.symmetric_lsq +import model_compression_toolkit.qat.keras.quantizer.lsq.uniform_lsq diff --git a/model_compression_toolkit/qat/keras/quantizer/lsq/__init__.py b/model_compression_toolkit/qat/keras/quantizer/lsq/__init__.py new file mode 100644 index 000000000..2147ec284 --- /dev/null +++ b/model_compression_toolkit/qat/keras/quantizer/lsq/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py b/model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py new file mode 100644 index 000000000..8c508de21 --- /dev/null +++ b/model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py @@ -0,0 +1,254 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Union + +import numpy as np +import tensorflow as tf +from tensorflow.python.framework.tensor_shape import TensorShape +from model_compression_toolkit.constants import SIGNED + +from model_compression_toolkit.qat import TrainingMethod + +from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod +from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper +from mct_quantizers import QuantizationTarget, mark_quantizer +from model_compression_toolkit.qat.common import THRESHOLD_TENSOR +from model_compression_toolkit import constants as C + +from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer +from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \ + TrainableQuantizerActivationConfig +from mct_quantizers.keras.quantizers import WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer, \ + ActivationPOTInferableQuantizer, ActivationSymmetricInferableQuantizer +from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup +from model_compression_toolkit.qat.keras.quantizer.quant_utils import ste_round, grad_scale + + +def symmetric_lsq_quantizer(x: tf.Tensor, + thresholds: tf.Tensor, + num_bits: int, + sign: bool, + min_int: int, + max_int:int, + scale_factor: float) -> tf.Tensor: + """ + Symmetric quantizer according to LSQ algorithm: https://arxiv.org/pdf/1902.08153.pdf + Args: + x: input to quantize + thresholds: thresholds of quantization levels + num_bits: number of bits for quantization + sign: whether x is signed or not + min_int: min clipping integer value + max_int: max clipping integer value + scale_factor: grad scale of LSQ algorithm + Returns: + A quantized tensor + """ + delta = thresholds / (2 ** (num_bits - int(sign))) + delta_scaled = grad_scale(delta, scale_factor) + rounded = ste_round(x / delta_scaled) + clipped = tf.math.minimum(tf.math.maximum(rounded, min_int), max_int) + quantized = delta_scaled * clipped + return quantized + + +@mark_quantizer(quantization_target=QuantizationTarget.Weights, + quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC], + identifier=TrainingMethod.LSQ) +class LSQWeightQATQuantizer(BaseKerasQATTrainableQuantizer): + """ + Trainable constrained quantizer to quantize layer's weights. + """ + + def __init__(self, quantization_config: TrainableQuantizerWeightsConfig): + """ + Initialize a LSQWeightQATQuantizer object with parameters to use + for the quantization. + + Args: + quantization_config: trainable quantizer config class + """ + super().__init__(quantization_config) + self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO + self.threshold_values = np.array(quantization_config.weights_quantization_params[C.THRESHOLD]) + self.threshold_shape = self.threshold_values.shape + self.per_channel = self.quantization_config.weights_per_channel_threshold + self.channel_axis = self.quantization_config.weights_channels_axis + self.threshold_values = np.reshape(np.asarray(self.threshold_values), [-1]) if self.per_channel else float(self.threshold_values) + self.num_bits = self.quantization_config.weights_n_bits + n_pos_bits = self.num_bits - int(C.WEIGHTS_SIGNED) + self.min_int = -int(C.WEIGHTS_SIGNED) * (2 ** n_pos_bits) + self.max_int = 2 **n_pos_bits - 1 + self.scale_factor = 1.0 / np.sqrt(self.max_int * self.threshold_values.size) + if self.power_of_two: + self.threshold_values = np.power(2.0, np.ceil(np.log2(np.maximum(self.threshold_values, C.MIN_THRESHOLD)))) + + def initialize_quantization(self, + tensor_shape: TensorShape, + name: str, + layer: KerasTrainableQuantizationWrapper): + """ + Add quantizer parameters to the quantizer parameters dictionary + + Args: + tensor_shape: tensor shape of the quantized tensor. + name: Tensor name. + layer: Layer to quantize. + """ + ptq_threshold_tensor = layer.add_weight( + name + THRESHOLD_TENSOR, + shape=len(self.threshold_values) if self.per_channel else (), + initializer=tf.keras.initializers.Constant(1.0), + trainable=True) + ptq_threshold_tensor.assign(self.threshold_values) + + # save the quantizer added parameters for later calculations + self.add_quantizer_variable(THRESHOLD_TENSOR, ptq_threshold_tensor, VariableGroup.QPARAMS) + + def __call__(self, + inputs: tf.Tensor, + training: bool): + """ + Quantize a tensor. + Args: + inputs: Input tensor to quantize. + training: Whether the graph is in training mode. + weights: Dictionary of weights the quantizer can use to quantize the tensor. + **kwargs: Additional variables the quantizer may receive. + + Returns: + The quantized tensor. + """ + + thresholds = self.get_quantizer_variable(THRESHOLD_TENSOR) + q_tensor = symmetric_lsq_quantizer(inputs, thresholds, self.num_bits, C.WEIGHTS_SIGNED, self.min_int, self.max_int, self.scale_factor) + return q_tensor + + def convert2inferable(self) -> Union[WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer]: + """ + Convert quantizer to inferable quantizer. + + Returns: + BaseKerasInferableQuantizer object. + """ + if self.power_of_two: + thresholds = 2 ** np.ceil(np.log2(self.get_quantizer_variable(THRESHOLD_TENSOR).numpy())) + return WeightsPOTInferableQuantizer(num_bits=self.num_bits, + threshold=list(thresholds.flatten()), + per_channel=self.per_channel, + channel_axis=self.channel_axis, + input_rank=len(self.threshold_shape)) + else: + thresholds = self.get_quantizer_variable(THRESHOLD_TENSOR).numpy() + return WeightsSymmetricInferableQuantizer(num_bits=self.num_bits, + threshold=list(thresholds.flatten()), + per_channel=self.per_channel, + channel_axis=self.channel_axis, + input_rank=len(self.threshold_shape)) + + +@mark_quantizer(quantization_target=QuantizationTarget.Activation, + quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC], + identifier=TrainingMethod.LSQ) +class LSQActivationQATQuantizer(BaseKerasQATTrainableQuantizer): + """ + Trainable constrained quantizer to quantize layer activations. + """ + + def __init__(self, quantization_config: TrainableQuantizerActivationConfig): + """ + Initialize a LSQActivationQATQuantizer object with parameters to use + for the quantization. + + Args: + quantization_config: trainable quantizer config class + """ + super().__init__(quantization_config) + self.power_of_two = quantization_config.activation_quantization_method == QuantizationMethod.POWER_OF_TWO + self.threshold_values = float(quantization_config.activation_quantization_params[C.THRESHOLD]) + self.threshold_shape = np.asarray(self.threshold_values).shape + self.sign = quantization_config.activation_quantization_params[SIGNED] + self.num_bits = quantization_config.activation_n_bits + n_pos_bits = self.num_bits - int(self.sign) + self.min_int = -int(self.sign) * (2 ** n_pos_bits) + self.max_int = (2 ** n_pos_bits) - 1 + if self.power_of_two: + self.threshold_values = np.power(2.0, np.ceil(np.log2(np.maximum(self.threshold_values, C.MIN_THRESHOLD)))) + + + def initialize_quantization(self, + tensor_shape: TensorShape, + name: str, + layer: KerasTrainableQuantizationWrapper): + """ + Add quantizer parameters to the quantizer parameters dictionary + + Args: + tensor_shape: tensor shape of the quantized tensor. + name: Tensor name. + layer: Layer to quantize. + """ + ptq_threshold_tensor = layer.add_weight( + name + THRESHOLD_TENSOR, + shape=(), + initializer=tf.keras.initializers.Constant(1.0), + trainable=True) + ptq_threshold_tensor.assign(self.threshold_values) + + # save the quantizer added parameters for later calculations + self.add_quantizer_variable(THRESHOLD_TENSOR, ptq_threshold_tensor, VariableGroup.QPARAMS) + + def __call__(self, + inputs: tf.Tensor, + training: bool): + """ + Quantize a tensor. + Args: + inputs: Input tensor to quantize. + training: Whether the graph is in training mode. + + Returns: + The quantized tensor. + """ + + thresholds = self.get_quantizer_variable(THRESHOLD_TENSOR) + n_channels = inputs.shape[-1] + scale_factor = 1.0 / np.sqrt(self.max_int * n_channels) + q_tensor = symmetric_lsq_quantizer(inputs, thresholds, self.num_bits, self.sign, self.min_int, self.max_int, scale_factor) + return q_tensor + + def convert2inferable(self) -> Union[ActivationPOTInferableQuantizer, ActivationSymmetricInferableQuantizer]: + """ + Convert quantizer to inferable quantizer. + + Returns: + BaseKerasInferableQuantizer object. + """ + + if self.power_of_two: + thresholds = 2 ** np.ceil(np.log2(self.get_quantizer_variable(THRESHOLD_TENSOR).numpy())) + return ActivationPOTInferableQuantizer(num_bits=self.num_bits, + # In activation quantization is per-tensor only - thus we pass + # the threshold as a list with a len of 1 + threshold=[thresholds], + signed=self.sign) + else: + thresholds = self.get_quantizer_variable(THRESHOLD_TENSOR).numpy() + return ActivationSymmetricInferableQuantizer(num_bits=self.num_bits, + # In activation quantization is per-tensor only - thus we + # pass the threshold as a list with a len of 1 + threshold=[thresholds], + signed=self.sign) diff --git a/model_compression_toolkit/qat/keras/quantizer/lsq/uniform_lsq.py b/model_compression_toolkit/qat/keras/quantizer/lsq/uniform_lsq.py new file mode 100644 index 000000000..d89e701ba --- /dev/null +++ b/model_compression_toolkit/qat/keras/quantizer/lsq/uniform_lsq.py @@ -0,0 +1,250 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import tensorflow as tf +from tensorflow.python.framework.tensor_shape import TensorShape +from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX +from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX +from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper +from model_compression_toolkit.qat import TrainingMethod + +from mct_quantizers import mark_quantizer, QuantizationMethod, QuantizationTarget +from mct_quantizers.keras.quantizers import \ + BaseKerasInferableQuantizer, WeightsUniformInferableQuantizer, ActivationUniformInferableQuantizer + +from model_compression_toolkit import constants as C + +from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer +from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \ + TrainableQuantizerActivationConfig +from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup +from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import fix_range_to_include_zero +from model_compression_toolkit.qat.keras.quantizer.quant_utils import ste_round, grad_scale, adjust_range_to_include_zero + + +def uniform_lsq_quantizer(x: tf.Tensor, + min_range: tf.Tensor, + max_range: tf.Tensor, + num_bits: int, + min_int: int, + max_int:int, + scale_factor: float) -> tf.Tensor: + """ + Uniform quantizer according to LSQ algorithm: https://arxiv.org/pdf/1902.08153.pdf + Args: + x: input to quantize + min_range: min range of quantization values + max_range: min range of quantization values + num_bits: number of bits for quantization + min_int: min clipping integer value + max_int: max clipping integer value + scale_factor: grad scale of LSQ algorithm + Returns: + A quantized tensor + """ + min_range, max_range = adjust_range_to_include_zero(min_range, max_range, num_bits) + delta = (max_range - min_range) / (2 ** num_bits - 1) + delta_scaled = grad_scale(delta, scale_factor) + rounded = ste_round((x-min_range) / delta_scaled) + clipped = tf.math.minimum(tf.math.maximum(rounded, min_int), max_int) + quantized = delta_scaled * clipped + min_range + return quantized + + +@mark_quantizer(quantization_target=QuantizationTarget.Weights, + quantization_method=[QuantizationMethod.UNIFORM], + identifier=TrainingMethod.LSQ) +class LSQUniformWeightQATQuantizer(BaseKerasQATTrainableQuantizer): + """ + Trainable constrained quantizer to quantize layer's weights. + """ + + def __init__(self, quantization_config: TrainableQuantizerWeightsConfig): + """ + Initialize a LSQUniformWeightQATQuantizer object with parameters to use + for the quantization. + + Args: + quantization_config: a trainable quantizer config class with attributes for the quantization. + + """ + super().__init__(quantization_config) + self.num_bits = self.quantization_config.weights_n_bits + self.per_channel = self.quantization_config.weights_per_channel_threshold + self.channel_axis = self.quantization_config.weights_channels_axis + max_values = np.array(quantization_config.weights_quantization_params[RANGE_MAX]) + min_values = np.array(quantization_config.weights_quantization_params[RANGE_MIN]) + self.min_max_shape = np.asarray(max_values).shape + self.max_values = np.reshape(max_values, [-1]) if self.per_channel else float(max_values) + self.min_values = np.reshape(min_values, [-1]) if self.per_channel else float(min_values) + self.min_int = 0 + self.max_int = 2**self.num_bits - 1 + self.scale_factor = 1.0 / np.sqrt(self.max_int * self.max_values.size) + + + def initialize_quantization(self, + tensor_shape: TensorShape, + name: str, + layer: KerasTrainableQuantizationWrapper): + """ + Add quantizer parameters to the quantizer parameters dictionary + + Args: + tensor_shape: tensor shape of the quantized tensor. + name: Tensor name. + layer: Layer to quantize. + """ + fq_min = layer.add_weight( + name + FQ_MIN, + shape=len(self.min_values) if self.per_channel else (), + initializer=tf.keras.initializers.Constant(-1.0), + trainable=True) + fq_min.assign(self.min_values) + + fq_max = layer.add_weight( + name + FQ_MAX, + shape=len(self.max_values) if self.per_channel else (), + initializer=tf.keras.initializers.Constant(1.0), + trainable=True) + fq_max.assign(self.max_values) + + # save the quantizer added parameters for later calculations + self.add_quantizer_variable(FQ_MIN, fq_min, VariableGroup.QPARAMS) + self.add_quantizer_variable(FQ_MAX, fq_max, VariableGroup.QPARAMS) + + def __call__(self, inputs: tf.Tensor, + training: bool): + """ + Quantize a tensor. + Args: + inputs: Input tensor to quantize. + training: Whether the graph is in training mode. + + Returns: + The quantized tensor. + """ + + min_range = self.get_quantizer_variable(FQ_MIN) + max_range = self.get_quantizer_variable(FQ_MAX) + q_tensor = uniform_lsq_quantizer(inputs, min_range, max_range, self.num_bits, self.min_int, self.max_int, self.scale_factor) + return q_tensor + + def convert2inferable(self) -> BaseKerasInferableQuantizer: + """ + Convert quantizer to inferable quantizer. + + Returns: + BaseKerasInferableQuantizer object. + """ + min_range, max_range = fix_range_to_include_zero(self.get_quantizer_variable(FQ_MIN).numpy(), + self.get_quantizer_variable(FQ_MAX).numpy(), + self.num_bits) + return WeightsUniformInferableQuantizer(num_bits=self.num_bits, + min_range=list(min_range.flatten()), + max_range=list(max_range.flatten()), + per_channel=self.per_channel, + channel_axis=self.channel_axis, + input_rank=len(self.min_max_shape)) + + +@mark_quantizer(quantization_target=QuantizationTarget.Activation, + quantization_method=[QuantizationMethod.UNIFORM], + identifier=TrainingMethod.LSQ) +class LSQUniformActivationQATQuantizer(BaseKerasQATTrainableQuantizer): + """ + Trainable constrained quantizer to quantize layer activations. + """ + + def __init__(self, quantization_config: TrainableQuantizerActivationConfig): + """ + Initialize a LSQUniformActivationQATQuantizer object with parameters to use + for the quantization. + + Args: + quantization_config: trainable quantizer config class + """ + super().__init__(quantization_config) + + self.num_bits = quantization_config.activation_n_bits + self.min_range = np.array(quantization_config.activation_quantization_params[C.RANGE_MIN]) + self.max_range = np.array(quantization_config.activation_quantization_params[C.RANGE_MAX]) + self.min_int = 0 + self.max_int = 2**self.num_bits - 1 + + def initialize_quantization(self, + tensor_shape: TensorShape, + name: str, + layer: KerasTrainableQuantizationWrapper): + """ + Add quantizer parameters to the quantizer parameters dictionary + + Args: + tensor_shape: tensor shape of the quantized tensor. + name: Tensor name. + layer: Layer to quantize. + """ + fq_min = layer.add_weight( + name + FQ_MIN, + shape=(), + initializer=tf.keras.initializers.Constant(-1.0), + trainable=True) + fq_min.assign(self.min_range) + + fq_max = layer.add_weight( + name + FQ_MAX, + shape=(), + initializer=tf.keras.initializers.Constant(1.0), + trainable=True) + fq_max.assign(self.max_range) + + # save the quantizer added parameters for later calculations + self.add_quantizer_variable(FQ_MIN, fq_min, VariableGroup.QPARAMS) + self.add_quantizer_variable(FQ_MAX, fq_max, VariableGroup.QPARAMS) + + def __call__(self, + inputs: tf.Tensor, + training: bool): + """ + Quantize a tensor. + Args: + inputs: Input tensor to quantize. + training: Whether the graph is in training mode. + + Returns: + The quantized tensor. + """ + + min_range = self.get_quantizer_variable(FQ_MIN) + max_range = self.get_quantizer_variable(FQ_MAX) + n_channels = inputs.shape[-1] + scale_factor = 1.0 / np.sqrt(self.max_int * n_channels) + q_tensor = uniform_lsq_quantizer(inputs, min_range, max_range, self.num_bits, self.min_int, self.max_int, scale_factor) + return q_tensor + + def convert2inferable(self) -> BaseKerasInferableQuantizer: + """ + Convert quantizer to inferable quantizer. + + Returns: + BaseKerasInferableQuantizer object. + """ + min_range, max_range = fix_range_to_include_zero(self.get_quantizer_variable(FQ_MIN).numpy(), + self.get_quantizer_variable(FQ_MAX).numpy(), + self.num_bits) + return ActivationUniformInferableQuantizer(num_bits=self.num_bits, + # In activation quantization is per-tensor only - thus we pass + # the min/max as lists with a len of 1 + min_range=[min_range], + max_range=[max_range]) diff --git a/model_compression_toolkit/qat/keras/quantizer/quant_utils.py b/model_compression_toolkit/qat/keras/quantizer/quant_utils.py index b001e363c..533488c49 100644 --- a/model_compression_toolkit/qat/keras/quantizer/quant_utils.py +++ b/model_compression_toolkit/qat/keras/quantizer/quant_utils.py @@ -17,6 +17,23 @@ from typing import Tuple +def ste_round(x: tf.Tensor) -> tf.Tensor: + """ + Return the rounded values of a tensor. + """ + error = tf.stop_gradient(tf.math.round(x) - x) + return error + x + + +def grad_scale(x: tf.Tensor, scale=1.0) -> tf.Tensor: + """ + Return x in forward and x*scale in backward (for scaling the gradients). + """ + x_scaled = scale * x + error = tf.stop_gradient(x - x_scaled) + return error + x_scaled + + def adjust_range_to_include_zero(range_min: tf.Tensor, range_max: tf.Tensor, n_bits: int) -> Tuple[tf.Tensor, tf.Tensor]: diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py index ed2f7eabf..d8ec93f52 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py @@ -167,6 +167,7 @@ class QATWrappersTest(BaseKerasFeatureNetworkTest): def __init__(self, unit_test, layer, weight_bits=2, activation_bits=4, finalize=True, weights_quantization_method=mct.target_platform.QuantizationMethod.POWER_OF_TWO, activation_quantization_method=mct.target_platform.QuantizationMethod.POWER_OF_TWO, + training_method=mct.qat.TrainingMethod.STE, per_channel=True, test_loading=False): self.layer = layer @@ -177,6 +178,7 @@ def __init__(self, unit_test, layer, weight_bits=2, activation_bits=4, finalize= self.activation_quantization_method = activation_quantization_method self.per_channel = per_channel self.test_loading = test_loading + self.training_method = training_method super().__init__(unit_test) def get_tpc(self): @@ -190,11 +192,15 @@ def create_networks(self): outputs = self.layer(inputs) return keras.Model(inputs=inputs, outputs=outputs) + def get_qat_config(self): + return mct.qat.QATConfig(weight_training_method=self.training_method, activation_training_method=self.training_method) + def run_test(self, **kwargs): model_float = self.create_networks() ptq_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init(model_float, self.representative_data_gen, fw_info=self.get_fw_info(), + qat_config=self.get_qat_config(), target_platform_capabilities=self.get_tpc()) # PTQ model @@ -214,8 +220,7 @@ def run_test(self, **kwargs): quantization_info=quantization_info) out_qat_model = qat_model(in_tensor) - self.unit_test.assertTrue( - np.isclose(np.linalg.norm(out_qat_model - out_ptq_model) / np.linalg.norm(out_ptq_model), 0, atol=1e-6)) + self.unit_test.assertTrue(np.isclose(np.linalg.norm(out_qat_model - out_ptq_model) / np.linalg.norm(out_ptq_model), 0, atol=1e-6)) if self.finalize: # QAT finalize model @@ -233,9 +238,7 @@ def run_test(self, **kwargs): input_x=self.representative_data_gen(), quantization_info=quantization_info) out_qat_finalize_model = qat_finalize_model(in_tensor) - self.unit_test.assertTrue( - np.isclose(np.linalg.norm(out_qat_finalize_model - out_ptq_model) / np.linalg.norm(out_ptq_model), 0, - atol=1e-6)) + self.unit_test.assertTrue(np.isclose(np.linalg.norm(out_qat_finalize_model - out_ptq_model) / np.linalg.norm(out_ptq_model), 0, atol=1e-6)) def compare(self, qat_model, finalize=False, input_x=None, quantization_info=None): all_trainable_quantizers = get_all_subclasses(BaseKerasQATTrainableQuantizer) @@ -254,7 +257,7 @@ def compare(self, qat_model, finalize=False, input_x=None, quantization_info=Non else: self.unit_test.assertTrue(isinstance(layer.activation_holder_quantizer, BaseKerasTrainableQuantizer)) q = [_q for _q in all_trainable_quantizers if - _q.identifier == mct.qat.TrainingMethod.STE + _q.identifier == self.training_method and _q.quantization_target == QuantizationTarget.Activation and self.activation_quantization_method in _q.quantization_method and type(_q.identifier) == TrainingMethod] @@ -275,7 +278,7 @@ def compare(self, qat_model, finalize=False, input_x=None, quantization_info=Non self.unit_test.assertTrue(isinstance(layer.weights_quantizers[KERNEL], q[0])) else: self.unit_test.assertTrue(isinstance(quantizer, BaseKerasTrainableQuantizer)) - q = [_q for _q in all_trainable_quantizers if _q.identifier == mct.qat.TrainingMethod.STE + q = [_q for _q in all_trainable_quantizers if _q.identifier == self.training_method and _q.quantization_target == QuantizationTarget.Weights and self.weights_quantization_method in _q.quantization_method and type(_q.identifier) == TrainingMethod] diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 74e2f36fe..71dfb72df 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -124,6 +124,7 @@ MixedPercisionSearchLastLayerDistanceTest, MixedPercisionSearchActivationKPINonConfNodesTest, \ MixedPercisionSearchTotalKPINonConfNodesTest, MixedPercisionSearchPartWeightsLayersTest from tests.keras_tests.feature_networks_tests.feature_networks.old_api_test import OldApiTest +from model_compression_toolkit.qat.common.qat_config import TrainingMethod layers = tf.keras.layers @@ -690,6 +691,18 @@ def test_qat(self): weights_quantization_method=QuantizationMethod.SYMMETRIC, activation_quantization_method=QuantizationMethod.SYMMETRIC).run_test() QATWrappersTest(self, layers.Conv2DTranspose(3, 4, activation='relu')).run_test() + QATWrappersTest(self, layers.Conv2D(3, 4, activation='relu'), + weights_quantization_method=QuantizationMethod.SYMMETRIC, + activation_quantization_method=QuantizationMethod.SYMMETRIC, + training_method=TrainingMethod.LSQ).run_test() + QATWrappersTest(self, layers.Conv2D(3, 4, activation='relu'), + weights_quantization_method=QuantizationMethod.UNIFORM, + activation_quantization_method=QuantizationMethod.UNIFORM, + training_method=TrainingMethod.LSQ).run_test() + QATWrappersTest(self, layers.Dense(3, activation='relu'), + weights_quantization_method=QuantizationMethod.POWER_OF_TWO, + activation_quantization_method=QuantizationMethod.POWER_OF_TWO, + training_method=TrainingMethod.LSQ).run_test() # DW-Conv2D are tested under the tests below because an extra check is needed to verify the # quantization per channel of its kernel TODO: should be part of the quantizers tests QuantizationAwareTrainingQuantizersTest(self).run_test() diff --git a/tests/trainable_infrastructure_tests/keras/test_keras_trainable_infra_runner.py b/tests/trainable_infrastructure_tests/keras/test_keras_trainable_infra_runner.py index 6e47a74f4..e7a9a6141 100644 --- a/tests/trainable_infrastructure_tests/keras/test_keras_trainable_infra_runner.py +++ b/tests/trainable_infrastructure_tests/keras/test_keras_trainable_infra_runner.py @@ -22,13 +22,14 @@ STEActivationQATQuantizer from model_compression_toolkit.qat.keras.quantizer.ste_rounding.uniform_ste import STEUniformWeightQATQuantizer, \ STEUniformActivationQATQuantizer +from model_compression_toolkit.qat.keras.quantizer.lsq.uniform_lsq import LSQUniformActivationQATQuantizer, LSQUniformWeightQATQuantizer +from model_compression_toolkit.qat.keras.quantizer.lsq.symmetric_lsq import LSQActivationQATQuantizer, LSQWeightQATQuantizer from model_compression_toolkit.trainable_infrastructure import BaseKerasTrainableQuantizer from tests.trainable_infrastructure_tests.keras.trainable_keras.test_get_quantizers import \ TestGetTrainableQuantizer from tests.trainable_infrastructure_tests.keras.trainable_keras.test_keras_base_quantizer import TestKerasBaseWeightsQuantizer, \ TestKerasBaseActivationsQuantizer, TestKerasQuantizerWithoutMarkDecorator - layers = tf.keras.layers @@ -70,6 +71,26 @@ def test_get_quantizers(self): quantizer_base_class=BaseKerasTrainableQuantizer, quantizer_id=TrainingMethod.STE, expected_quantizer_class=STEUniformActivationQATQuantizer).run_test() + TestGetTrainableQuantizer(self, quant_target=QuantizationTarget.Weights, + quant_method=QuantizationMethod.SYMMETRIC, + quantizer_base_class=BaseKerasTrainableQuantizer, + quantizer_id=TrainingMethod.LSQ, + expected_quantizer_class=LSQWeightQATQuantizer).run_test() + TestGetTrainableQuantizer(self, quant_target=QuantizationTarget.Weights, + quant_method=QuantizationMethod.UNIFORM, + quantizer_base_class=BaseKerasTrainableQuantizer, + quantizer_id=TrainingMethod.LSQ, + expected_quantizer_class=LSQUniformWeightQATQuantizer).run_test() + TestGetTrainableQuantizer(self, quant_target=QuantizationTarget.Activation, + quant_method=QuantizationMethod.SYMMETRIC, + quantizer_base_class=BaseKerasTrainableQuantizer, + quantizer_id=TrainingMethod.LSQ, + expected_quantizer_class=LSQActivationQATQuantizer).run_test() + TestGetTrainableQuantizer(self, quant_target=QuantizationTarget.Activation, + quant_method=QuantizationMethod.UNIFORM, + quantizer_base_class=BaseKerasTrainableQuantizer, + quantizer_id=TrainingMethod.LSQ, + expected_quantizer_class=LSQUniformActivationQATQuantizer).run_test()