From 104445edd2a9389d44fa04ea9f86685126b7f3bb Mon Sep 17 00:00:00 2001 From: Ariel Lapid <57916763+lapid92@users.noreply.github.com> Date: Wed, 6 Nov 2024 16:49:04 +0200 Subject: [PATCH] Add feature Activation Bias Correction (#1256) * Add feature Activation Bias Correction --- .../core/common/framework_implementation.py | 73 ++++--- .../quantization/node_quantization_config.py | 2 + .../quantization/quantization_config.py | 2 + ...ply_activation_bias_correction_to_graph.py | 81 ++++++++ ...ute_activation_bias_correction_of_graph.py | 190 ++++++++++++++++++ .../statistics_correction.py | 16 +- .../core/keras/keras_implementation.py | 25 ++- ...ute_activation_bias_correction_of_graph.py | 67 ++++++ .../core/pytorch/pytorch_implementation.py | 21 ++ ...ute_activation_bias_correction_of_graph.py | 57 ++++++ model_compression_toolkit/core/runner.py | 8 + .../activation_bias_correction_test.py | 91 +++++++++ .../test_features_runner.py | 83 +++++--- .../activation_bias_correction_test.py | 132 ++++++++++++ .../model_tests/test_feature_models_runner.py | 38 +++- 15 files changed, 828 insertions(+), 58 deletions(-) create mode 100644 model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py create mode 100644 model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py create mode 100644 model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py create mode 100644 model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py create mode 100644 tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py create mode 100644 tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py diff --git a/model_compression_toolkit/core/common/framework_implementation.py b/model_compression_toolkit/core/common/framework_implementation.py index 76ddd917b..ad7758d46 100644 --- a/model_compression_toolkit/core/common/framework_implementation.py +++ b/model_compression_toolkit/core/common/framework_implementation.py @@ -64,7 +64,7 @@ def get_hessian_scores_calculator(self, Returns: HessianScoresCalculator to use for the hessian approximation scores computation for this request. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_hessian_scores_calculator method.') # pragma: no cover @abstractmethod @@ -77,7 +77,7 @@ def to_numpy(self, tensor: Any) -> np.ndarray: Returns: Numpy array converted from the input tensor. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s to_numpy method.') # pragma: no cover @abstractmethod @@ -90,7 +90,7 @@ def to_tensor(self, tensor: np.ndarray) -> Any: Returns: Framework's tensor converted from the input Numpy array. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s to_tensor method.') # pragma: no cover @abstractmethod @@ -106,7 +106,7 @@ def model_reader(self, Returns: Graph representing the input model. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s model_reader method.') # pragma: no cover @abstractmethod @@ -131,7 +131,7 @@ def model_builder(self, Returns: A tuple with the model and additional relevant supporting objects. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s model_builder method.') # pragma: no cover @abstractmethod @@ -148,7 +148,7 @@ def run_model_inference(self, Returns: The frameworks model's output. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s run_model_inference method.') # pragma: no cover @abstractmethod @@ -167,9 +167,28 @@ def shift_negative_correction(self, Returns: Graph after SNC. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s apply_shift_negative_correction method.') # pragma: no cover + @abstractmethod + def compute_activation_bias_correction(self, + graph: Graph, + quant_config: QuantizationConfig, + fw_info: FrameworkInfo) -> Graph: + """ + Compute activation bias correction on a graph. + + Args: + graph: Graph to apply activation bias correction on. + quant_config: QuantizationConfig of how the model should be quantized. + fw_info: FrameworkInfo object with information about the specific framework's model. + + Returns: + Graph after activation bias correction computing. + """ + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' + f'framework\'s compute_activation_bias_correction method.') # pragma: no cover + @abstractmethod def get_substitutions_channel_equalization(self, quant_config: QuantizationConfig, @@ -184,7 +203,7 @@ def get_substitutions_channel_equalization(self, Returns: A list of the framework substitutions used after we collect statistics. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover @abstractmethod @@ -194,7 +213,7 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List Returns: A list of the framework substitutions used to prepare the graph. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover @abstractmethod @@ -208,7 +227,7 @@ def get_substitutions_pre_statistics_collection(self, quant_config: Quantization Returns: A list of the framework substitutions used before we collect statistics. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_pre_statistics_collection method.') # pragma: no cover @abstractmethod @@ -216,7 +235,7 @@ def get_linear_collapsing_substitution(self) -> common.BaseSubstitution: """ Returns: linear collapsing substitution """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover @abstractmethod @@ -224,7 +243,7 @@ def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution: """ Returns: conv2d add const collapsing substitution """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_op2d_add_const_collapsing_substitution method.') # pragma: no cover @abstractmethod @@ -239,7 +258,7 @@ def get_substitutions_statistics_correction(self, quant_config: QuantizationConf Returns: A list of the framework substitutions used for statistics correction. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_statistics_correction method.') # pragma: no cover @abstractmethod @@ -247,7 +266,7 @@ def get_residual_collapsing_substitution(self) -> List[common.BaseSubstitution]: """ Returns: A list of the framework substitutions used for residual collapsing """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_residual_collapsing_substitution method.') # pragma: no cover @@ -263,7 +282,7 @@ def get_substitutions_post_statistics_collection(self, quant_config: Quantizatio Returns: A list of the framework substitutions used after we collect statistics. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_post_statistics_collection method.') # pragma: no cover @abstractmethod @@ -272,7 +291,7 @@ def get_substitutions_virtual_weights_activation_coupling(self) -> List[common.B Returns: A list of Keras substitutions used to build a virtual graph with composed activation-weights pairs. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_virtual_weights_activation_coupling ' f'method.') # pragma: no cover @@ -288,7 +307,7 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz Returns: A list of the framework substitutions used after we apply second moment statistics. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_after_second_moment_correction ' f'method.') # pragma: no cover @@ -316,7 +335,7 @@ def get_sensitivity_evaluator(self, A function that computes the metric. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover def get_node_prior_info(self, node: BaseNode, @@ -334,7 +353,7 @@ def get_node_prior_info(self, node: BaseNode, NodePriorInfo with information about the node. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_node_prior_info method.') # pragma: no cover def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool: @@ -345,7 +364,7 @@ def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool Returns: True if the node should be considered an interest point, False otherwise. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s count_node_for_mixed_precision_interest_points method.') # pragma: no cover def get_mp_node_distance_fn(self, n: BaseNode, @@ -364,7 +383,7 @@ def get_mp_node_distance_fn(self, n: BaseNode, Returns: A distance function between two tensors and a axis on which the distance is computed (if exists). """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_mp_node_distance_fn method.') # pragma: no cover @@ -381,7 +400,7 @@ def is_output_node_compatible_for_hessian_score_computation(self, """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s is_output_node_compatible_for_hessian_score_computation method.') # pragma: no cover @abstractmethod @@ -398,7 +417,7 @@ def get_node_mac_operations(self, Returns: The MAC count of the operation """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_node_mac_operations method.') # pragma: no cover @abstractmethod @@ -419,7 +438,7 @@ def apply_second_moment_correction(self, Returns: A Graph after second moment correction. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s apply_second_moment_correction method.') # pragma: no cover @abstractmethod @@ -436,7 +455,7 @@ def sensitivity_eval_inference(self, Returns: The output of the model inference on the given input. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s sensitivity_eval_inference method.') # pragma: no cover def get_inferable_quantizers(self, node: BaseNode): @@ -452,9 +471,9 @@ def get_inferable_quantizers(self, node: BaseNode): """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_inferable_quantizers method.') # pragma: no cover - + @staticmethod def convert_data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int): """ diff --git a/model_compression_toolkit/core/common/quantization/node_quantization_config.py b/model_compression_toolkit/core/common/quantization/node_quantization_config.py index 0c68fb7b4..a790cbc77 100644 --- a/model_compression_toolkit/core/common/quantization/node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/node_quantization_config.py @@ -95,7 +95,9 @@ def __init__(self, self.activation_error_method = qc.activation_error_method self.activation_n_bits = op_cfg.activation_n_bits self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2 + self.activation_bias_correction_term = None self.enable_activation_quantization = op_cfg.enable_activation_quantization + self.quantization_preserving = op_cfg.quantization_preserving self.signedness = op_cfg.signedness self.activation_channel_equalization = qc.activation_channel_equalization self.input_scaling = qc.input_scaling diff --git a/model_compression_toolkit/core/common/quantization/quantization_config.py b/model_compression_toolkit/core/common/quantization/quantization_config.py index 8af7ee658..cf3a39976 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/quantization_config.py @@ -84,6 +84,8 @@ class QuantizationConfig: shift_negative_threshold_recalculation: bool = False shift_negative_params_search: bool = False concat_threshold_update: bool = False + activation_bias_correction: bool = False + activation_bias_correction_threshold: float = 0.0 # Default quantization configuration the library use. diff --git a/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py b/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py new file mode 100644 index 000000000..293e3dcce --- /dev/null +++ b/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py @@ -0,0 +1,81 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from model_compression_toolkit.core import CoreConfig, QuantizationConfig +from model_compression_toolkit.core.common import BaseNode, Graph +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig +from model_compression_toolkit.target_platform_capabilities.target_platform import AttributeQuantizationConfig + + +def apply_activation_bias_correction_to_graph(graph: Graph, + core_config: CoreConfig, + fw_impl: FrameworkImplementation) -> Graph: + """ + Get a graph, where each node has a final activation quantization configuration (with an activation bias + correction term in it), and apply the activation bias correction for each node in the graph. + + Args: + graph: Graph to apply activation bias correction to. + core_config: CoreConfig containing parameters of how the model should be quantized. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + + Returns: + Graph with activation bias correction apply to it's nodes. + """ + + for n in graph.nodes: + # Activation bias correction is only relevant for nodes with kernel op + kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0] + if core_config.quantization_config.activation_bias_correction and kernel_attr is not None and \ + n.final_activation_quantization_cfg.activation_bias_correction_term is not None: + # If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was + # calculated during model preparation, and is used now in the node's bias term. + _apply_activation_bias_correction_to_node(n, fw_impl, core_config.quantization_config) + return graph + + +def _apply_activation_bias_correction_to_node(node: BaseNode, + fw_impl: FrameworkImplementation, + qc: QuantizationConfig): + """ + Set new bias to node using the activation bias correction term that is stored in the + final activation quantization configuration. + + Args: + node: Node to set its corrected bias after activation bias correction. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + qc: QuantizationConfig containing parameters of how the model should be quantized. + + """ + correction = node.final_activation_quantization_cfg.activation_bias_correction_term + bias = node.get_weights_by_keys(fw_impl.constants.BIAS) # get original bias from node's weights + + if bias is None: + # If the layer has no bias, we set the bias as -correction. + node.set_weights_by_keys(fw_impl.constants.BIAS, - correction) + + # Mark the use_bias attribute of the node. + node.framework_attr[fw_impl.constants.USE_BIAS] = True + + # Configure the quantization of the bias as disabled. + node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS, + WeightsAttrQuantizationConfig( + qc, + AttributeQuantizationConfig( + enable_weights_quantization=False))) + else: + # If the layer has bias, we subtract the correction from original bias + node.set_weights_by_keys(fw_impl.constants.BIAS, bias - correction) diff --git a/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py new file mode 100644 index 000000000..bf753709b --- /dev/null +++ b/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py @@ -0,0 +1,190 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +from typing import Any, Callable + +from model_compression_toolkit.core import QuantizationConfig +from model_compression_toolkit.core.common import BaseNode, Graph +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.framework_info import FrameworkInfo + + +def get_previous_node_with_activation_quantization(linear_node: BaseNode, + graph: Graph) -> Any: + """ + Search recursively for the previous node with activation quantization. + + Args: + linear_node: Node to search for its previous node. + graph: Graph the node is in. + + Returns: + The previous node (if found) or None if it was not found or there are multiple incoming edges to one of + nodes during the search (which means, the substitution can not be applied). + """ + + prev_nodes = graph.get_prev_nodes(linear_node) + + if len(prev_nodes) != 1: + return None # pragma: no cover + + prev_node = prev_nodes[0] + + activation_quantization_config = prev_node.final_activation_quantization_cfg + + # Search for node with activation quantization + if (activation_quantization_config.enable_activation_quantization and + not activation_quantization_config.quantization_preserving): + return prev_node + else: + return get_previous_node_with_activation_quantization(prev_node, graph) + + +def calculate_bin_centers(bin_edges: np.ndarray) -> np.ndarray: + """ + Calculate the centers of bins given their edges. + + Args: + bin_edges: Array of bin edges. + + Returns: + np.ndarray: Array of bin centers. + """ + # Calculate the centers by averaging continuous bin edges + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0 + return bin_centers + + +def compute_activation_bias_correction(graph: Graph, + quant_config: QuantizationConfig, + fw_info: FrameworkInfo, + fw_impl: FrameworkImplementation, + linear_node: BaseNode, + prev_node: BaseNode, + kernel_size: str) -> Graph: + """ + Compute the activation bias correction term, and store it in the final activation + quantization configuration. + + Args: + graph: Graph with nodes to compute the activation bias correction for each node's final activation quantization configuration. + quant_config: QuantizationConfig of how the model should be quantized. + fw_info: Framework info like lists of nodes their kernel should quantized. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + linear_node: Node to compute the activation bias correction for. + prev_node: Node to compute the activation error caused by his activation quantization. + kernel_size: The framework specific attribute name of the convolution layer's kernel size. + + Returns: + Graph with activation bias correction term for each node. + """ + + # Retrieve the 'kernel_size' value if it exists and ensure it is None, 1, or (1, 1). This feature supports only + # Dense/Linear layers and convolution layers with kernel size of 1 or (1, 1). + # For Dense/Linear layers, which lack a 'kernel_size' attribute, the result will be None, and no restriction + # applies in that case. + if linear_node.framework_attr.get(kernel_size) not in [None, 1, (1, 1)]: + # If the kernel size is not 1 or (1, 1), return the current graph unmodified + return graph + + prev_node_act_quant_cfg = prev_node.final_activation_quantization_cfg + + # Check if the previous node's has activation quantization configuration and if the previous node have the + # histogram collector. + if prev_node_act_quant_cfg is None or not hasattr(graph.get_out_stats_collector(prev_node), 'hc'): + return graph # pragma: no cover + + float_bins, float_count = graph.get_out_stats_collector(prev_node).hc.get_histogram() + + # Calculate the centers of the float bins + float_centers = calculate_bin_centers(float_bins) + + # Quantize the bin edges and calculate the centers of the quantized bins + quant_bins = prev_node_act_quant_cfg.quantize_node_output(fw_impl.to_tensor(float_bins)) + quant_bins = fw_impl.to_numpy(quant_bins) + quant_centers = calculate_bin_centers(quant_bins) + + # Calculate the mean of the both the float and the quantized bin centers, weighted by the bin counts + mean_float_centers = np.sum(float_centers * float_count) / np.sum(float_count) + mean_quant_centers = np.sum(quant_centers * float_count) / np.sum(float_count) + + # Compute the difference between the mean quantized center and the mean float center + mean_diff = mean_quant_centers - mean_float_centers + + # Calculate the normalized bias as a percentage of the float center norm + float_centers_norm1 = np.abs(mean_float_centers) + normalized_bias = 100 * np.abs(mean_diff) / float_centers_norm1 + + # If the normalized bias is below the activation bias correction threshold, return the graph unmodified. + # By default, the threshold is set to 0, allowing all nodes to proceed in this case. + if normalized_bias < quant_config.activation_bias_correction_threshold: + return graph + + kernel = linear_node.get_weights_by_keys(fw_info.kernel_ops_attributes_mapping.get(linear_node.type)[0]) + + # Compute the activation bias correction by applying the quantization error to the kernel, resulting in an output + # size matching the number of output channels. + if kernel is not None: + + # Get the axes that are not the output channel. + output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(linear_node.type) + axis_not_output_channel = list(range(len(kernel.shape))) + axis_not_output_channel.remove(output_channel_index) + + # Special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters. + if output_channel_index == input_channel_index: + axis_not_output_channel.remove(3) # 3 is the depth multiplier index. + + activation_bias_correction_term = mean_diff * np.sum(kernel, axis=tuple(axis_not_output_channel)) + linear_node.final_activation_quantization_cfg.activation_bias_correction_term = ( + activation_bias_correction_term.flatten()) + return graph + + +def compute_activation_bias_correction_of_graph(graph: Graph, + quant_config: QuantizationConfig, + fw_info: FrameworkInfo, + fw_impl: FrameworkImplementation, + activation_bias_correction_node_matchers: Callable, + kernel_size: str) -> Graph: + """ + Compute the activation bias correction term for the graph. + + Args: + graph: Graph with nodes to compute the activation bias correction. + quant_config: QuantizationConfig of how the model should be quantized. + fw_info: Framework info like lists of nodes their kernel should quantized. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + activation_bias_correction_node_matchers: Function to match the layers for activation bias correction. + kernel_size: The framework specific attribute name of the convolution layer's kernel size. + + + Returns: + Graph with activation bias correction term for each relevant node. + """ + linear_node_types = activation_bias_correction_node_matchers() + + for n in graph.nodes: + if linear_node_types.apply(n): + prev_node = get_previous_node_with_activation_quantization(n, graph) + if prev_node is not None: + graph = compute_activation_bias_correction(graph=graph, + quant_config=quant_config, + fw_info=fw_info, + fw_impl=fw_impl, + linear_node=n, + prev_node=prev_node, + kernel_size=kernel_size) + return graph diff --git a/model_compression_toolkit/core/common/statistics_correction/statistics_correction.py b/model_compression_toolkit/core/common/statistics_correction/statistics_correction.py index e11aa0290..b07f3f3d1 100644 --- a/model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +++ b/model_compression_toolkit/core/common/statistics_correction/statistics_correction.py @@ -18,6 +18,8 @@ from model_compression_toolkit.core.common import Graph from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.quantization.core_config import CoreConfig +from model_compression_toolkit.core.common.statistics_correction.apply_activation_bias_correction_to_graph import \ + apply_activation_bias_correction_to_graph from model_compression_toolkit.core.common.statistics_correction.apply_bias_correction_to_graph import \ apply_bias_correction_to_graph from model_compression_toolkit.core.common.statistics_correction.apply_second_moment_correction_to_graph import \ @@ -73,7 +75,7 @@ def apply_statistics_correction(transformed_graph: Graph, fw_impl: FrameworkImplementation, tb_w: TensorboardWriter = None, ) -> Graph: """ - Apply statistics moment correction on graph. + Apply statistics correction on graph. Args: transformed_graph: Graph to apply statistics correction. representative_data_gen (Callable): Dataset used for calibration. @@ -84,7 +86,7 @@ def apply_statistics_correction(transformed_graph: Graph, tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc. Returns: - Graph after statistics correction correction. + Graph after statistics correction. """ ############################################# @@ -104,4 +106,14 @@ def apply_statistics_correction(transformed_graph: Graph, if tb_w is not None: tb_w.add_graph(transformed_graph, 'after_statistics_correction') + ############################################# + # Apply Activation Bias Correction + ############################################# + if core_config.quantization_config.activation_bias_correction: + transformed_graph = apply_activation_bias_correction_to_graph(graph=transformed_graph, + core_config=core_config, + fw_impl=fw_impl) + if tb_w is not None: + tb_w.add_graph(transformed_graph, 'after_activation_bias_correction') + return transformed_graph diff --git a/model_compression_toolkit/core/keras/keras_implementation.py b/model_compression_toolkit/core/keras/keras_implementation.py index 6020b137f..85fab7637 100644 --- a/model_compression_toolkit/core/keras/keras_implementation.py +++ b/model_compression_toolkit/core/keras/keras_implementation.py @@ -28,6 +28,8 @@ from model_compression_toolkit.core.keras.hessian.activation_hessian_scores_calculator_keras import \ ActivationHessianScoresCalculatorKeras from model_compression_toolkit.core.keras.hessian.weights_hessian_scores_calculator_keras import WeightsHessianScoresCalculatorKeras +from model_compression_toolkit.core.keras.statistics_correction.keras_compute_activation_bias_correction_of_graph import \ + keras_compute_activation_bias_correction_of_graph from model_compression_toolkit.exporter.model_wrapper.fw_agnostic.get_inferable_quantizers import \ get_inferable_quantizers from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \ @@ -84,7 +86,7 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.residual_collapsing import \ keras_residual_collapsing from model_compression_toolkit.core.keras.graph_substitutions.substitutions.input_scaling import InputScaling, \ - InputScalingWithPad + InputScalingWithPad from model_compression_toolkit.core.keras.graph_substitutions.substitutions.concat_threshold_update import ConcatThresholdUpdate from model_compression_toolkit.core.keras.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \ ReLUBoundToPowerOfTwo @@ -218,6 +220,25 @@ def shift_negative_correction(self, core_config, fw_info) + def compute_activation_bias_correction(self, + graph: Graph, + quant_config: QuantizationConfig, + fw_info: FrameworkInfo): + """ + Compute activation bias correction on a graph. + + Args: + graph: Graph to apply activation bias correction on. + quant_config: QuantizationConfig of how the model should be quantized. + fw_info: FrameworkInfo object with information about the specific framework's model. + + Returns: + Graph after activation bias correction computing. + """ + return keras_compute_activation_bias_correction_of_graph(graph=graph, + quant_config=quant_config, + fw_info=fw_info, + fw_impl=self) def get_substitutions_channel_equalization(self, quant_config: QuantizationConfig, @@ -309,7 +330,7 @@ def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution: """ return keras_op2d_add_const_collapsing() - def get_substitutions_post_statistics_collection(self, + def get_substitutions_post_statistics_collection(self, quant_config: QuantizationConfig) -> List[common.BaseSubstitution]: """ Return a list of the framework substitutions used after we collect statistics. diff --git a/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py new file mode 100644 index 000000000..ce1ac9c23 --- /dev/null +++ b/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py @@ -0,0 +1,67 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf +from packaging import version + +from model_compression_toolkit.core.keras.constants import KERNEL_SIZE + +if version.parse(tf.__version__) >= version.parse("2.13"): + from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose +else: + from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose + +from model_compression_toolkit.core import QuantizationConfig +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common import Graph +from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher +from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \ + compute_activation_bias_correction_of_graph + + +def activation_bias_correction_node_matchers(): + # Match linear layers where we can add a correction. + linear_node = NodeOperationMatcher(Conv2D) | \ + NodeOperationMatcher(Dense) | \ + NodeOperationMatcher(DepthwiseConv2D) | \ + NodeOperationMatcher(Conv2DTranspose) + return linear_node + + +def keras_compute_activation_bias_correction_of_graph(graph: Graph, + quant_config: QuantizationConfig, + fw_info: FrameworkInfo, + fw_impl: FrameworkImplementation) -> Graph: + """ + Compute the activation bias correction term for graph based on a Keras model. + + Args: + graph: Graph with nodes to compute the activation bias correction. + quant_config: QuantizationConfig of how the model should be quantized. + fw_info: Framework info like lists of nodes their kernel should quantized. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + + Returns: + Graph with activation bias correction term for each relevant node. + """ + graph = compute_activation_bias_correction_of_graph(graph=graph, + quant_config=quant_config, + fw_info=fw_info, + fw_impl=fw_impl, + activation_bias_correction_node_matchers= + activation_bias_correction_node_matchers, + kernel_size=KERNEL_SIZE) + return graph diff --git a/model_compression_toolkit/core/pytorch/pytorch_implementation.py b/model_compression_toolkit/core/pytorch/pytorch_implementation.py index 38e03dd08..5ec26a66d 100644 --- a/model_compression_toolkit/core/pytorch/pytorch_implementation.py +++ b/model_compression_toolkit/core/pytorch/pytorch_implementation.py @@ -92,6 +92,8 @@ from model_compression_toolkit.core.pytorch.reader.reader import model_reader from model_compression_toolkit.core.pytorch.statistics_correction.apply_second_moment_correction import \ pytorch_apply_second_moment_correction +from model_compression_toolkit.core.pytorch.statistics_correction.pytorch_compute_activation_bias_correction_of_graph import \ + pytorch_compute_activation_bias_correction_of_graph from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model from model_compression_toolkit.exporter.model_wrapper.fw_agnostic.get_inferable_quantizers import \ get_inferable_quantizers @@ -212,6 +214,25 @@ def shift_negative_correction(self, core_config, fw_info) + def compute_activation_bias_correction(self, + graph: Graph, + quant_config: QuantizationConfig, + fw_info: FrameworkInfo): + """ + Compute activation bias correction on a graph. + + Args: + graph: Graph to apply activation bias correction on. + quant_config: QuantizationConfig of how the model should be quantized. + fw_info: FrameworkInfo object with information about the specific framework's model. + + Returns: + Graph after activation bias correction computing. + """ + return pytorch_compute_activation_bias_correction_of_graph(graph=graph, + quant_config=quant_config, + fw_info=fw_info, + fw_impl=self) def get_substitutions_channel_equalization(self, quant_config: QuantizationConfig, diff --git a/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py new file mode 100644 index 000000000..149050cf2 --- /dev/null +++ b/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py @@ -0,0 +1,57 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from torch.nn import Conv2d, Linear, ConvTranspose2d + +from model_compression_toolkit.core import QuantizationConfig +from model_compression_toolkit.core.common import Graph +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher +from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \ + compute_activation_bias_correction_of_graph +from model_compression_toolkit.core.pytorch.constants import KERNEL_SIZE + + +def activation_bias_correction_node_matchers(): + # Match linear layers where we can add a correction. + linear_node = NodeOperationMatcher(Linear) | NodeOperationMatcher(Conv2d) | NodeOperationMatcher(ConvTranspose2d) + return linear_node + + +def pytorch_compute_activation_bias_correction_of_graph(graph: Graph, + quant_config: QuantizationConfig, + fw_info: FrameworkInfo, + fw_impl: FrameworkImplementation) -> Graph: + """ + Compute the activation bias correction term for graph based on a PyTorch model. + + Args: + graph: Graph with nodes to compute the activation bias correction. + quant_config: QuantizationConfig of how the model should be quantized. + fw_info: Framework info like lists of nodes their kernel should quantized. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + + Returns: + Graph with activation bias correction term for each relevant node. + """ + graph = compute_activation_bias_correction_of_graph(graph=graph, + quant_config=quant_config, + fw_info=fw_info, + fw_impl=fw_impl, + activation_bias_correction_node_matchers= + activation_bias_correction_node_matchers, + kernel_size=KERNEL_SIZE) + return graph diff --git a/model_compression_toolkit/core/runner.py b/model_compression_toolkit/core/runner.py index 66b25f080..65ef60176 100644 --- a/model_compression_toolkit/core/runner.py +++ b/model_compression_toolkit/core/runner.py @@ -164,6 +164,14 @@ def core_runner(in_model: Any, tg, bit_widths_config) + ###################################### + # Compute Activation Bias Correction + ###################################### + if core_config.quantization_config.activation_bias_correction: + tg = fw_impl.compute_activation_bias_correction(graph=tg, + quant_config=core_config.quantization_config, + fw_info=fw_info) + # Edit the graph again after finalizing the configurations. # This is since some actions regard the final configuration and should be edited. edit_network_graph(tg, fw_info, core_config.debug_config.network_editor) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py new file mode 100644 index 000000000..e48e828ed --- /dev/null +++ b/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py @@ -0,0 +1,91 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from model_compression_toolkit.core import QuantizationConfig +from model_compression_toolkit.core.keras.constants import KERNEL_SIZE +from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest + +import tensorflow as tf +import numpy as np + +from tests.keras_tests.utils import get_layers_from_model_by_type + +keras = tf.keras +layers = keras.layers + +""" +This test checks the Activation Bias Correction feature. +""" + + +class BaseActivationBiasCorrectionTest(BaseKerasFeatureNetworkTest): + """ + This test checks the Activation Bias Correction feature. + """ + + def __init__(self, unit_test, + prev_layer, + bypass_layer_list, + linear_layer, + activation_bias_correction_threshold=0.0): + super().__init__(unit_test) + self.prev_layer = prev_layer + self.bypass_layer_list = bypass_layer_list + self.linear_layer = linear_layer + self.activation_bias_correction_threshold = activation_bias_correction_threshold + + def get_quantization_config(self): + return QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + shift_negative_activation_correction=False, + activation_bias_correction=True, + activation_bias_correction_threshold=self.activation_bias_correction_threshold) + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = self.prev_layer(inputs) + + for bypass_layer in self.bypass_layer_list: + x = bypass_layer(x) + + outputs = self.linear_layer(x) + return keras.Model(inputs=inputs, outputs=outputs) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + float_linear_layers = get_layers_from_model_by_type(float_model, type(self.linear_layer)) + quantized_linear_layers = get_layers_from_model_by_type(quantized_model, type(self.linear_layer)) + + bias = float_linear_layers[-1].bias + bias_after_activation_bias_correction = quantized_linear_layers[-1].layer.bias + + y = float_model.predict(input_x) + y_hat = quantized_model.predict(input_x) + + self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') + + if getattr(float_linear_layers[-1], KERNEL_SIZE, None) in [None, 1, (1, 1)]: + if self.activation_bias_correction_threshold > 1e8: + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias " + f"value in case of activation_bias_correction_threshold " + f"{self.activation_bias_correction_threshold}.") + + else: + self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected a change in the bias " + f"value.") + else: + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias value " + f"in case of conv with kernel different than 1 or (1, 1).") diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 45c8cef3e..59336b057 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -28,6 +28,8 @@ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod from model_compression_toolkit.gptq import RoundingType from model_compression_toolkit.target_platform_capabilities import constants as C +from tests.keras_tests.feature_networks_tests.feature_networks.activation_bias_correction_test import \ + BaseActivationBiasCorrectionTest from tests.keras_tests.feature_networks_tests.feature_networks.activation_decomposition_test import \ ActivationDecompositionTest from tests.keras_tests.feature_networks_tests.feature_networks.activation_relu_bound_to_power_of_2_test import \ @@ -63,7 +65,8 @@ RequiresMixedPrecision, RequiresMixedPrecisionWeights from tests.keras_tests.feature_networks_tests.feature_networks.mixed_precision_bops_test import \ MixedPrecisionBopsBasicTest, MixedPrecisionBopsAllWeightsLayersTest, MixedPrecisionWeightsOnlyBopsTest, \ - MixedPrecisionActivationOnlyBopsTest, MixedPrecisionBopsAndWeightsUtilizationTest, MixedPrecisionBopsAndActivationUtilizationTest, \ + MixedPrecisionActivationOnlyBopsTest, MixedPrecisionBopsAndWeightsUtilizationTest, \ + MixedPrecisionBopsAndActivationUtilizationTest, \ MixedPrecisionBopsAndTotalUtilizationTest, MixedPrecisionBopsWeightsActivationUtilizationTest, \ MixedPrecisionBopsMultipleOutEdgesTest from tests.keras_tests.feature_networks_tests.feature_networks.mixed_precision_tests import \ @@ -140,7 +143,8 @@ MixedPrecisionSearchPartWeightsLayersTest, MixedPrecisionDepthwiseTest, MixedPrecisionSearchLastLayerDistanceTest, \ MixedPrecisionSearchActivationNonConfNodesTest, MixedPrecisionSearchTotalMemoryNonConfNodesTest, \ MixedPrecisionCombinedNMSTest -from tests.keras_tests.feature_networks_tests.feature_networks.matmul_substitution_test import MatmulToDenseSubstitutionTest +from tests.keras_tests.feature_networks_tests.feature_networks.matmul_substitution_test import \ + MatmulToDenseSubstitutionTest from tests.keras_tests.feature_networks_tests.feature_networks.metadata_test import MetadataTest from tests.keras_tests.feature_networks_tests.feature_networks.tpc_test import TpcTest from tests.keras_tests.feature_networks_tests.feature_networks.const_representation_test import ConstRepresentationTest, \ @@ -150,8 +154,10 @@ AdvancedConstQuantizationTest, ConstQuantizationMultiInputTest from tests.keras_tests.feature_networks_tests.feature_networks.activation_16bit_test import Activation16BitTest, \ Activation16BitMixedPrecisionTest -from tests.keras_tests.feature_networks_tests.feature_networks.sigmoid_mul_substitution_test import SigMulSubstitutionTest -from tests.keras_tests.feature_networks_tests.feature_networks.conv_func_substitutions_test import ConvFuncSubstitutionsTest +from tests.keras_tests.feature_networks_tests.feature_networks.sigmoid_mul_substitution_test import \ + SigMulSubstitutionTest +from tests.keras_tests.feature_networks_tests.feature_networks.conv_func_substitutions_test import \ + ConvFuncSubstitutionsTest from model_compression_toolkit.qat.common.qat_config import TrainingMethod layers = tf.keras.layers @@ -167,7 +173,7 @@ def test_remove_identity(self): def test_per_tensor_weight_quantization(self): PerTensorWeightQuantizationTest(self).run_test() - + def test_single_relu_replacement(self): SingleReluReplacementTest(self).run_test() @@ -240,7 +246,7 @@ def test_mixed_precision_weights_only_activation_conf(self): def test_requires_mixed_recision(self): RequiresMixedPrecisionWeights(self, weights_memory=True).run_test() - RequiresMixedPrecision(self,activation_memory=True).run_test() + RequiresMixedPrecision(self, activation_memory=True).run_test() RequiresMixedPrecision(self, total_memory=True).run_test() RequiresMixedPrecision(self, bops=True).run_test() RequiresMixedPrecision(self).run_test() @@ -516,6 +522,27 @@ def test_conv2d_bn_concat(self): def test_activation_scaling_relu6(self): ReLUBoundToPOTNetTest(self).run_test() + def test_activation_bias_correction(self): + for use_bias in [False, True]: + for linear_layer in [layers.Dense(30, use_bias=use_bias), + layers.Conv2D(filters=3, kernel_size=1, use_bias=use_bias), + layers.DepthwiseConv2D(kernel_size=1, use_bias=use_bias, + bias_initializer='glorot_uniform', depth_multiplier=1), + layers.Conv2DTranspose(filters=3, kernel_size=1, use_bias=use_bias), + layers.Conv2D(filters=3, kernel_size=2, use_bias=use_bias)]: + for activation_bias_correction_threshold in [0.0, 1e-6, 1e9]: + for activation in ['gelu', 'swish']: + for bypass_layer in [[layers.ZeroPadding2D(2)], + [layers.Dropout(0.5), layers.Dropout(0.5)], + [layers.MaxPooling2D(pool_size=(2, 2), strides=2)], + [layers.Flatten(), layers.Reshape(target_shape=(8, 8, 3))]]: + BaseActivationBiasCorrectionTest(self, + prev_layer=layers.Activation(activation), + bypass_layer_list=bypass_layer, + linear_layer=linear_layer, + activation_bias_correction_threshold= + activation_bias_correction_threshold).run_test() + def test_layer_activation_softmax_shift(self): SoftmaxShiftTest(self, layers.Dense(20, activation='softmax'), None).run_test() @@ -535,9 +562,12 @@ def test_conv2dbn_folding(self): Conv2DBNFoldingTest(self).run_test() def test_bn_forward_folding(self): - BNForwardFoldingTest(self, layers.Conv2D(2, 1, bias_initializer='glorot_uniform'), True, is_dwconv=True).run_test() - BNForwardFoldingTest(self, layers.DepthwiseConv2D(1, bias_initializer='glorot_uniform'), True, is_dwconv=True).run_test() - BNForwardFoldingTest(self, layers.Conv2DTranspose(2, 1, bias_initializer='glorot_uniform'), True, is_dwconv=True).run_test() + BNForwardFoldingTest(self, layers.Conv2D(2, 1, bias_initializer='glorot_uniform'), True, + is_dwconv=True).run_test() + BNForwardFoldingTest(self, layers.DepthwiseConv2D(1, bias_initializer='glorot_uniform'), True, + is_dwconv=True).run_test() + BNForwardFoldingTest(self, layers.Conv2DTranspose(2, 1, bias_initializer='glorot_uniform'), True, + is_dwconv=True).run_test() BNForwardFoldingTest(self, layers.Conv2D(2, 2), False, is_dwconv=True).run_test() BNForwardFoldingTest(self, layers.DepthwiseConv2D((3, 1)), False, is_dwconv=True).run_test() BNForwardFoldingTest(self, layers.Conv2DTranspose(2, (1, 3)), False, is_dwconv=True).run_test() @@ -582,10 +612,14 @@ def test_const_quantization(self): for qmethod in [QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC, QuantizationMethod.UNIFORM]: for error_method in [QuantizationErrorMethod.MSE, QuantizationErrorMethod.NOCLIPPING]: ConstQuantizationTest(self, func, c, qmethod=qmethod, error_method=error_method).run_test() - ConstQuantizationTest(self, func, c, input_reverse_order=True, qmethod=qmethod, error_method=error_method).run_test() - ConstQuantizationTest(self, func, c, input_reverse_order=True, use_kwargs=True, qmethod=qmethod, error_method=error_method).run_test() - ConstQuantizationTest(self, func, c, use_kwargs=True, qmethod=qmethod, error_method=error_method).run_test() - ConstQuantizationTest(self, func, 5.1, input_reverse_order=True, qmethod=qmethod, error_method=error_method).run_test() + ConstQuantizationTest(self, func, c, input_reverse_order=True, qmethod=qmethod, + error_method=error_method).run_test() + ConstQuantizationTest(self, func, c, input_reverse_order=True, use_kwargs=True, qmethod=qmethod, + error_method=error_method).run_test() + ConstQuantizationTest(self, func, c, use_kwargs=True, qmethod=qmethod, + error_method=error_method).run_test() + ConstQuantizationTest(self, func, 5.1, input_reverse_order=True, qmethod=qmethod, + error_method=error_method).run_test() AdvancedConstQuantizationTest(self).run_test() ConstQuantizationMultiInputTest(self).run_test() @@ -607,7 +641,8 @@ def test_const_representation(self): for func in [layers.Add(), layers.Multiply(), layers.Subtract()]: ConstRepresentationTest(self, func, c, is_list_input=True).run_test() ConstRepresentationTest(self, func, c, input_reverse_order=True, is_list_input=True).run_test() - ConstRepresentationTest(self, func, c, input_reverse_order=True, use_kwargs=True, is_list_input=True).run_test() + ConstRepresentationTest(self, func, c, input_reverse_order=True, use_kwargs=True, + is_list_input=True).run_test() ConstRepresentationTest(self, func, c, use_kwargs=True, is_list_input=True).run_test() ConstRepresentationMultiInputTest(self).run_test() @@ -697,7 +732,6 @@ def test_gptq(self): # GradientPTQLearnRateZeroConvGroupTest(self).run_test() # GradientPTQWeightsUpdateConvGroupTest(self).run_test() - def test_gptq_conv_group_dilation(self): GradientPTQLearnRateZeroConvGroupDilationTest(self).run_test() GradientPTQWeightsUpdateConvGroupDilationTest(self).run_test() @@ -794,7 +828,8 @@ def test_qat(self): QuantizationAwareTrainingQuantizersTest(self).run_test() QuantizationAwareTrainingQuantizerHolderTest(self).run_test() QATWrappersMixedPrecisionCfgTest(self).run_test() - QATWrappersMixedPrecisionCfgTest(self, ru_weights=17920 * 4 / 8, ru_activation=5408 * 4 / 8, expected_mp_cfg=[0, 4, 1, 1]).run_test() + QATWrappersMixedPrecisionCfgTest(self, ru_weights=17920 * 4 / 8, ru_activation=5408 * 4 / 8, + expected_mp_cfg=[0, 4, 1, 1]).run_test() def test_bn_attributes_quantization(self): BNAttributesQuantization(self, quantize_linear=False).run_test() @@ -883,8 +918,8 @@ def test_exceptions_manual_selection(self): # Invalid inputs to API with self.assertRaises(Exception) as context: ManualBitWidthSelectionTest(self, - [NodeNameFilter('relu1'), NodeNameFilter('add1'), NodeNameFilter('add2')], - [2, 4]).run_test() + [NodeNameFilter('relu1'), NodeNameFilter('add1'), NodeNameFilter('add2')], + [2, 4]).run_test() # Check that the correct exception message was raised self.assertEqual(str(context.exception), "Configuration Error: The number of provided bit_width values 2 must match the number of filters 3, or a single bit_width value should be provided for all filters.") @@ -899,15 +934,15 @@ def test_manual_bit_width_selection(self): ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Add), 4).run_test() ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Add), 2).run_test() ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Conv2D), NodeTypeFilter(layers.Dense)], - [2, 4]).run_test() + [2, 4]).run_test() ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Conv2D), NodeTypeFilter(layers.Dense)], - [4, 4]).run_test() + [4, 4]).run_test() ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Conv2D), NodeTypeFilter(layers.Add)], - [2, 4]).run_test() + [2, 4]).run_test() ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeTypeFilter(layers.Conv2D)], - [4, 4]).run_test() + [4, 4]).run_test() ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeTypeFilter(layers.Dense)], - 4).run_test() + 4).run_test() ManualBitWidthSelectionTest(self, NodeNameFilter('input'), 4).run_test() ManualBitWidthSelectionTest(self, NodeNameFilter('conv1'), 4).run_test() ManualBitWidthSelectionTest(self, NodeNameFilter('fc'), 4).run_test() @@ -916,7 +951,7 @@ def test_manual_bit_width_selection(self): ManualBitWidthSelectionTest(self, NodeNameFilter('relu1'), 4).run_test() ManualBitWidthSelectionTest(self, [NodeNameFilter('add1'), NodeNameFilter('conv1')], [2, 4]).run_test() ManualBitWidthSelectionTest(self, [NodeNameFilter('add2'), NodeNameFilter('relu1')], 4).run_test() - ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeNameFilter('add2')],[4, 2]).run_test() + ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeNameFilter('add2')], [4, 2]).run_test() if __name__ == '__main__': diff --git a/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py b/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py new file mode 100644 index 000000000..460aad783 --- /dev/null +++ b/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py @@ -0,0 +1,132 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import torch +from torch.nn import GELU, Hardswish, AdaptiveAvgPool2d, ZeroPad2d, Linear, Conv2d + +import model_compression_toolkit as mct +from model_compression_toolkit.core.pytorch.constants import KERNEL_SIZE +from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, set_model +from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest + +""" +This test checks the Activation Bias Correction feature. +""" + + +class ActivationBiasCorrectionNet(torch.nn.Module): + """ + This is the network to test the Activation Bias Correction feature. + """ + + def __init__(self, + prev_layer, + linear_layer, + bypass_layers): + super(ActivationBiasCorrectionNet, self).__init__() + self.activation_layer = prev_layer + self.linear_layer = linear_layer + self.bypass_layers = torch.nn.ModuleList(bypass_layers) + + def forward(self, x): + x = self.activation_layer(x) + + for bypass_layer in self.bypass_layers: + x = bypass_layer(x) + x = self.linear_layer(x) + return x + +class ActivationBiasCorrectionPadNet(torch.nn.Module): + """ + This is the network to test the Activation Bias Correction feature with pooling/padding layers as a bypass layers. + """ + + def __init__(self): + super(ActivationBiasCorrectionPadNet, self).__init__() + self.activation_layer = Hardswish() + self.pooling_layer = AdaptiveAvgPool2d(output_size=6) + self.padding_layer = ZeroPad2d(padding=2) + self.linear_layer = Linear(10, 10) + + def forward(self, x): + x = self.activation_layer(x) + x = self.pooling_layer(x) + x = self.padding_layer(x) + x = self.linear_layer(x) + return x + + +class ActivationBiasCorrectionReshapeNet(torch.nn.Module): + """ + This is the network to test the Activation Bias Correction feature with reshape layers as a bypass layers. + """ + + def __init__(self): + super(ActivationBiasCorrectionReshapeNet, self).__init__() + self.activation_layer = GELU() + self.linear_layer = Conv2d(in_channels=8, out_channels=1, kernel_size=1) + + def forward(self, x): + x = self.activation_layer(x) + x = x.flatten() + x = x.reshape(8, 2, -1) + x = self.linear_layer(x) + return x + + +class BaseActivationBiasCorrectionTest(BasePytorchFeatureNetworkTest): + def __init__(self, unit_test, + model, + activation_bias_correction_threshold=0.0): + super().__init__(unit_test) + self.model = model + self.activation_bias_correction_threshold = activation_bias_correction_threshold + + def get_quantization_config(self): + return mct.core.QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + shift_negative_activation_correction=False, + activation_bias_correction=True, + activation_bias_correction_threshold= + self.activation_bias_correction_threshold) + + def create_networks(self): + return self.model + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + bias = float_model.linear_layer.bias.cpu().detach().numpy() + bias_after_activation_bias_correction = quantized_model.linear_layer.layer.bias.cpu().detach().numpy() + + set_model(float_model) + y = float_model(to_torch_tensor(input_x[0])) + y_hat = quantized_model(to_torch_tensor(input_x[0])) + + self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') + + if getattr(float_model.linear_layer, KERNEL_SIZE, None) in [None, 1, (1, 1)]: + if self.activation_bias_correction_threshold > 1e8: + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias " + f"value in case of activation_bias_correction_threshold " + f"{self.activation_bias_correction_threshold}.") + else: + self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected a change in the bias " + f"value.") + else: + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias value " + f"in case of conv with kernel different than 1 or (1, 1).") diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index 44f8193bb..b5c02f350 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -30,6 +30,9 @@ from model_compression_toolkit.trainable_infrastructure import TrainingMethod from tests.pytorch_tests.model_tests.feature_models.activation_16bit_test import Activation16BitTest, \ Activation16BitMixedPrecisionTest +from tests.pytorch_tests.model_tests.feature_models.activation_bias_correction_test import ( + BaseActivationBiasCorrectionTest, ActivationBiasCorrectionNet, ActivationBiasCorrectionPadNet, + ActivationBiasCorrectionReshapeNet) from tests.pytorch_tests.model_tests.feature_models.add_net_test import AddNetTest from tests.pytorch_tests.model_tests.feature_models.add_same_test import AddSameNetTest from tests.pytorch_tests.model_tests.feature_models.bn_attributes_quantization_test import BNAttributesQuantization @@ -302,13 +305,13 @@ def test_permute_substitution(self): This test checks the permute substitution feature """ PermuteSubstitutionTest(self).run_test() - + def test_reshape_substitution(self): """ This test checks the reshape substitution feature """ ReshapeSubstitutionTest(self).run_test() - + def test_constant_conv_substitution(self): """ This test checks the constant conv substitution feature @@ -438,6 +441,35 @@ def test_shift_negative_activation_net(self): for activation_layer in [torch.nn.Hardswish, torch.nn.GELU]: ShiftNegaviteActivationNetTest(self, activation_layer=activation_layer).run_test(seed=3) + def test_activation_bias_correction_net(self): + """ + This test checks the activation bias correction feature. + """ + model_list = [ActivationBiasCorrectionNet(prev_layer=nn.GELU(), + linear_layer=nn.Linear(8, 20), + bypass_layers=[nn.Dropout(0.5), nn.Dropout(0.5)]), + ActivationBiasCorrectionNet(prev_layer=nn.GELU(), + linear_layer=nn.Conv2d(3, 20, 1), + bypass_layers=[nn.Dropout(0.5), nn.Dropout(0.5)]), + ActivationBiasCorrectionNet(prev_layer=nn.GELU(), + linear_layer=nn.Conv2d(3, 8, (1, 1)), + bypass_layers=[nn.Dropout(0.5)]), + ActivationBiasCorrectionNet(prev_layer=nn.Hardswish(), + linear_layer=nn.ConvTranspose2d(3, 8, 1), + bypass_layers=[nn.Dropout(0.5)]), + ActivationBiasCorrectionNet(prev_layer=nn.GELU(), + linear_layer=nn.Conv2d(3, 8, 3), + bypass_layers=[nn.Dropout(0.5)]), + ActivationBiasCorrectionPadNet(), + ActivationBiasCorrectionReshapeNet()] + + for activation_bias_correction_threshold in [0.0, 1e-6, 1e9]: + for model in model_list: + BaseActivationBiasCorrectionTest(self, + model=model, + activation_bias_correction_threshold= + activation_bias_correction_threshold).run_test() + def test_split_concat_net(self): """ This test checks: @@ -715,7 +747,7 @@ def test_bn_attributes_quantization(self): """ BNAttributesQuantization(self, quantize_linear=False).run_test() BNAttributesQuantization(self, quantize_linear=True).run_test() - + def test_concat_threshold_update(self): ConcatUpdateTest(self).run_test()