From 5905dfae382e6a817248ab169f2a43a958e396cc Mon Sep 17 00:00:00 2001 From: ariell Date: Mon, 4 Nov 2024 14:33:58 +0200 Subject: [PATCH 1/5] Add feture Activation Bias Correction --- .../core/common/framework_implementation.py | 21 +- .../quantization/quantization_config.py | 2 + ...ply_activation_bias_correction_to_graph.py | 80 +++++++ ...ute_activation_bias_correction_of_graph.py | 209 ++++++++++++++++++ .../statistics_correction.py | 16 +- .../core/keras/keras_implementation.py | 25 ++- ...ute_activation_bias_correction_of_graph.py | 88 ++++++++ .../core/pytorch/pytorch_implementation.py | 21 ++ ...ute_activation_bias_correction_of_graph.py | 71 ++++++ model_compression_toolkit/core/runner.py | 8 + .../activation_bias_correction_test.py | 161 ++++++++++++++ .../test_features_runner.py | 9 + .../activation_bias_correction_test.py | 159 +++++++++++++ .../model_tests/test_feature_models_runner.py | 24 +- 14 files changed, 886 insertions(+), 8 deletions(-) create mode 100644 model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py create mode 100644 model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py create mode 100644 model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py create mode 100644 model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py create mode 100644 tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py create mode 100644 tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py diff --git a/model_compression_toolkit/core/common/framework_implementation.py b/model_compression_toolkit/core/common/framework_implementation.py index 76ddd917b..3e63c561b 100644 --- a/model_compression_toolkit/core/common/framework_implementation.py +++ b/model_compression_toolkit/core/common/framework_implementation.py @@ -170,6 +170,25 @@ def shift_negative_correction(self, raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' f'framework\'s apply_shift_negative_correction method.') # pragma: no cover + @abstractmethod + def compute_activation_bias_correction(self, + graph: Graph, + core_config: CoreConfig, + fw_info: FrameworkInfo) -> Graph: + """ + Compute activation bias correction on a graph. + + Args: + graph: Graph to apply activation bias correction on. + core_config: QuantizationConfig of how the model should be quantized. + fw_info: FrameworkInfo object with information about the specific framework's model. + + Returns: + Graph after activation bias correction computing. + """ + raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + f'framework\'s compute_activation_bias_correction method.') # pragma: no cover + @abstractmethod def get_substitutions_channel_equalization(self, quant_config: QuantizationConfig, @@ -454,7 +473,7 @@ def get_inferable_quantizers(self, node: BaseNode): raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' f'framework\'s get_inferable_quantizers method.') # pragma: no cover - + @staticmethod def convert_data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int): """ diff --git a/model_compression_toolkit/core/common/quantization/quantization_config.py b/model_compression_toolkit/core/common/quantization/quantization_config.py index 8af7ee658..a940c4504 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/quantization_config.py @@ -70,6 +70,8 @@ class QuantizationConfig: weights_error_method: QuantizationErrorMethod = QuantizationErrorMethod.MSE relu_bound_to_power_of_2: bool = False weights_bias_correction: bool = True + activation_bias_correction: bool = False + activation_bias_correction_threshold: float = 0.0 weights_second_moment_correction: bool = False input_scaling: bool = False softmax_shift: bool = False diff --git a/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py b/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py new file mode 100644 index 000000000..83fcc14ef --- /dev/null +++ b/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py @@ -0,0 +1,80 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import copy + +from model_compression_toolkit.core import CoreConfig, QuantizationConfig +from model_compression_toolkit.core.common import BaseNode, Graph +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig +from model_compression_toolkit.target_platform_capabilities.target_platform import AttributeQuantizationConfig + + +def apply_activation_bias_correction_to_graph(graph: Graph, + core_config: CoreConfig, + fw_impl: FrameworkImplementation) -> Graph: + """ + Get a graph, where each node has a final activation quantization configuration (with a activation bias + correction term in it), and apply the activation bias correction for each node in the graph. + + Args: + graph: Graph to apply activation bias correction to. + core_config: CoreConfig containing parameters of how the model should be quantized. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + + Returns: + Graph with activation bias correction apply to it's nodes. + """ + + graph = copy.deepcopy(graph) + for n in graph.nodes: + # Activation bias correction is only relevant for nodes with kernel op + kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0] + if core_config.quantization_config.activation_bias_correction and kernel_attr is not None and \ + hasattr(n.final_activation_quantization_cfg, 'activation_bias_correction_term'): + # If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was + # calculated during model preparation, and is used now in the node's bias term. + _apply_activation_bias_correction_to_node(n, fw_impl, core_config.quantization_config) + return graph + + +def _apply_activation_bias_correction_to_node(node: BaseNode, + fw_impl: FrameworkImplementation, + qc: QuantizationConfig): + """ + Set new bias to node using the activation bias correction term that is stored in the + final activation quantization configuration. + + Args: + node: Node to set its corrected bias after activation bias correction. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + qc: QuantizationConfig containing parameters of how the model should be quantized. + + """ + correction = node.final_activation_quantization_cfg.activation_bias_correction_term + bias = node.get_weights_by_keys(fw_impl.constants.BIAS) # get original bias from node's weights + + if bias is not None: # If the layer has bias, we subtract the correction from original bias + node.set_weights_by_keys(fw_impl.constants.BIAS, bias - correction) + else: + # If the layer has no bias, we consider it as if it has and its value is 0 and add a "dummy" attribute + # configuration with disabled quantization. + node.set_weights_by_keys(fw_impl.constants.BIAS, - correction) + node.framework_attr[fw_impl.constants.USE_BIAS] = True # Mark the use_bias attribute of the node. + node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS, + WeightsAttrQuantizationConfig( + qc, + AttributeQuantizationConfig( + enable_weights_quantization=False))) diff --git a/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py new file mode 100644 index 000000000..e8ad63679 --- /dev/null +++ b/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py @@ -0,0 +1,209 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import List, Tuple, Any, Callable + +import numpy as np + +from model_compression_toolkit.core import CoreConfig +from model_compression_toolkit.core.common import BaseNode, Graph +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher + + +def get_next_nodes_to_correct(node: BaseNode, + graph: Graph, + linear_node_types: NodeOperationMatcher, + bypass_node_types: NodeOperationMatcher, + bypass_nodes: List = None) -> Tuple[Any, Any]: + """ + Search for the previous node which is not a bypass node of a given node. Go over the previous nodes of the node + and recursively search for a node. + + Args: + node: Node to search for its previous node. + graph: Graph the node is in. + linear_node_types: Types of linear nodes to consider. + bypass_node_types: Types of nodes for bypassing to consider. + bypass_nodes: a list of bypass nodes found while running this function + + Returns: The previous node (if found) and a list of bypass nodes (if any), or Nones if it were not found or there + are multiple incoming edges to one of nodes during the search (which means, the substitution can not be applied). + """ + + prev_nodes = graph.get_prev_nodes(node) + + if len(prev_nodes) != 1: + return None, None + + prev_node = prev_nodes[0] + + # If the previous node is not a bypass type, return it as the valid node along with any bypass nodes + if not bypass_node_types.apply(prev_node): + return prev_node, bypass_nodes + + # If the previous node is a bypass node type, add it to the bypass_nodes list and continue searching + if bypass_node_types.apply(prev_node): + if bypass_nodes: + bypass_nodes.append(prev_node) + else: + bypass_nodes = [prev_node] + return get_next_nodes_to_correct(node=prev_node, + graph=graph, + linear_node_types=linear_node_types, + bypass_node_types=bypass_node_types, + bypass_nodes=bypass_nodes) + return None, None + + +def calculate_bin_centers(bin_edges: np.ndarray) -> np.ndarray: + """ + Calculate the centers of bins given their edges. + + Parameters: + bin_edges: Array of bin edges. + + Returns: + np.ndarray: Array of bin centers. + """ + # Ensure bin_edges is a numpy array + bin_edges = np.array(bin_edges, dtype=np.float32) + + # Calculate the centers by averaging continuous bin edges + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0 + return bin_centers + + +def compute_activation_bias_correction(graph: Graph, + core_config: CoreConfig, + fw_info: FrameworkInfo, + fw_impl: FrameworkImplementation, + linear_node: BaseNode, + prev_node: BaseNode, + kernel_size: str) -> Graph: + """ + Compute the activation bias correction term, and store it in the final activation + quantization configuration. + + Args: + graph: Graph with nodes to compute the activation bias correction for each node's final activation quantization configuration. + core_config: Configuration object containing parameters of how the model should be quantized. + fw_info: Framework info like lists of nodes their kernel should quantized. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + linear_node: Node to compute the activation bias correction for. + prev_node: Node to compute the activation error caused by his activation quantization. + kernel_size: The framework specific attribute name of the convolution layer's kernel size. + + Returns: + Graph with activation bias correction term for each node. + """ + + # Check if 'kernel_size' is a key in the framework-specific attributes of the linear_node, if it is then the + # linear_node is a convolution + if kernel_size in linear_node.framework_attr.keys(): + # Retrieve the value of 'kernel_size' and check if it is not 1 or (1, 1). This feature supports only kernel + # size of 1 or (1, 1). + if linear_node.framework_attr.get(kernel_size) not in [1, (1, 1)]: + # If the kernel size is not 1 or (1, 1), return the current graph unmodified + return graph + + prev_node_act_quant_cfg = prev_node.final_activation_quantization_cfg + + # Check if the previous node's has activation quantization configuration and if the previous node have the + # histogram collector + if prev_node_act_quant_cfg is None or not hasattr(graph.get_out_stats_collector(prev_node), 'hc'): + return graph + + float_bins, float_count = graph.get_out_stats_collector(prev_node).hc.get_histogram() + + # Calculate the centers of the float bins + float_centers = calculate_bin_centers(float_bins) + + # Quantize the bin edges and calculate the centers of the quantized bins + quant_bins = prev_node_act_quant_cfg.quantize_node_output(fw_impl.to_tensor(float_bins)) + quant_bins = fw_impl.to_numpy(quant_bins) + quant_centers = calculate_bin_centers(quant_bins) + + # Calculate the mean of the both the float and the quantized bin centers, weighted by the bin counts + mean_float_centers = np.sum(float_centers * float_count) / np.sum(float_count) + mean_quant_centers = np.sum(quant_centers * float_count) / np.sum(float_count) + + # Compute the difference between the mean quantized center and the mean float center + mean_diff = mean_quant_centers - mean_float_centers + + # Check if activation bias correction is enabled based on the configured threshold + if core_config.quantization_config.activation_bias_correction_threshold > 0: + + # Calculate the normalized bias as a percentage of the float center norm + float_centers_norm1 = np.abs(mean_float_centers) + normalized_bias = 100 * np.abs(mean_diff) / float_centers_norm1 + + # If the normalized bias is below the activation bias correction threshold, return the unmodified graph + if normalized_bias < core_config.quantization_config.activation_bias_correction_threshold: + return graph + + # The correction term is a function of the layer type. + kernel = linear_node.get_weights_by_keys(fw_info.kernel_ops_attributes_mapping.get(linear_node.type)[0]) + + if kernel is not None: + output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(linear_node.type) + axis_not_output_channel = list(range(len(kernel.shape))) + axis_not_output_channel.remove(output_channel_index) + + if output_channel_index == input_channel_index: + axis_not_output_channel.remove(3) # 3 is the depth multiplier index + + activation_bias_correction_term = mean_diff * np.sum(kernel, axis=tuple(axis_not_output_channel)) + linear_node.final_activation_quantization_cfg.activation_bias_correction_term = activation_bias_correction_term.flatten() + return graph + + +def compute_activation_bias_correction_of_graph(graph: Graph, + core_config: CoreConfig, + fw_info: FrameworkInfo, + fw_impl: FrameworkImplementation, + activation_bias_correction_node_matchers: Callable, + kernel_size: str) -> Graph: + """ + Compute the activation bias correction term for the graph. + + Args: + graph: Graph with nodes to compute the activation bias correction. + core_config: Configuration object containing parameters of how the model should be quantized. + fw_info: Framework info like lists of nodes their kernel should quantized. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + activation_bias_correction_node_matchers: Function to match the layers for activation bias correction. + kernel_size: The framework specific attribute name of the convolution layer's kernel size. + + + Returns: + Graph with activation bias correction term for each relevant node. + """ + linear_node_types, bypass_node_types = activation_bias_correction_node_matchers() + + for n in graph.nodes: + if linear_node_types.apply(n): + prev_node, _ = get_next_nodes_to_correct(node=n, + graph=graph, + linear_node_types=linear_node_types, + bypass_node_types=bypass_node_types) + graph = compute_activation_bias_correction(graph=graph, + core_config=core_config, + fw_info=fw_info, + fw_impl=fw_impl, + linear_node=n, + prev_node=prev_node, + kernel_size=kernel_size) + return graph diff --git a/model_compression_toolkit/core/common/statistics_correction/statistics_correction.py b/model_compression_toolkit/core/common/statistics_correction/statistics_correction.py index e11aa0290..b07f3f3d1 100644 --- a/model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +++ b/model_compression_toolkit/core/common/statistics_correction/statistics_correction.py @@ -18,6 +18,8 @@ from model_compression_toolkit.core.common import Graph from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.quantization.core_config import CoreConfig +from model_compression_toolkit.core.common.statistics_correction.apply_activation_bias_correction_to_graph import \ + apply_activation_bias_correction_to_graph from model_compression_toolkit.core.common.statistics_correction.apply_bias_correction_to_graph import \ apply_bias_correction_to_graph from model_compression_toolkit.core.common.statistics_correction.apply_second_moment_correction_to_graph import \ @@ -73,7 +75,7 @@ def apply_statistics_correction(transformed_graph: Graph, fw_impl: FrameworkImplementation, tb_w: TensorboardWriter = None, ) -> Graph: """ - Apply statistics moment correction on graph. + Apply statistics correction on graph. Args: transformed_graph: Graph to apply statistics correction. representative_data_gen (Callable): Dataset used for calibration. @@ -84,7 +86,7 @@ def apply_statistics_correction(transformed_graph: Graph, tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc. Returns: - Graph after statistics correction correction. + Graph after statistics correction. """ ############################################# @@ -104,4 +106,14 @@ def apply_statistics_correction(transformed_graph: Graph, if tb_w is not None: tb_w.add_graph(transformed_graph, 'after_statistics_correction') + ############################################# + # Apply Activation Bias Correction + ############################################# + if core_config.quantization_config.activation_bias_correction: + transformed_graph = apply_activation_bias_correction_to_graph(graph=transformed_graph, + core_config=core_config, + fw_impl=fw_impl) + if tb_w is not None: + tb_w.add_graph(transformed_graph, 'after_activation_bias_correction') + return transformed_graph diff --git a/model_compression_toolkit/core/keras/keras_implementation.py b/model_compression_toolkit/core/keras/keras_implementation.py index 6020b137f..024c0bd54 100644 --- a/model_compression_toolkit/core/keras/keras_implementation.py +++ b/model_compression_toolkit/core/keras/keras_implementation.py @@ -28,6 +28,8 @@ from model_compression_toolkit.core.keras.hessian.activation_hessian_scores_calculator_keras import \ ActivationHessianScoresCalculatorKeras from model_compression_toolkit.core.keras.hessian.weights_hessian_scores_calculator_keras import WeightsHessianScoresCalculatorKeras +from model_compression_toolkit.core.keras.statistics_correction.keras_compute_activation_bias_correction_of_graph import \ + keras_compute_activation_bias_correction_of_graph from model_compression_toolkit.exporter.model_wrapper.fw_agnostic.get_inferable_quantizers import \ get_inferable_quantizers from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \ @@ -84,7 +86,7 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.residual_collapsing import \ keras_residual_collapsing from model_compression_toolkit.core.keras.graph_substitutions.substitutions.input_scaling import InputScaling, \ - InputScalingWithPad + InputScalingWithPad from model_compression_toolkit.core.keras.graph_substitutions.substitutions.concat_threshold_update import ConcatThresholdUpdate from model_compression_toolkit.core.keras.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \ ReLUBoundToPowerOfTwo @@ -218,6 +220,25 @@ def shift_negative_correction(self, core_config, fw_info) + def compute_activation_bias_correction(self, + graph: Graph, + core_config: CoreConfig, + fw_info: FrameworkInfo): + """ + Compute activation bias correction on a graph. + + Args: + graph: Graph to apply activation bias correction on. + core_config: QuantizationConfig of how the model should be quantized. + fw_info: FrameworkInfo object with information about the specific framework's model. + + Returns: + Graph after activation bias correction computing. + """ + return keras_compute_activation_bias_correction_of_graph(graph=graph, + core_config=core_config, + fw_info=fw_info, + fw_impl=self) def get_substitutions_channel_equalization(self, quant_config: QuantizationConfig, @@ -309,7 +330,7 @@ def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution: """ return keras_op2d_add_const_collapsing() - def get_substitutions_post_statistics_collection(self, + def get_substitutions_post_statistics_collection(self, quant_config: QuantizationConfig) -> List[common.BaseSubstitution]: """ Return a list of the framework substitutions used after we collect statistics. diff --git a/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py new file mode 100644 index 000000000..7dcdc6683 --- /dev/null +++ b/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py @@ -0,0 +1,88 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf +from packaging import version + +from model_compression_toolkit.core.keras.constants import KERNEL_SIZE + +if version.parse(tf.__version__) >= version.parse("2.13"): + from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ + MaxPooling2D, Flatten, Cropping2D, Permute, GlobalAveragePooling2D +else: + from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ + MaxPooling2D, Flatten, Cropping2D, Permute, GlobalAveragePooling2D + +from model_compression_toolkit.core import CoreConfig +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common import Graph +from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher +from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \ + compute_activation_bias_correction_of_graph + + +def activation_bias_correction_node_matchers(): + # Match linear layers where we can add a correction. + linear_node = NodeOperationMatcher(Conv2D) | \ + NodeOperationMatcher(Dense) | \ + NodeOperationMatcher(DepthwiseConv2D) + + # Match bypass layers what don't affect the quantization process. + bypass_node = (NodeOperationMatcher(Cropping2D) | + NodeOperationMatcher(GlobalAveragePooling2D) | + NodeOperationMatcher(Dropout) | + NodeOperationMatcher(ZeroPadding2D) | + NodeOperationMatcher(MaxPooling2D) | + NodeOperationMatcher(tf.split) | + NodeOperationMatcher(tf.cast) | + NodeOperationMatcher(tf.unstack) | + NodeOperationMatcher(tf.__operators__.getitem) | + NodeOperationMatcher(tf.strided_slice) | + NodeOperationMatcher(Reshape) | + NodeOperationMatcher(tf.reshape) | + NodeOperationMatcher(Permute) | + NodeOperationMatcher(tf.transpose) | + NodeOperationMatcher(Flatten) | + NodeOperationMatcher(tf.gather) | + NodeOperationMatcher(tf.compat.v1.gather)) + + return linear_node, bypass_node + + +def keras_compute_activation_bias_correction_of_graph(graph: Graph, + core_config: CoreConfig, + fw_info: FrameworkInfo, + fw_impl: FrameworkImplementation) -> Graph: + """ + Compute the activation bias correction term for graph based on a Keras model. + + Args: + graph: Graph with nodes to compute the activation bias correction. + core_config: Configuration object containing parameters of how the model should be quantized. + fw_info: Framework info like lists of nodes their kernel should quantized. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + + Returns: + Graph with activation bias correction term for each relevant node. + """ + graph = compute_activation_bias_correction_of_graph(graph=graph, + core_config=core_config, + fw_info=fw_info, + fw_impl=fw_impl, + activation_bias_correction_node_matchers= + activation_bias_correction_node_matchers, + kernel_size=KERNEL_SIZE) + return graph diff --git a/model_compression_toolkit/core/pytorch/pytorch_implementation.py b/model_compression_toolkit/core/pytorch/pytorch_implementation.py index 38e03dd08..2b5e3bae8 100644 --- a/model_compression_toolkit/core/pytorch/pytorch_implementation.py +++ b/model_compression_toolkit/core/pytorch/pytorch_implementation.py @@ -92,6 +92,8 @@ from model_compression_toolkit.core.pytorch.reader.reader import model_reader from model_compression_toolkit.core.pytorch.statistics_correction.apply_second_moment_correction import \ pytorch_apply_second_moment_correction +from model_compression_toolkit.core.pytorch.statistics_correction.pytorch_compute_activation_bias_correction_of_graph import \ + pytorch_compute_activation_bias_correction_of_graph from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model from model_compression_toolkit.exporter.model_wrapper.fw_agnostic.get_inferable_quantizers import \ get_inferable_quantizers @@ -212,6 +214,25 @@ def shift_negative_correction(self, core_config, fw_info) + def compute_activation_bias_correction(self, + graph: Graph, + core_config: CoreConfig, + fw_info: FrameworkInfo): + """ + Compute activation bias correction on a graph. + + Args: + graph: Graph to apply activation bias correction on. + core_config: QuantizationConfig of how the model should be quantized. + fw_info: FrameworkInfo object with information about the specific framework's model. + + Returns: + Graph after activation bias correction computing. + """ + return pytorch_compute_activation_bias_correction_of_graph(graph=graph, + core_config=core_config, + fw_info=fw_info, + fw_impl=self) def get_substitutions_channel_equalization(self, quant_config: QuantizationConfig, diff --git a/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py new file mode 100644 index 000000000..f2d17e4e5 --- /dev/null +++ b/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py @@ -0,0 +1,71 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from torch import reshape, transpose, flatten, permute +from torch.nn import Conv2d, Linear, Dropout, ZeroPad2d, AdaptiveAvgPool2d +from torch.nn.functional import avg_pool2d, pad + +from model_compression_toolkit.core import CoreConfig +from model_compression_toolkit.core.common import Graph +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher +from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \ + compute_activation_bias_correction_of_graph +from model_compression_toolkit.core.pytorch.constants import KERNEL_SIZE + + +def activation_bias_correction_node_matchers(): + # Match linear layers where we can add a correction. + linear_node = NodeOperationMatcher(Linear) | NodeOperationMatcher(Conv2d) + + # Match bypass layers what don't affect the quantization process. + bypass_node = NodeOperationMatcher(reshape) | \ + NodeOperationMatcher(avg_pool2d) | \ + NodeOperationMatcher(transpose) | \ + NodeOperationMatcher(Dropout) | \ + NodeOperationMatcher(flatten) | \ + NodeOperationMatcher(ZeroPad2d) | \ + NodeOperationMatcher(pad) | \ + NodeOperationMatcher(AdaptiveAvgPool2d) | \ + NodeOperationMatcher(permute) + + return linear_node, bypass_node + + +def pytorch_compute_activation_bias_correction_of_graph(graph: Graph, + core_config: CoreConfig, + fw_info: FrameworkInfo, + fw_impl: FrameworkImplementation) -> Graph: + """ + Compute the activation bias correction term for graph based on a PyTorch model. + + Args: + graph: Graph with nodes to compute the activation bias correction. + core_config: Configuration object containing parameters of how the model should be quantized. + fw_info: Framework info like lists of nodes their kernel should quantized. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + + Returns: + Graph with activation bias correction term for each relevant node. + """ + graph = compute_activation_bias_correction_of_graph(graph=graph, + core_config=core_config, + fw_info=fw_info, + fw_impl=fw_impl, + activation_bias_correction_node_matchers= + activation_bias_correction_node_matchers, + kernel_size=KERNEL_SIZE) + return graph diff --git a/model_compression_toolkit/core/runner.py b/model_compression_toolkit/core/runner.py index 66b25f080..b111219bb 100644 --- a/model_compression_toolkit/core/runner.py +++ b/model_compression_toolkit/core/runner.py @@ -164,6 +164,14 @@ def core_runner(in_model: Any, tg, bit_widths_config) + ###################################### + # Compute Activation Bias Correction + ###################################### + if core_config.quantization_config.activation_bias_correction: + tg = fw_impl.compute_activation_bias_correction(graph=tg, + core_config=core_config, + fw_info=fw_info) + # Edit the graph again after finalizing the configurations. # This is since some actions regard the final configuration and should be edited. edit_network_graph(tg, fw_info, core_config.debug_config.network_editor) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py new file mode 100644 index 000000000..86dc0bff8 --- /dev/null +++ b/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py @@ -0,0 +1,161 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from model_compression_toolkit.core import QuantizationConfig +from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest + +import tensorflow as tf +import numpy as np + +from tests.keras_tests.utils import get_layers_from_model_by_type + +keras = tf.keras +layers = keras.layers + +""" +This test checks the Activation Bias Correction feature. +""" + +class BaseActivationBiasCorrectionTest(BaseKerasFeatureNetworkTest): + """ + This test checks the Activation Bias Correction feature. + """ + + def __init__(self, unit_test): + super().__init__(unit_test) + + def get_quantization_config(self): + return QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + activation_bias_correction=True) + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = layers.Activation('gelu')(inputs) + x = layers.Dropout(0.5)(x) + x = layers.Dropout(0.5)(x) + outputs = layers.Dense(30)(x) + return keras.Model(inputs=inputs, outputs=outputs) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + float_dense_layers = get_layers_from_model_by_type(float_model, layers.Dense) + quantized_dense_layers = get_layers_from_model_by_type(quantized_model, layers.Dense) + + bias = float_dense_layers[-1].bias + bias_after_activation_bias_correction = quantized_dense_layers[-1].layer.bias + + self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected a change in the bias value.") + + +class BaseActivationBiasCorrectionConvTest(BaseKerasFeatureNetworkTest): + """ + This test checks the Activation Bias Correction feature with Conv2D layer. + """ + + def __init__(self, unit_test): + super().__init__(unit_test) + + def get_quantization_config(self): + # A small value set to the activation bias correction threshold only to activate the threshold + # filtering without changing the bias correction values. + return QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + activation_bias_correction=True, + activation_bias_correction_threshold=1e-6) + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = layers.Activation('swish')(inputs) + x = layers.ZeroPadding2D(2)(x) + outputs = layers.Conv2D(filters=3, kernel_size=1, use_bias=True)(x) + return keras.Model(inputs=inputs, outputs=outputs) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + float_conv_layers = get_layers_from_model_by_type(float_model, layers.Conv2D) + quantized_conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) + + bias = float_conv_layers[-1].bias + bias_after_activation_bias_correction = quantized_conv_layers[-1].layer.bias + + self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected a change in the bias value.") + + +class BaseActivationBiasCorrectionDWConvTest(BaseKerasFeatureNetworkTest): + """ + This test checks the Activation Bias Correction feature with DepthWiseConv2D layer. + """ + + def __init__(self, unit_test): + super().__init__(unit_test) + + def get_quantization_config(self): + return QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + activation_bias_correction=True) + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = tf.nn.gelu(inputs) + x = layers.MaxPooling2D(pool_size=(2, 2), strides=2)(x) + outputs = layers.DepthwiseConv2D(kernel_size=1, use_bias=True, bias_initializer='glorot_uniform', + depth_multiplier=1)(x) + return keras.Model(inputs=inputs, outputs=outputs) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + float_dw_conv_layers = get_layers_from_model_by_type(float_model, layers.DepthwiseConv2D) + quantized_dw_conv_layers = get_layers_from_model_by_type(quantized_model, layers.DepthwiseConv2D) + + bias = float_dw_conv_layers[-1].bias + bias_after_activation_bias_correction = quantized_dw_conv_layers[-1].layer.bias + + self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected a change in the bias value.") + + +class BaseActivationBiasCorrectionReshapeConvTest(BaseKerasFeatureNetworkTest): + """ + This test checks the Activation Bias Correction feature. + """ + + def __init__(self, unit_test): + super().__init__(unit_test) + + def get_quantization_config(self): + # A large value is assigned to the activation bias correction threshold to enable threshold filtering, + # which adjusts the bias correction values to zero. + return QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + activation_bias_correction=True, + activation_bias_correction_threshold=1e9) + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = layers.Activation('swish')(inputs) + x = layers.Flatten()(x) + x = layers.Reshape(target_shape=(8, 8, 3))(x) + outputs = layers.Conv2D(filters=3, kernel_size=1, use_bias=True, bias_initializer='ones')(x) + return keras.Model(inputs=inputs, outputs=outputs) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + float_conv_layers = get_layers_from_model_by_type(float_model, layers.Conv2D) + quantized_conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) + + bias = float_conv_layers[-1].bias + bias_after_activation_bias_correction = quantized_conv_layers[-1].layer.bias + + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias value in " + f"case of activation_bias_correction_threshold 1e9.") diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 45c8cef3e..10ea22c28 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -28,6 +28,9 @@ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod from model_compression_toolkit.gptq import RoundingType from model_compression_toolkit.target_platform_capabilities import constants as C +from tests.keras_tests.feature_networks_tests.feature_networks.activation_bias_correction_test import \ + BaseActivationBiasCorrectionTest, BaseActivationBiasCorrectionConvTest, BaseActivationBiasCorrectionDWConvTest, \ + BaseActivationBiasCorrectionReshapeConvTest from tests.keras_tests.feature_networks_tests.feature_networks.activation_decomposition_test import \ ActivationDecompositionTest from tests.keras_tests.feature_networks_tests.feature_networks.activation_relu_bound_to_power_of_2_test import \ @@ -516,6 +519,12 @@ def test_conv2d_bn_concat(self): def test_activation_scaling_relu6(self): ReLUBoundToPOTNetTest(self).run_test() + def test_activation_bias_correction(self): + BaseActivationBiasCorrectionTest(self).run_test() + BaseActivationBiasCorrectionConvTest(self).run_test() + BaseActivationBiasCorrectionDWConvTest(self).run_test() + BaseActivationBiasCorrectionReshapeConvTest(self).run_test() + def test_layer_activation_softmax_shift(self): SoftmaxShiftTest(self, layers.Dense(20, activation='softmax'), None).run_test() diff --git a/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py b/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py new file mode 100644 index 000000000..cdd2a6970 --- /dev/null +++ b/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py @@ -0,0 +1,159 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import torch +from torch.nn import GELU, Hardswish, AdaptiveAvgPool2d, ZeroPad2d, Linear, Conv2d + +import model_compression_toolkit as mct +from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest + +""" +This test checks the Activation Bias Correction feature. +""" + + +class ActivationBiasCorrectionNet(torch.nn.Module): + """ + This is the network to test the Activation Bias Correction feature. + """ + + def __init__(self, + linear_layer, + bypass_layers): + super(ActivationBiasCorrectionNet, self).__init__() + self.activation_layer = GELU() + self.linear_layer = linear_layer + self.bypass_layers = torch.nn.ModuleList(bypass_layers) + + def forward(self, x): + x = self.activation_layer(x) + + for bypass_layer in self.bypass_layers: + x = bypass_layer(x) + x = self.linear_layer(x) + return x + + +class ActivationBiasCorrectionPadNet(torch.nn.Module): + """ + This is the network to test the Activation Bias Correction feature with pooling/padding layers as a bypass layers. + """ + + def __init__(self): + super(ActivationBiasCorrectionPadNet, self).__init__() + self.activation_layer = Hardswish() + self.pooling_layer = AdaptiveAvgPool2d(output_size=6) + self.padding_layer = ZeroPad2d(padding=2) + self.linear_layer = Linear(10, 10) + + def forward(self, x): + x = self.activation_layer(x) + x = self.pooling_layer(x) + x = self.padding_layer(x) + x = self.linear_layer(x) + return x + + +class ActivationBiasCorrectionReshapeNet(torch.nn.Module): + """ + This is the network to test the Activation Bias Correction feature with reshape layers as a bypass layers. + """ + + def __init__(self): + super(ActivationBiasCorrectionReshapeNet, self).__init__() + self.activation_layer = GELU() + self.linear_layer = Conv2d(in_channels=8, out_channels=1, kernel_size=1) + + def forward(self, x): + x = self.activation_layer(x) + x = x.flatten() + x = x.reshape(8, 2, -1) + x = self.linear_layer(x) + return x + + +class BaseActivationBiasCorrectionTest(BasePytorchFeatureNetworkTest): + def __init__(self, unit_test): + super().__init__(unit_test) + + def get_quantization_config(self): + return mct.core.QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + activation_bias_correction=True) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + bias = float_model.linear_layer.bias.cpu().detach().numpy() + bias_after_activation_bias_correction = quantized_model.linear_layer.layer.bias.cpu().detach().numpy() + self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected a change in the bias value.") + + +class BaseActivationBiasCorrectionNetTest(BaseActivationBiasCorrectionTest): + def __init__(self, unit_test, linear_layer, bypass_layers): + super().__init__(unit_test) + self.linear_layer = linear_layer + self.bypass_layers = bypass_layers + + def create_networks(self): + return ActivationBiasCorrectionNet(linear_layer=self.linear_layer, + bypass_layers=self.bypass_layers) + + +class BaseActivationBiasCorrectionPadNetTest(BaseActivationBiasCorrectionTest): + def __init__(self, unit_test): + super().__init__(unit_test) + + def get_quantization_config(self): + # A small value set to the activation bias correction threshold only to activate the threshold + # filtering without changing the bias correction values. + return mct.core.QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + activation_bias_correction=True, + activation_bias_correction_threshold=1e-6) + + def create_networks(self): + return ActivationBiasCorrectionPadNet() + + +class BaseActivationBiasCorrectionBigThrTest(BaseActivationBiasCorrectionTest): + def __init__(self, unit_test): + super().__init__(unit_test) + + def get_quantization_config(self): + # A large value is assigned to the activation bias correction threshold to enable threshold filtering, + # which adjusts the bias correction values to zero. + return mct.core.QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + activation_bias_correction=True, + activation_bias_correction_threshold=1e9) + + def create_networks(self): + return ActivationBiasCorrectionPadNet() + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + bias = float_model.linear_layer.bias.cpu().detach().numpy() + bias_after_activation_bias_correction = quantized_model.linear_layer.layer.bias.cpu().detach().numpy() + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias value in " + f"case of activation_bias_correction_threshold 1e9.") + + +class BaseActivationBiasCorrectionReshapeNetTest(BaseActivationBiasCorrectionTest): + def __init__(self, unit_test): + super().__init__(unit_test) + + def create_networks(self): + return ActivationBiasCorrectionReshapeNet() diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index 44f8193bb..d1fe8833a 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -30,6 +30,9 @@ from model_compression_toolkit.trainable_infrastructure import TrainingMethod from tests.pytorch_tests.model_tests.feature_models.activation_16bit_test import Activation16BitTest, \ Activation16BitMixedPrecisionTest +from tests.pytorch_tests.model_tests.feature_models.activation_bias_correction_test import \ + BaseActivationBiasCorrectionPadNetTest, BaseActivationBiasCorrectionNetTest, \ + BaseActivationBiasCorrectionReshapeNetTest, BaseActivationBiasCorrectionBigThrTest from tests.pytorch_tests.model_tests.feature_models.add_net_test import AddNetTest from tests.pytorch_tests.model_tests.feature_models.add_same_test import AddSameNetTest from tests.pytorch_tests.model_tests.feature_models.bn_attributes_quantization_test import BNAttributesQuantization @@ -302,13 +305,13 @@ def test_permute_substitution(self): This test checks the permute substitution feature """ PermuteSubstitutionTest(self).run_test() - + def test_reshape_substitution(self): """ This test checks the reshape substitution feature """ ReshapeSubstitutionTest(self).run_test() - + def test_constant_conv_substitution(self): """ This test checks the constant conv substitution feature @@ -438,6 +441,21 @@ def test_shift_negative_activation_net(self): for activation_layer in [torch.nn.Hardswish, torch.nn.GELU]: ShiftNegaviteActivationNetTest(self, activation_layer=activation_layer).run_test(seed=3) + def test_activation_bias_correction_net(self): + """ + This test checks the activation bias correction feature. + """ + BaseActivationBiasCorrectionPadNetTest(self).run_test() + BaseActivationBiasCorrectionReshapeNetTest(self).run_test() + BaseActivationBiasCorrectionBigThrTest(self).run_test() + for linear_layer in [nn.Linear(8, 20), + nn.Conv2d(3, 20, 1), + nn.Conv2d(3, 8, (1, 1))]: + for bypass_layers in [[nn.Dropout(0.5)], [nn.Dropout(0.5), nn.Dropout(0.5)]]: + BaseActivationBiasCorrectionNetTest(self, + linear_layer=linear_layer, + bypass_layers=bypass_layers).run_test() + def test_split_concat_net(self): """ This test checks: @@ -715,7 +733,7 @@ def test_bn_attributes_quantization(self): """ BNAttributesQuantization(self, quantize_linear=False).run_test() BNAttributesQuantization(self, quantize_linear=True).run_test() - + def test_concat_threshold_update(self): ConcatUpdateTest(self).run_test() From 0810a1b15b4faa78f74b63ca880cf1e08d4201e1 Mon Sep 17 00:00:00 2001 From: ariell Date: Mon, 4 Nov 2024 15:44:15 +0200 Subject: [PATCH 2/5] Fix no previous node bug --- .../common/quantization/quantization_config.py | 4 ++-- ...compute_activation_bias_correction_of_graph.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/model_compression_toolkit/core/common/quantization/quantization_config.py b/model_compression_toolkit/core/common/quantization/quantization_config.py index a940c4504..cf3a39976 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/quantization_config.py @@ -70,8 +70,6 @@ class QuantizationConfig: weights_error_method: QuantizationErrorMethod = QuantizationErrorMethod.MSE relu_bound_to_power_of_2: bool = False weights_bias_correction: bool = True - activation_bias_correction: bool = False - activation_bias_correction_threshold: float = 0.0 weights_second_moment_correction: bool = False input_scaling: bool = False softmax_shift: bool = False @@ -86,6 +84,8 @@ class QuantizationConfig: shift_negative_threshold_recalculation: bool = False shift_negative_params_search: bool = False concat_threshold_update: bool = False + activation_bias_correction: bool = False + activation_bias_correction_threshold: float = 0.0 # Default quantization configuration the library use. diff --git a/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py index e8ad63679..6935fd468 100644 --- a/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +++ b/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py @@ -199,11 +199,12 @@ def compute_activation_bias_correction_of_graph(graph: Graph, graph=graph, linear_node_types=linear_node_types, bypass_node_types=bypass_node_types) - graph = compute_activation_bias_correction(graph=graph, - core_config=core_config, - fw_info=fw_info, - fw_impl=fw_impl, - linear_node=n, - prev_node=prev_node, - kernel_size=kernel_size) + if prev_node is not None: + graph = compute_activation_bias_correction(graph=graph, + core_config=core_config, + fw_info=fw_info, + fw_impl=fw_impl, + linear_node=n, + prev_node=prev_node, + kernel_size=kernel_size) return graph From 1d40f9b4371722b7c55acdbc103c3ddf0c9497bd Mon Sep 17 00:00:00 2001 From: ariell Date: Mon, 4 Nov 2024 16:38:21 +0200 Subject: [PATCH 3/5] Increase coverage --- ...ute_activation_bias_correction_of_graph.py | 6 +-- .../activation_bias_correction_test.py | 38 ++++++++++++++++++- .../test_features_runner.py | 3 +- 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py index 6935fd468..c828250cb 100644 --- a/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +++ b/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py @@ -46,7 +46,7 @@ def get_next_nodes_to_correct(node: BaseNode, prev_nodes = graph.get_prev_nodes(node) if len(prev_nodes) != 1: - return None, None + return None, None # pragma: no cover prev_node = prev_nodes[0] @@ -65,7 +65,7 @@ def get_next_nodes_to_correct(node: BaseNode, linear_node_types=linear_node_types, bypass_node_types=bypass_node_types, bypass_nodes=bypass_nodes) - return None, None + return None, None # pragma: no cover def calculate_bin_centers(bin_edges: np.ndarray) -> np.ndarray: @@ -124,7 +124,7 @@ def compute_activation_bias_correction(graph: Graph, # Check if the previous node's has activation quantization configuration and if the previous node have the # histogram collector if prev_node_act_quant_cfg is None or not hasattr(graph.get_out_stats_collector(prev_node), 'hc'): - return graph + return graph # pragma: no cover float_bins, float_count = graph.get_out_stats_collector(prev_node).hc.get_histogram() diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py index 86dc0bff8..4c98fefd4 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py @@ -38,6 +38,7 @@ def __init__(self, unit_test): def get_quantization_config(self): return QuantizationConfig(weights_bias_correction=False, weights_second_moment_correction=False, + shift_negative_activation_correction=False, activation_bias_correction=True) def create_networks(self): @@ -45,7 +46,7 @@ def create_networks(self): x = layers.Activation('gelu')(inputs) x = layers.Dropout(0.5)(x) x = layers.Dropout(0.5)(x) - outputs = layers.Dense(30)(x) + outputs = layers.Dense(30, use_bias=False)(x) return keras.Model(inputs=inputs, outputs=outputs) def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): @@ -159,3 +160,38 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), msg=f"Error in activation bias correction: expected no change in the bias value in " f"case of activation_bias_correction_threshold 1e9.") + + +class BaseActivationBiasCorrectionConv8Test(BaseKerasFeatureNetworkTest): + """ + This test checks the Activation Bias Correction feature. + """ + + def __init__(self, unit_test): + super().__init__(unit_test) + + def get_quantization_config(self): + # A large value is assigned to the activation bias correction threshold to enable threshold filtering, + # which adjusts the bias correction values to zero. + return QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + activation_bias_correction=True) + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = layers.Activation('swish')(inputs) + x = layers.Flatten()(x) + x = layers.Reshape(target_shape=(8, 8, 3))(x) + outputs = layers.Conv2D(filters=3, kernel_size=2, use_bias=True, bias_initializer='ones')(x) + return keras.Model(inputs=inputs, outputs=outputs) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + float_conv_layers = get_layers_from_model_by_type(float_model, layers.Conv2D) + quantized_conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) + + bias = float_conv_layers[-1].bias + bias_after_activation_bias_correction = quantized_conv_layers[-1].layer.bias + + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias value in " + f"case of conv2d with kernel 2.") diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 10ea22c28..e48973f99 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -30,7 +30,7 @@ from model_compression_toolkit.target_platform_capabilities import constants as C from tests.keras_tests.feature_networks_tests.feature_networks.activation_bias_correction_test import \ BaseActivationBiasCorrectionTest, BaseActivationBiasCorrectionConvTest, BaseActivationBiasCorrectionDWConvTest, \ - BaseActivationBiasCorrectionReshapeConvTest + BaseActivationBiasCorrectionReshapeConvTest, BaseActivationBiasCorrectionConv8Test from tests.keras_tests.feature_networks_tests.feature_networks.activation_decomposition_test import \ ActivationDecompositionTest from tests.keras_tests.feature_networks_tests.feature_networks.activation_relu_bound_to_power_of_2_test import \ @@ -524,6 +524,7 @@ def test_activation_bias_correction(self): BaseActivationBiasCorrectionConvTest(self).run_test() BaseActivationBiasCorrectionDWConvTest(self).run_test() BaseActivationBiasCorrectionReshapeConvTest(self).run_test() + BaseActivationBiasCorrectionConv8Test(self).run_test() def test_layer_activation_softmax_shift(self): SoftmaxShiftTest(self, layers.Dense(20, activation='softmax'), None).run_test() From 2967f1766754f71ad95dd1cc7b2f4774768a1258 Mon Sep 17 00:00:00 2001 From: ariell Date: Wed, 6 Nov 2024 11:14:20 +0200 Subject: [PATCH 4/5] Fix tests and node matchers --- .../core/common/framework_implementation.py | 58 +++--- .../quantization/node_quantization_config.py | 2 + ...ply_activation_bias_correction_to_graph.py | 23 +-- ...ute_activation_bias_correction_of_graph.py | 115 +++++------ .../core/keras/keras_implementation.py | 6 +- ...ute_activation_bias_correction_of_graph.py | 39 +--- .../core/pytorch/pytorch_implementation.py | 6 +- ...ute_activation_bias_correction_of_graph.py | 28 +-- model_compression_toolkit/core/runner.py | 2 +- .../activation_bias_correction_test.py | 184 ++++-------------- .../test_features_runner.py | 87 ++++++--- .../activation_bias_correction_test.py | 86 +++----- .../model_tests/test_feature_models_runner.py | 40 ++-- 13 files changed, 257 insertions(+), 419 deletions(-) diff --git a/model_compression_toolkit/core/common/framework_implementation.py b/model_compression_toolkit/core/common/framework_implementation.py index 3e63c561b..ad7758d46 100644 --- a/model_compression_toolkit/core/common/framework_implementation.py +++ b/model_compression_toolkit/core/common/framework_implementation.py @@ -64,7 +64,7 @@ def get_hessian_scores_calculator(self, Returns: HessianScoresCalculator to use for the hessian approximation scores computation for this request. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_hessian_scores_calculator method.') # pragma: no cover @abstractmethod @@ -77,7 +77,7 @@ def to_numpy(self, tensor: Any) -> np.ndarray: Returns: Numpy array converted from the input tensor. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s to_numpy method.') # pragma: no cover @abstractmethod @@ -90,7 +90,7 @@ def to_tensor(self, tensor: np.ndarray) -> Any: Returns: Framework's tensor converted from the input Numpy array. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s to_tensor method.') # pragma: no cover @abstractmethod @@ -106,7 +106,7 @@ def model_reader(self, Returns: Graph representing the input model. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s model_reader method.') # pragma: no cover @abstractmethod @@ -131,7 +131,7 @@ def model_builder(self, Returns: A tuple with the model and additional relevant supporting objects. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s model_builder method.') # pragma: no cover @abstractmethod @@ -148,7 +148,7 @@ def run_model_inference(self, Returns: The frameworks model's output. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s run_model_inference method.') # pragma: no cover @abstractmethod @@ -167,26 +167,26 @@ def shift_negative_correction(self, Returns: Graph after SNC. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s apply_shift_negative_correction method.') # pragma: no cover @abstractmethod def compute_activation_bias_correction(self, graph: Graph, - core_config: CoreConfig, + quant_config: QuantizationConfig, fw_info: FrameworkInfo) -> Graph: """ Compute activation bias correction on a graph. Args: graph: Graph to apply activation bias correction on. - core_config: QuantizationConfig of how the model should be quantized. + quant_config: QuantizationConfig of how the model should be quantized. fw_info: FrameworkInfo object with information about the specific framework's model. Returns: Graph after activation bias correction computing. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s compute_activation_bias_correction method.') # pragma: no cover @abstractmethod @@ -203,7 +203,7 @@ def get_substitutions_channel_equalization(self, Returns: A list of the framework substitutions used after we collect statistics. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover @abstractmethod @@ -213,7 +213,7 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List Returns: A list of the framework substitutions used to prepare the graph. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover @abstractmethod @@ -227,7 +227,7 @@ def get_substitutions_pre_statistics_collection(self, quant_config: Quantization Returns: A list of the framework substitutions used before we collect statistics. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_pre_statistics_collection method.') # pragma: no cover @abstractmethod @@ -235,7 +235,7 @@ def get_linear_collapsing_substitution(self) -> common.BaseSubstitution: """ Returns: linear collapsing substitution """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover @abstractmethod @@ -243,7 +243,7 @@ def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution: """ Returns: conv2d add const collapsing substitution """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_op2d_add_const_collapsing_substitution method.') # pragma: no cover @abstractmethod @@ -258,7 +258,7 @@ def get_substitutions_statistics_correction(self, quant_config: QuantizationConf Returns: A list of the framework substitutions used for statistics correction. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_statistics_correction method.') # pragma: no cover @abstractmethod @@ -266,7 +266,7 @@ def get_residual_collapsing_substitution(self) -> List[common.BaseSubstitution]: """ Returns: A list of the framework substitutions used for residual collapsing """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_residual_collapsing_substitution method.') # pragma: no cover @@ -282,7 +282,7 @@ def get_substitutions_post_statistics_collection(self, quant_config: Quantizatio Returns: A list of the framework substitutions used after we collect statistics. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_post_statistics_collection method.') # pragma: no cover @abstractmethod @@ -291,7 +291,7 @@ def get_substitutions_virtual_weights_activation_coupling(self) -> List[common.B Returns: A list of Keras substitutions used to build a virtual graph with composed activation-weights pairs. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_virtual_weights_activation_coupling ' f'method.') # pragma: no cover @@ -307,7 +307,7 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz Returns: A list of the framework substitutions used after we apply second moment statistics. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_after_second_moment_correction ' f'method.') # pragma: no cover @@ -335,7 +335,7 @@ def get_sensitivity_evaluator(self, A function that computes the metric. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover def get_node_prior_info(self, node: BaseNode, @@ -353,7 +353,7 @@ def get_node_prior_info(self, node: BaseNode, NodePriorInfo with information about the node. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_node_prior_info method.') # pragma: no cover def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool: @@ -364,7 +364,7 @@ def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool Returns: True if the node should be considered an interest point, False otherwise. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s count_node_for_mixed_precision_interest_points method.') # pragma: no cover def get_mp_node_distance_fn(self, n: BaseNode, @@ -383,7 +383,7 @@ def get_mp_node_distance_fn(self, n: BaseNode, Returns: A distance function between two tensors and a axis on which the distance is computed (if exists). """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_mp_node_distance_fn method.') # pragma: no cover @@ -400,7 +400,7 @@ def is_output_node_compatible_for_hessian_score_computation(self, """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s is_output_node_compatible_for_hessian_score_computation method.') # pragma: no cover @abstractmethod @@ -417,7 +417,7 @@ def get_node_mac_operations(self, Returns: The MAC count of the operation """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_node_mac_operations method.') # pragma: no cover @abstractmethod @@ -438,7 +438,7 @@ def apply_second_moment_correction(self, Returns: A Graph after second moment correction. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s apply_second_moment_correction method.') # pragma: no cover @abstractmethod @@ -455,7 +455,7 @@ def sensitivity_eval_inference(self, Returns: The output of the model inference on the given input. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s sensitivity_eval_inference method.') # pragma: no cover def get_inferable_quantizers(self, node: BaseNode): @@ -471,7 +471,7 @@ def get_inferable_quantizers(self, node: BaseNode): """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_inferable_quantizers method.') # pragma: no cover @staticmethod diff --git a/model_compression_toolkit/core/common/quantization/node_quantization_config.py b/model_compression_toolkit/core/common/quantization/node_quantization_config.py index 0c68fb7b4..a790cbc77 100644 --- a/model_compression_toolkit/core/common/quantization/node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/node_quantization_config.py @@ -95,7 +95,9 @@ def __init__(self, self.activation_error_method = qc.activation_error_method self.activation_n_bits = op_cfg.activation_n_bits self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2 + self.activation_bias_correction_term = None self.enable_activation_quantization = op_cfg.enable_activation_quantization + self.quantization_preserving = op_cfg.quantization_preserving self.signedness = op_cfg.signedness self.activation_channel_equalization = qc.activation_channel_equalization self.input_scaling = qc.input_scaling diff --git a/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py b/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py index 83fcc14ef..293e3dcce 100644 --- a/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +++ b/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== -import copy - from model_compression_toolkit.core import CoreConfig, QuantizationConfig from model_compression_toolkit.core.common import BaseNode, Graph from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation @@ -26,7 +24,7 @@ def apply_activation_bias_correction_to_graph(graph: Graph, core_config: CoreConfig, fw_impl: FrameworkImplementation) -> Graph: """ - Get a graph, where each node has a final activation quantization configuration (with a activation bias + Get a graph, where each node has a final activation quantization configuration (with an activation bias correction term in it), and apply the activation bias correction for each node in the graph. Args: @@ -38,12 +36,11 @@ def apply_activation_bias_correction_to_graph(graph: Graph, Graph with activation bias correction apply to it's nodes. """ - graph = copy.deepcopy(graph) for n in graph.nodes: # Activation bias correction is only relevant for nodes with kernel op kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0] if core_config.quantization_config.activation_bias_correction and kernel_attr is not None and \ - hasattr(n.final_activation_quantization_cfg, 'activation_bias_correction_term'): + n.final_activation_quantization_cfg.activation_bias_correction_term is not None: # If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was # calculated during model preparation, and is used now in the node's bias term. _apply_activation_bias_correction_to_node(n, fw_impl, core_config.quantization_config) @@ -66,15 +63,19 @@ def _apply_activation_bias_correction_to_node(node: BaseNode, correction = node.final_activation_quantization_cfg.activation_bias_correction_term bias = node.get_weights_by_keys(fw_impl.constants.BIAS) # get original bias from node's weights - if bias is not None: # If the layer has bias, we subtract the correction from original bias - node.set_weights_by_keys(fw_impl.constants.BIAS, bias - correction) - else: - # If the layer has no bias, we consider it as if it has and its value is 0 and add a "dummy" attribute - # configuration with disabled quantization. + if bias is None: + # If the layer has no bias, we set the bias as -correction. node.set_weights_by_keys(fw_impl.constants.BIAS, - correction) - node.framework_attr[fw_impl.constants.USE_BIAS] = True # Mark the use_bias attribute of the node. + + # Mark the use_bias attribute of the node. + node.framework_attr[fw_impl.constants.USE_BIAS] = True + + # Configure the quantization of the bias as disabled. node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS, WeightsAttrQuantizationConfig( qc, AttributeQuantizationConfig( enable_weights_quantization=False))) + else: + # If the layer has bias, we subtract the correction from original bias + node.set_weights_by_keys(fw_impl.constants.BIAS, bias - correction) diff --git a/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py index c828250cb..765226c6d 100644 --- a/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +++ b/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py @@ -12,82 +12,63 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from typing import List, Tuple, Any, Callable - import numpy as np +from typing import Any, Callable -from model_compression_toolkit.core import CoreConfig +from model_compression_toolkit.core import QuantizationConfig from model_compression_toolkit.core.common import BaseNode, Graph from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.framework_info import FrameworkInfo -from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher -def get_next_nodes_to_correct(node: BaseNode, - graph: Graph, - linear_node_types: NodeOperationMatcher, - bypass_node_types: NodeOperationMatcher, - bypass_nodes: List = None) -> Tuple[Any, Any]: +def get_previous_node_with_activation_quantization(linear_node: BaseNode, + graph: Graph) -> Any: """ - Search for the previous node which is not a bypass node of a given node. Go over the previous nodes of the node - and recursively search for a node. + Search recursively for the previous node with activation quantization. Args: - node: Node to search for its previous node. + linear_node: Node to search for its previous node. graph: Graph the node is in. - linear_node_types: Types of linear nodes to consider. - bypass_node_types: Types of nodes for bypassing to consider. - bypass_nodes: a list of bypass nodes found while running this function - Returns: The previous node (if found) and a list of bypass nodes (if any), or Nones if it were not found or there - are multiple incoming edges to one of nodes during the search (which means, the substitution can not be applied). + Returns: + The previous node (if found) or None if it was not found or there are multiple incoming edges to one of + nodes during the search (which means, the substitution can not be applied). """ - prev_nodes = graph.get_prev_nodes(node) + prev_nodes = graph.get_prev_nodes(linear_node) if len(prev_nodes) != 1: - return None, None # pragma: no cover + return None # pragma: no cover prev_node = prev_nodes[0] - # If the previous node is not a bypass type, return it as the valid node along with any bypass nodes - if not bypass_node_types.apply(prev_node): - return prev_node, bypass_nodes + activation_quantization_config = prev_node.final_activation_quantization_cfg - # If the previous node is a bypass node type, add it to the bypass_nodes list and continue searching - if bypass_node_types.apply(prev_node): - if bypass_nodes: - bypass_nodes.append(prev_node) - else: - bypass_nodes = [prev_node] - return get_next_nodes_to_correct(node=prev_node, - graph=graph, - linear_node_types=linear_node_types, - bypass_node_types=bypass_node_types, - bypass_nodes=bypass_nodes) - return None, None # pragma: no cover + # Search for node with activation quantization + if (activation_quantization_config.enable_activation_quantization and + not activation_quantization_config.quantization_preserving): + return prev_node + else: + return get_previous_node_with_activation_quantization(prev_node, graph) def calculate_bin_centers(bin_edges: np.ndarray) -> np.ndarray: """ Calculate the centers of bins given their edges. - Parameters: - bin_edges: Array of bin edges. + Args: + bin_edges: Array of bin edges. Returns: - np.ndarray: Array of bin centers. + np.ndarray: Array of bin centers. """ - # Ensure bin_edges is a numpy array - bin_edges = np.array(bin_edges, dtype=np.float32) - # Calculate the centers by averaging continuous bin edges bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0 return bin_centers def compute_activation_bias_correction(graph: Graph, - core_config: CoreConfig, + quant_config: QuantizationConfig, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation, linear_node: BaseNode, @@ -99,7 +80,7 @@ def compute_activation_bias_correction(graph: Graph, Args: graph: Graph with nodes to compute the activation bias correction for each node's final activation quantization configuration. - core_config: Configuration object containing parameters of how the model should be quantized. + quant_config: QuantizationConfig of how the model should be quantized. fw_info: Framework info like lists of nodes their kernel should quantized. fw_impl: FrameworkImplementation object with a specific framework methods implementation. linear_node: Node to compute the activation bias correction for. @@ -110,19 +91,16 @@ def compute_activation_bias_correction(graph: Graph, Graph with activation bias correction term for each node. """ - # Check if 'kernel_size' is a key in the framework-specific attributes of the linear_node, if it is then the - # linear_node is a convolution - if kernel_size in linear_node.framework_attr.keys(): - # Retrieve the value of 'kernel_size' and check if it is not 1 or (1, 1). This feature supports only kernel - # size of 1 or (1, 1). - if linear_node.framework_attr.get(kernel_size) not in [1, (1, 1)]: - # If the kernel size is not 1 or (1, 1), return the current graph unmodified - return graph + # Retrieve the 'kernel_size' value if it exists and ensure it is None, 1, or (1, 1). + # This feature supports only Dense/Linear layers and convolution layers with kernel size of 1 or (1, 1). + if linear_node.framework_attr.get(kernel_size) not in [None, 1, (1, 1)]: + # If the kernel size is not 1 or (1, 1), return the current graph unmodified + return graph prev_node_act_quant_cfg = prev_node.final_activation_quantization_cfg # Check if the previous node's has activation quantization configuration and if the previous node have the - # histogram collector + # histogram collector. if prev_node_act_quant_cfg is None or not hasattr(graph.get_out_stats_collector(prev_node), 'hc'): return graph # pragma: no cover @@ -143,25 +121,27 @@ def compute_activation_bias_correction(graph: Graph, # Compute the difference between the mean quantized center and the mean float center mean_diff = mean_quant_centers - mean_float_centers - # Check if activation bias correction is enabled based on the configured threshold - if core_config.quantization_config.activation_bias_correction_threshold > 0: - - # Calculate the normalized bias as a percentage of the float center norm - float_centers_norm1 = np.abs(mean_float_centers) - normalized_bias = 100 * np.abs(mean_diff) / float_centers_norm1 + # Calculate the normalized bias as a percentage of the float center norm + float_centers_norm1 = np.abs(mean_float_centers) + normalized_bias = 100 * np.abs(mean_diff) / float_centers_norm1 - # If the normalized bias is below the activation bias correction threshold, return the unmodified graph - if normalized_bias < core_config.quantization_config.activation_bias_correction_threshold: - return graph + # If the normalized bias is below the activation bias correction threshold, return the graph unmodified. + # By default, the threshold is set to 0, allowing all nodes to proceed in this case. + if normalized_bias < quant_config.activation_bias_correction_threshold: + return graph - # The correction term is a function of the layer type. kernel = linear_node.get_weights_by_keys(fw_info.kernel_ops_attributes_mapping.get(linear_node.type)[0]) + # Compute the activation bias correction by applying the quantization error to the kernel, resulting in an output + # size matching the number of output channels. if kernel is not None: + + # Get the axes that are not the output channel output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(linear_node.type) axis_not_output_channel = list(range(len(kernel.shape))) axis_not_output_channel.remove(output_channel_index) + # special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters if output_channel_index == input_channel_index: axis_not_output_channel.remove(3) # 3 is the depth multiplier index @@ -171,7 +151,7 @@ def compute_activation_bias_correction(graph: Graph, def compute_activation_bias_correction_of_graph(graph: Graph, - core_config: CoreConfig, + quant_config: QuantizationConfig, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation, activation_bias_correction_node_matchers: Callable, @@ -181,7 +161,7 @@ def compute_activation_bias_correction_of_graph(graph: Graph, Args: graph: Graph with nodes to compute the activation bias correction. - core_config: Configuration object containing parameters of how the model should be quantized. + quant_config: QuantizationConfig of how the model should be quantized. fw_info: Framework info like lists of nodes their kernel should quantized. fw_impl: FrameworkImplementation object with a specific framework methods implementation. activation_bias_correction_node_matchers: Function to match the layers for activation bias correction. @@ -191,17 +171,14 @@ def compute_activation_bias_correction_of_graph(graph: Graph, Returns: Graph with activation bias correction term for each relevant node. """ - linear_node_types, bypass_node_types = activation_bias_correction_node_matchers() + linear_node_types = activation_bias_correction_node_matchers() for n in graph.nodes: if linear_node_types.apply(n): - prev_node, _ = get_next_nodes_to_correct(node=n, - graph=graph, - linear_node_types=linear_node_types, - bypass_node_types=bypass_node_types) + prev_node = get_previous_node_with_activation_quantization(n, graph) if prev_node is not None: graph = compute_activation_bias_correction(graph=graph, - core_config=core_config, + quant_config=quant_config, fw_info=fw_info, fw_impl=fw_impl, linear_node=n, diff --git a/model_compression_toolkit/core/keras/keras_implementation.py b/model_compression_toolkit/core/keras/keras_implementation.py index 024c0bd54..85fab7637 100644 --- a/model_compression_toolkit/core/keras/keras_implementation.py +++ b/model_compression_toolkit/core/keras/keras_implementation.py @@ -222,21 +222,21 @@ def shift_negative_correction(self, def compute_activation_bias_correction(self, graph: Graph, - core_config: CoreConfig, + quant_config: QuantizationConfig, fw_info: FrameworkInfo): """ Compute activation bias correction on a graph. Args: graph: Graph to apply activation bias correction on. - core_config: QuantizationConfig of how the model should be quantized. + quant_config: QuantizationConfig of how the model should be quantized. fw_info: FrameworkInfo object with information about the specific framework's model. Returns: Graph after activation bias correction computing. """ return keras_compute_activation_bias_correction_of_graph(graph=graph, - core_config=core_config, + quant_config=quant_config, fw_info=fw_info, fw_impl=self) diff --git a/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py index 7dcdc6683..ce1ac9c23 100644 --- a/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +++ b/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py @@ -19,13 +19,11 @@ from model_compression_toolkit.core.keras.constants import KERNEL_SIZE if version.parse(tf.__version__) >= version.parse("2.13"): - from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ - MaxPooling2D, Flatten, Cropping2D, Permute, GlobalAveragePooling2D + from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose else: - from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ - MaxPooling2D, Flatten, Cropping2D, Permute, GlobalAveragePooling2D + from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose -from model_compression_toolkit.core import CoreConfig +from model_compression_toolkit.core import QuantizationConfig from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.framework_info import FrameworkInfo from model_compression_toolkit.core.common import Graph @@ -38,32 +36,13 @@ def activation_bias_correction_node_matchers(): # Match linear layers where we can add a correction. linear_node = NodeOperationMatcher(Conv2D) | \ NodeOperationMatcher(Dense) | \ - NodeOperationMatcher(DepthwiseConv2D) - - # Match bypass layers what don't affect the quantization process. - bypass_node = (NodeOperationMatcher(Cropping2D) | - NodeOperationMatcher(GlobalAveragePooling2D) | - NodeOperationMatcher(Dropout) | - NodeOperationMatcher(ZeroPadding2D) | - NodeOperationMatcher(MaxPooling2D) | - NodeOperationMatcher(tf.split) | - NodeOperationMatcher(tf.cast) | - NodeOperationMatcher(tf.unstack) | - NodeOperationMatcher(tf.__operators__.getitem) | - NodeOperationMatcher(tf.strided_slice) | - NodeOperationMatcher(Reshape) | - NodeOperationMatcher(tf.reshape) | - NodeOperationMatcher(Permute) | - NodeOperationMatcher(tf.transpose) | - NodeOperationMatcher(Flatten) | - NodeOperationMatcher(tf.gather) | - NodeOperationMatcher(tf.compat.v1.gather)) - - return linear_node, bypass_node + NodeOperationMatcher(DepthwiseConv2D) | \ + NodeOperationMatcher(Conv2DTranspose) + return linear_node def keras_compute_activation_bias_correction_of_graph(graph: Graph, - core_config: CoreConfig, + quant_config: QuantizationConfig, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation) -> Graph: """ @@ -71,7 +50,7 @@ def keras_compute_activation_bias_correction_of_graph(graph: Graph, Args: graph: Graph with nodes to compute the activation bias correction. - core_config: Configuration object containing parameters of how the model should be quantized. + quant_config: QuantizationConfig of how the model should be quantized. fw_info: Framework info like lists of nodes their kernel should quantized. fw_impl: FrameworkImplementation object with a specific framework methods implementation. @@ -79,7 +58,7 @@ def keras_compute_activation_bias_correction_of_graph(graph: Graph, Graph with activation bias correction term for each relevant node. """ graph = compute_activation_bias_correction_of_graph(graph=graph, - core_config=core_config, + quant_config=quant_config, fw_info=fw_info, fw_impl=fw_impl, activation_bias_correction_node_matchers= diff --git a/model_compression_toolkit/core/pytorch/pytorch_implementation.py b/model_compression_toolkit/core/pytorch/pytorch_implementation.py index 2b5e3bae8..5ec26a66d 100644 --- a/model_compression_toolkit/core/pytorch/pytorch_implementation.py +++ b/model_compression_toolkit/core/pytorch/pytorch_implementation.py @@ -216,21 +216,21 @@ def shift_negative_correction(self, def compute_activation_bias_correction(self, graph: Graph, - core_config: CoreConfig, + quant_config: QuantizationConfig, fw_info: FrameworkInfo): """ Compute activation bias correction on a graph. Args: graph: Graph to apply activation bias correction on. - core_config: QuantizationConfig of how the model should be quantized. + quant_config: QuantizationConfig of how the model should be quantized. fw_info: FrameworkInfo object with information about the specific framework's model. Returns: Graph after activation bias correction computing. """ return pytorch_compute_activation_bias_correction_of_graph(graph=graph, - core_config=core_config, + quant_config=quant_config, fw_info=fw_info, fw_impl=self) diff --git a/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py index f2d17e4e5..149050cf2 100644 --- a/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +++ b/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py @@ -13,11 +13,9 @@ # limitations under the License. # ============================================================================== -from torch import reshape, transpose, flatten, permute -from torch.nn import Conv2d, Linear, Dropout, ZeroPad2d, AdaptiveAvgPool2d -from torch.nn.functional import avg_pool2d, pad +from torch.nn import Conv2d, Linear, ConvTranspose2d -from model_compression_toolkit.core import CoreConfig +from model_compression_toolkit.core import QuantizationConfig from model_compression_toolkit.core.common import Graph from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.framework_info import FrameworkInfo @@ -29,24 +27,12 @@ def activation_bias_correction_node_matchers(): # Match linear layers where we can add a correction. - linear_node = NodeOperationMatcher(Linear) | NodeOperationMatcher(Conv2d) - - # Match bypass layers what don't affect the quantization process. - bypass_node = NodeOperationMatcher(reshape) | \ - NodeOperationMatcher(avg_pool2d) | \ - NodeOperationMatcher(transpose) | \ - NodeOperationMatcher(Dropout) | \ - NodeOperationMatcher(flatten) | \ - NodeOperationMatcher(ZeroPad2d) | \ - NodeOperationMatcher(pad) | \ - NodeOperationMatcher(AdaptiveAvgPool2d) | \ - NodeOperationMatcher(permute) - - return linear_node, bypass_node + linear_node = NodeOperationMatcher(Linear) | NodeOperationMatcher(Conv2d) | NodeOperationMatcher(ConvTranspose2d) + return linear_node def pytorch_compute_activation_bias_correction_of_graph(graph: Graph, - core_config: CoreConfig, + quant_config: QuantizationConfig, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation) -> Graph: """ @@ -54,7 +40,7 @@ def pytorch_compute_activation_bias_correction_of_graph(graph: Graph, Args: graph: Graph with nodes to compute the activation bias correction. - core_config: Configuration object containing parameters of how the model should be quantized. + quant_config: QuantizationConfig of how the model should be quantized. fw_info: Framework info like lists of nodes their kernel should quantized. fw_impl: FrameworkImplementation object with a specific framework methods implementation. @@ -62,7 +48,7 @@ def pytorch_compute_activation_bias_correction_of_graph(graph: Graph, Graph with activation bias correction term for each relevant node. """ graph = compute_activation_bias_correction_of_graph(graph=graph, - core_config=core_config, + quant_config=quant_config, fw_info=fw_info, fw_impl=fw_impl, activation_bias_correction_node_matchers= diff --git a/model_compression_toolkit/core/runner.py b/model_compression_toolkit/core/runner.py index b111219bb..65ef60176 100644 --- a/model_compression_toolkit/core/runner.py +++ b/model_compression_toolkit/core/runner.py @@ -169,7 +169,7 @@ def core_runner(in_model: Any, ###################################### if core_config.quantization_config.activation_bias_correction: tg = fw_impl.compute_activation_bias_correction(graph=tg, - core_config=core_config, + quant_config=core_config.quantization_config, fw_info=fw_info) # Edit the graph again after finalizing the configurations. diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py index 4c98fefd4..0de409258 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== from model_compression_toolkit.core import QuantizationConfig +from model_compression_toolkit.core.keras.constants import KERNEL_SIZE from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest import tensorflow as tf @@ -27,171 +28,58 @@ This test checks the Activation Bias Correction feature. """ + class BaseActivationBiasCorrectionTest(BaseKerasFeatureNetworkTest): """ This test checks the Activation Bias Correction feature. """ - def __init__(self, unit_test): + def __init__(self, unit_test, + prev_layer, + bypass_layer_list, + linear_layer, + activation_bias_correction_threshold=0.0): super().__init__(unit_test) + self.prev_layer = prev_layer + self.bypass_layer_list = bypass_layer_list + self.linear_layer = linear_layer + self.activation_bias_correction_threshold = activation_bias_correction_threshold def get_quantization_config(self): return QuantizationConfig(weights_bias_correction=False, weights_second_moment_correction=False, shift_negative_activation_correction=False, - activation_bias_correction=True) - - def create_networks(self): - inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) - x = layers.Activation('gelu')(inputs) - x = layers.Dropout(0.5)(x) - x = layers.Dropout(0.5)(x) - outputs = layers.Dense(30, use_bias=False)(x) - return keras.Model(inputs=inputs, outputs=outputs) - - def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): - float_dense_layers = get_layers_from_model_by_type(float_model, layers.Dense) - quantized_dense_layers = get_layers_from_model_by_type(quantized_model, layers.Dense) - - bias = float_dense_layers[-1].bias - bias_after_activation_bias_correction = quantized_dense_layers[-1].layer.bias - - self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), - msg=f"Error in activation bias correction: expected a change in the bias value.") - - -class BaseActivationBiasCorrectionConvTest(BaseKerasFeatureNetworkTest): - """ - This test checks the Activation Bias Correction feature with Conv2D layer. - """ - - def __init__(self, unit_test): - super().__init__(unit_test) - - def get_quantization_config(self): - # A small value set to the activation bias correction threshold only to activate the threshold - # filtering without changing the bias correction values. - return QuantizationConfig(weights_bias_correction=False, - weights_second_moment_correction=False, - activation_bias_correction=True, - activation_bias_correction_threshold=1e-6) - - def create_networks(self): - inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) - x = layers.Activation('swish')(inputs) - x = layers.ZeroPadding2D(2)(x) - outputs = layers.Conv2D(filters=3, kernel_size=1, use_bias=True)(x) - return keras.Model(inputs=inputs, outputs=outputs) - - def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): - float_conv_layers = get_layers_from_model_by_type(float_model, layers.Conv2D) - quantized_conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) - - bias = float_conv_layers[-1].bias - bias_after_activation_bias_correction = quantized_conv_layers[-1].layer.bias - - self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), - msg=f"Error in activation bias correction: expected a change in the bias value.") - - -class BaseActivationBiasCorrectionDWConvTest(BaseKerasFeatureNetworkTest): - """ - This test checks the Activation Bias Correction feature with DepthWiseConv2D layer. - """ - - def __init__(self, unit_test): - super().__init__(unit_test) - - def get_quantization_config(self): - return QuantizationConfig(weights_bias_correction=False, - weights_second_moment_correction=False, - activation_bias_correction=True) - - def create_networks(self): - inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) - x = tf.nn.gelu(inputs) - x = layers.MaxPooling2D(pool_size=(2, 2), strides=2)(x) - outputs = layers.DepthwiseConv2D(kernel_size=1, use_bias=True, bias_initializer='glorot_uniform', - depth_multiplier=1)(x) - return keras.Model(inputs=inputs, outputs=outputs) - - def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): - float_dw_conv_layers = get_layers_from_model_by_type(float_model, layers.DepthwiseConv2D) - quantized_dw_conv_layers = get_layers_from_model_by_type(quantized_model, layers.DepthwiseConv2D) - - bias = float_dw_conv_layers[-1].bias - bias_after_activation_bias_correction = quantized_dw_conv_layers[-1].layer.bias - - self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), - msg=f"Error in activation bias correction: expected a change in the bias value.") - - -class BaseActivationBiasCorrectionReshapeConvTest(BaseKerasFeatureNetworkTest): - """ - This test checks the Activation Bias Correction feature. - """ - - def __init__(self, unit_test): - super().__init__(unit_test) - - def get_quantization_config(self): - # A large value is assigned to the activation bias correction threshold to enable threshold filtering, - # which adjusts the bias correction values to zero. - return QuantizationConfig(weights_bias_correction=False, - weights_second_moment_correction=False, activation_bias_correction=True, - activation_bias_correction_threshold=1e9) + activation_bias_correction_threshold=self.activation_bias_correction_threshold) def create_networks(self): inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) - x = layers.Activation('swish')(inputs) - x = layers.Flatten()(x) - x = layers.Reshape(target_shape=(8, 8, 3))(x) - outputs = layers.Conv2D(filters=3, kernel_size=1, use_bias=True, bias_initializer='ones')(x) - return keras.Model(inputs=inputs, outputs=outputs) - - def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): - float_conv_layers = get_layers_from_model_by_type(float_model, layers.Conv2D) - quantized_conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) - - bias = float_conv_layers[-1].bias - bias_after_activation_bias_correction = quantized_conv_layers[-1].layer.bias - - self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), - msg=f"Error in activation bias correction: expected no change in the bias value in " - f"case of activation_bias_correction_threshold 1e9.") + x = self.prev_layer(inputs) + for bypass_layer in self.bypass_layer_list: + x = bypass_layer(x) -class BaseActivationBiasCorrectionConv8Test(BaseKerasFeatureNetworkTest): - """ - This test checks the Activation Bias Correction feature. - """ - - def __init__(self, unit_test): - super().__init__(unit_test) - - def get_quantization_config(self): - # A large value is assigned to the activation bias correction threshold to enable threshold filtering, - # which adjusts the bias correction values to zero. - return QuantizationConfig(weights_bias_correction=False, - weights_second_moment_correction=False, - activation_bias_correction=True) - - def create_networks(self): - inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) - x = layers.Activation('swish')(inputs) - x = layers.Flatten()(x) - x = layers.Reshape(target_shape=(8, 8, 3))(x) - outputs = layers.Conv2D(filters=3, kernel_size=2, use_bias=True, bias_initializer='ones')(x) + outputs = self.linear_layer(x) return keras.Model(inputs=inputs, outputs=outputs) def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): - float_conv_layers = get_layers_from_model_by_type(float_model, layers.Conv2D) - quantized_conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) - - bias = float_conv_layers[-1].bias - bias_after_activation_bias_correction = quantized_conv_layers[-1].layer.bias - - self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), - msg=f"Error in activation bias correction: expected no change in the bias value in " - f"case of conv2d with kernel 2.") + float_linear_layers = get_layers_from_model_by_type(float_model, type(self.linear_layer)) + quantized_linear_layers = get_layers_from_model_by_type(quantized_model, type(self.linear_layer)) + + bias = float_linear_layers[-1].bias + bias_after_activation_bias_correction = quantized_linear_layers[-1].layer.bias + + if getattr(float_linear_layers[-1], KERNEL_SIZE, None) in [None, 1, (1, 1)]: + if self.activation_bias_correction_threshold > 1e8: + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias " + f"value in case of activation_bias_correction_threshold " + f"{self.activation_bias_correction_threshold}.") + else: + self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected a change in the bias " + f"value.") + else: + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias value " + f"in case of conv with kernel 2.") diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index e48973f99..59336b057 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -29,8 +29,7 @@ from model_compression_toolkit.gptq import RoundingType from model_compression_toolkit.target_platform_capabilities import constants as C from tests.keras_tests.feature_networks_tests.feature_networks.activation_bias_correction_test import \ - BaseActivationBiasCorrectionTest, BaseActivationBiasCorrectionConvTest, BaseActivationBiasCorrectionDWConvTest, \ - BaseActivationBiasCorrectionReshapeConvTest, BaseActivationBiasCorrectionConv8Test + BaseActivationBiasCorrectionTest from tests.keras_tests.feature_networks_tests.feature_networks.activation_decomposition_test import \ ActivationDecompositionTest from tests.keras_tests.feature_networks_tests.feature_networks.activation_relu_bound_to_power_of_2_test import \ @@ -66,7 +65,8 @@ RequiresMixedPrecision, RequiresMixedPrecisionWeights from tests.keras_tests.feature_networks_tests.feature_networks.mixed_precision_bops_test import \ MixedPrecisionBopsBasicTest, MixedPrecisionBopsAllWeightsLayersTest, MixedPrecisionWeightsOnlyBopsTest, \ - MixedPrecisionActivationOnlyBopsTest, MixedPrecisionBopsAndWeightsUtilizationTest, MixedPrecisionBopsAndActivationUtilizationTest, \ + MixedPrecisionActivationOnlyBopsTest, MixedPrecisionBopsAndWeightsUtilizationTest, \ + MixedPrecisionBopsAndActivationUtilizationTest, \ MixedPrecisionBopsAndTotalUtilizationTest, MixedPrecisionBopsWeightsActivationUtilizationTest, \ MixedPrecisionBopsMultipleOutEdgesTest from tests.keras_tests.feature_networks_tests.feature_networks.mixed_precision_tests import \ @@ -143,7 +143,8 @@ MixedPrecisionSearchPartWeightsLayersTest, MixedPrecisionDepthwiseTest, MixedPrecisionSearchLastLayerDistanceTest, \ MixedPrecisionSearchActivationNonConfNodesTest, MixedPrecisionSearchTotalMemoryNonConfNodesTest, \ MixedPrecisionCombinedNMSTest -from tests.keras_tests.feature_networks_tests.feature_networks.matmul_substitution_test import MatmulToDenseSubstitutionTest +from tests.keras_tests.feature_networks_tests.feature_networks.matmul_substitution_test import \ + MatmulToDenseSubstitutionTest from tests.keras_tests.feature_networks_tests.feature_networks.metadata_test import MetadataTest from tests.keras_tests.feature_networks_tests.feature_networks.tpc_test import TpcTest from tests.keras_tests.feature_networks_tests.feature_networks.const_representation_test import ConstRepresentationTest, \ @@ -153,8 +154,10 @@ AdvancedConstQuantizationTest, ConstQuantizationMultiInputTest from tests.keras_tests.feature_networks_tests.feature_networks.activation_16bit_test import Activation16BitTest, \ Activation16BitMixedPrecisionTest -from tests.keras_tests.feature_networks_tests.feature_networks.sigmoid_mul_substitution_test import SigMulSubstitutionTest -from tests.keras_tests.feature_networks_tests.feature_networks.conv_func_substitutions_test import ConvFuncSubstitutionsTest +from tests.keras_tests.feature_networks_tests.feature_networks.sigmoid_mul_substitution_test import \ + SigMulSubstitutionTest +from tests.keras_tests.feature_networks_tests.feature_networks.conv_func_substitutions_test import \ + ConvFuncSubstitutionsTest from model_compression_toolkit.qat.common.qat_config import TrainingMethod layers = tf.keras.layers @@ -170,7 +173,7 @@ def test_remove_identity(self): def test_per_tensor_weight_quantization(self): PerTensorWeightQuantizationTest(self).run_test() - + def test_single_relu_replacement(self): SingleReluReplacementTest(self).run_test() @@ -243,7 +246,7 @@ def test_mixed_precision_weights_only_activation_conf(self): def test_requires_mixed_recision(self): RequiresMixedPrecisionWeights(self, weights_memory=True).run_test() - RequiresMixedPrecision(self,activation_memory=True).run_test() + RequiresMixedPrecision(self, activation_memory=True).run_test() RequiresMixedPrecision(self, total_memory=True).run_test() RequiresMixedPrecision(self, bops=True).run_test() RequiresMixedPrecision(self).run_test() @@ -520,11 +523,25 @@ def test_activation_scaling_relu6(self): ReLUBoundToPOTNetTest(self).run_test() def test_activation_bias_correction(self): - BaseActivationBiasCorrectionTest(self).run_test() - BaseActivationBiasCorrectionConvTest(self).run_test() - BaseActivationBiasCorrectionDWConvTest(self).run_test() - BaseActivationBiasCorrectionReshapeConvTest(self).run_test() - BaseActivationBiasCorrectionConv8Test(self).run_test() + for use_bias in [False, True]: + for linear_layer in [layers.Dense(30, use_bias=use_bias), + layers.Conv2D(filters=3, kernel_size=1, use_bias=use_bias), + layers.DepthwiseConv2D(kernel_size=1, use_bias=use_bias, + bias_initializer='glorot_uniform', depth_multiplier=1), + layers.Conv2DTranspose(filters=3, kernel_size=1, use_bias=use_bias), + layers.Conv2D(filters=3, kernel_size=2, use_bias=use_bias)]: + for activation_bias_correction_threshold in [0.0, 1e-6, 1e9]: + for activation in ['gelu', 'swish']: + for bypass_layer in [[layers.ZeroPadding2D(2)], + [layers.Dropout(0.5), layers.Dropout(0.5)], + [layers.MaxPooling2D(pool_size=(2, 2), strides=2)], + [layers.Flatten(), layers.Reshape(target_shape=(8, 8, 3))]]: + BaseActivationBiasCorrectionTest(self, + prev_layer=layers.Activation(activation), + bypass_layer_list=bypass_layer, + linear_layer=linear_layer, + activation_bias_correction_threshold= + activation_bias_correction_threshold).run_test() def test_layer_activation_softmax_shift(self): SoftmaxShiftTest(self, layers.Dense(20, activation='softmax'), None).run_test() @@ -545,9 +562,12 @@ def test_conv2dbn_folding(self): Conv2DBNFoldingTest(self).run_test() def test_bn_forward_folding(self): - BNForwardFoldingTest(self, layers.Conv2D(2, 1, bias_initializer='glorot_uniform'), True, is_dwconv=True).run_test() - BNForwardFoldingTest(self, layers.DepthwiseConv2D(1, bias_initializer='glorot_uniform'), True, is_dwconv=True).run_test() - BNForwardFoldingTest(self, layers.Conv2DTranspose(2, 1, bias_initializer='glorot_uniform'), True, is_dwconv=True).run_test() + BNForwardFoldingTest(self, layers.Conv2D(2, 1, bias_initializer='glorot_uniform'), True, + is_dwconv=True).run_test() + BNForwardFoldingTest(self, layers.DepthwiseConv2D(1, bias_initializer='glorot_uniform'), True, + is_dwconv=True).run_test() + BNForwardFoldingTest(self, layers.Conv2DTranspose(2, 1, bias_initializer='glorot_uniform'), True, + is_dwconv=True).run_test() BNForwardFoldingTest(self, layers.Conv2D(2, 2), False, is_dwconv=True).run_test() BNForwardFoldingTest(self, layers.DepthwiseConv2D((3, 1)), False, is_dwconv=True).run_test() BNForwardFoldingTest(self, layers.Conv2DTranspose(2, (1, 3)), False, is_dwconv=True).run_test() @@ -592,10 +612,14 @@ def test_const_quantization(self): for qmethod in [QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC, QuantizationMethod.UNIFORM]: for error_method in [QuantizationErrorMethod.MSE, QuantizationErrorMethod.NOCLIPPING]: ConstQuantizationTest(self, func, c, qmethod=qmethod, error_method=error_method).run_test() - ConstQuantizationTest(self, func, c, input_reverse_order=True, qmethod=qmethod, error_method=error_method).run_test() - ConstQuantizationTest(self, func, c, input_reverse_order=True, use_kwargs=True, qmethod=qmethod, error_method=error_method).run_test() - ConstQuantizationTest(self, func, c, use_kwargs=True, qmethod=qmethod, error_method=error_method).run_test() - ConstQuantizationTest(self, func, 5.1, input_reverse_order=True, qmethod=qmethod, error_method=error_method).run_test() + ConstQuantizationTest(self, func, c, input_reverse_order=True, qmethod=qmethod, + error_method=error_method).run_test() + ConstQuantizationTest(self, func, c, input_reverse_order=True, use_kwargs=True, qmethod=qmethod, + error_method=error_method).run_test() + ConstQuantizationTest(self, func, c, use_kwargs=True, qmethod=qmethod, + error_method=error_method).run_test() + ConstQuantizationTest(self, func, 5.1, input_reverse_order=True, qmethod=qmethod, + error_method=error_method).run_test() AdvancedConstQuantizationTest(self).run_test() ConstQuantizationMultiInputTest(self).run_test() @@ -617,7 +641,8 @@ def test_const_representation(self): for func in [layers.Add(), layers.Multiply(), layers.Subtract()]: ConstRepresentationTest(self, func, c, is_list_input=True).run_test() ConstRepresentationTest(self, func, c, input_reverse_order=True, is_list_input=True).run_test() - ConstRepresentationTest(self, func, c, input_reverse_order=True, use_kwargs=True, is_list_input=True).run_test() + ConstRepresentationTest(self, func, c, input_reverse_order=True, use_kwargs=True, + is_list_input=True).run_test() ConstRepresentationTest(self, func, c, use_kwargs=True, is_list_input=True).run_test() ConstRepresentationMultiInputTest(self).run_test() @@ -707,7 +732,6 @@ def test_gptq(self): # GradientPTQLearnRateZeroConvGroupTest(self).run_test() # GradientPTQWeightsUpdateConvGroupTest(self).run_test() - def test_gptq_conv_group_dilation(self): GradientPTQLearnRateZeroConvGroupDilationTest(self).run_test() GradientPTQWeightsUpdateConvGroupDilationTest(self).run_test() @@ -804,7 +828,8 @@ def test_qat(self): QuantizationAwareTrainingQuantizersTest(self).run_test() QuantizationAwareTrainingQuantizerHolderTest(self).run_test() QATWrappersMixedPrecisionCfgTest(self).run_test() - QATWrappersMixedPrecisionCfgTest(self, ru_weights=17920 * 4 / 8, ru_activation=5408 * 4 / 8, expected_mp_cfg=[0, 4, 1, 1]).run_test() + QATWrappersMixedPrecisionCfgTest(self, ru_weights=17920 * 4 / 8, ru_activation=5408 * 4 / 8, + expected_mp_cfg=[0, 4, 1, 1]).run_test() def test_bn_attributes_quantization(self): BNAttributesQuantization(self, quantize_linear=False).run_test() @@ -893,8 +918,8 @@ def test_exceptions_manual_selection(self): # Invalid inputs to API with self.assertRaises(Exception) as context: ManualBitWidthSelectionTest(self, - [NodeNameFilter('relu1'), NodeNameFilter('add1'), NodeNameFilter('add2')], - [2, 4]).run_test() + [NodeNameFilter('relu1'), NodeNameFilter('add1'), NodeNameFilter('add2')], + [2, 4]).run_test() # Check that the correct exception message was raised self.assertEqual(str(context.exception), "Configuration Error: The number of provided bit_width values 2 must match the number of filters 3, or a single bit_width value should be provided for all filters.") @@ -909,15 +934,15 @@ def test_manual_bit_width_selection(self): ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Add), 4).run_test() ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Add), 2).run_test() ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Conv2D), NodeTypeFilter(layers.Dense)], - [2, 4]).run_test() + [2, 4]).run_test() ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Conv2D), NodeTypeFilter(layers.Dense)], - [4, 4]).run_test() + [4, 4]).run_test() ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Conv2D), NodeTypeFilter(layers.Add)], - [2, 4]).run_test() + [2, 4]).run_test() ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeTypeFilter(layers.Conv2D)], - [4, 4]).run_test() + [4, 4]).run_test() ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeTypeFilter(layers.Dense)], - 4).run_test() + 4).run_test() ManualBitWidthSelectionTest(self, NodeNameFilter('input'), 4).run_test() ManualBitWidthSelectionTest(self, NodeNameFilter('conv1'), 4).run_test() ManualBitWidthSelectionTest(self, NodeNameFilter('fc'), 4).run_test() @@ -926,7 +951,7 @@ def test_manual_bit_width_selection(self): ManualBitWidthSelectionTest(self, NodeNameFilter('relu1'), 4).run_test() ManualBitWidthSelectionTest(self, [NodeNameFilter('add1'), NodeNameFilter('conv1')], [2, 4]).run_test() ManualBitWidthSelectionTest(self, [NodeNameFilter('add2'), NodeNameFilter('relu1')], 4).run_test() - ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeNameFilter('add2')],[4, 2]).run_test() + ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeNameFilter('add2')], [4, 2]).run_test() if __name__ == '__main__': diff --git a/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py b/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py index cdd2a6970..beba40d80 100644 --- a/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py @@ -18,6 +18,7 @@ from torch.nn import GELU, Hardswish, AdaptiveAvgPool2d, ZeroPad2d, Linear, Conv2d import model_compression_toolkit as mct +from model_compression_toolkit.core.pytorch.constants import KERNEL_SIZE from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest """ @@ -31,10 +32,11 @@ class ActivationBiasCorrectionNet(torch.nn.Module): """ def __init__(self, + prev_layer, linear_layer, bypass_layers): super(ActivationBiasCorrectionNet, self).__init__() - self.activation_layer = GELU() + self.activation_layer = prev_layer self.linear_layer = linear_layer self.bypass_layers = torch.nn.ModuleList(bypass_layers) @@ -46,7 +48,6 @@ def forward(self, x): x = self.linear_layer(x) return x - class ActivationBiasCorrectionPadNet(torch.nn.Module): """ This is the network to test the Activation Bias Correction feature with pooling/padding layers as a bypass layers. @@ -86,74 +87,39 @@ def forward(self, x): class BaseActivationBiasCorrectionTest(BasePytorchFeatureNetworkTest): - def __init__(self, unit_test): + def __init__(self, unit_test, + model, + activation_bias_correction_threshold=0.0): super().__init__(unit_test) + self.model = model + self.activation_bias_correction_threshold = activation_bias_correction_threshold def get_quantization_config(self): return mct.core.QuantizationConfig(weights_bias_correction=False, weights_second_moment_correction=False, - activation_bias_correction=True) - - def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): - bias = float_model.linear_layer.bias.cpu().detach().numpy() - bias_after_activation_bias_correction = quantized_model.linear_layer.layer.bias.cpu().detach().numpy() - self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), - msg=f"Error in activation bias correction: expected a change in the bias value.") - - -class BaseActivationBiasCorrectionNetTest(BaseActivationBiasCorrectionTest): - def __init__(self, unit_test, linear_layer, bypass_layers): - super().__init__(unit_test) - self.linear_layer = linear_layer - self.bypass_layers = bypass_layers - - def create_networks(self): - return ActivationBiasCorrectionNet(linear_layer=self.linear_layer, - bypass_layers=self.bypass_layers) - - -class BaseActivationBiasCorrectionPadNetTest(BaseActivationBiasCorrectionTest): - def __init__(self, unit_test): - super().__init__(unit_test) - - def get_quantization_config(self): - # A small value set to the activation bias correction threshold only to activate the threshold - # filtering without changing the bias correction values. - return mct.core.QuantizationConfig(weights_bias_correction=False, - weights_second_moment_correction=False, - activation_bias_correction=True, - activation_bias_correction_threshold=1e-6) - - def create_networks(self): - return ActivationBiasCorrectionPadNet() - - -class BaseActivationBiasCorrectionBigThrTest(BaseActivationBiasCorrectionTest): - def __init__(self, unit_test): - super().__init__(unit_test) - - def get_quantization_config(self): - # A large value is assigned to the activation bias correction threshold to enable threshold filtering, - # which adjusts the bias correction values to zero. - return mct.core.QuantizationConfig(weights_bias_correction=False, - weights_second_moment_correction=False, + shift_negative_activation_correction=False, activation_bias_correction=True, - activation_bias_correction_threshold=1e9) + activation_bias_correction_threshold= + self.activation_bias_correction_threshold) def create_networks(self): - return ActivationBiasCorrectionPadNet() + return self.model def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): bias = float_model.linear_layer.bias.cpu().detach().numpy() bias_after_activation_bias_correction = quantized_model.linear_layer.layer.bias.cpu().detach().numpy() - self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), - msg=f"Error in activation bias correction: expected no change in the bias value in " - f"case of activation_bias_correction_threshold 1e9.") - -class BaseActivationBiasCorrectionReshapeNetTest(BaseActivationBiasCorrectionTest): - def __init__(self, unit_test): - super().__init__(unit_test) - - def create_networks(self): - return ActivationBiasCorrectionReshapeNet() + if getattr(float_model.linear_layer, KERNEL_SIZE, None) in [None, 1, (1, 1)]: + if self.activation_bias_correction_threshold > 1e8: + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias " + f"value in case of activation_bias_correction_threshold " + f"{self.activation_bias_correction_threshold}.") + else: + self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected a change in the bias " + f"value.") + else: + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias value " + f"in case of conv with kernel different than 1 or (1, 1).") diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index d1fe8833a..b5c02f350 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -30,9 +30,9 @@ from model_compression_toolkit.trainable_infrastructure import TrainingMethod from tests.pytorch_tests.model_tests.feature_models.activation_16bit_test import Activation16BitTest, \ Activation16BitMixedPrecisionTest -from tests.pytorch_tests.model_tests.feature_models.activation_bias_correction_test import \ - BaseActivationBiasCorrectionPadNetTest, BaseActivationBiasCorrectionNetTest, \ - BaseActivationBiasCorrectionReshapeNetTest, BaseActivationBiasCorrectionBigThrTest +from tests.pytorch_tests.model_tests.feature_models.activation_bias_correction_test import ( + BaseActivationBiasCorrectionTest, ActivationBiasCorrectionNet, ActivationBiasCorrectionPadNet, + ActivationBiasCorrectionReshapeNet) from tests.pytorch_tests.model_tests.feature_models.add_net_test import AddNetTest from tests.pytorch_tests.model_tests.feature_models.add_same_test import AddSameNetTest from tests.pytorch_tests.model_tests.feature_models.bn_attributes_quantization_test import BNAttributesQuantization @@ -445,16 +445,30 @@ def test_activation_bias_correction_net(self): """ This test checks the activation bias correction feature. """ - BaseActivationBiasCorrectionPadNetTest(self).run_test() - BaseActivationBiasCorrectionReshapeNetTest(self).run_test() - BaseActivationBiasCorrectionBigThrTest(self).run_test() - for linear_layer in [nn.Linear(8, 20), - nn.Conv2d(3, 20, 1), - nn.Conv2d(3, 8, (1, 1))]: - for bypass_layers in [[nn.Dropout(0.5)], [nn.Dropout(0.5), nn.Dropout(0.5)]]: - BaseActivationBiasCorrectionNetTest(self, - linear_layer=linear_layer, - bypass_layers=bypass_layers).run_test() + model_list = [ActivationBiasCorrectionNet(prev_layer=nn.GELU(), + linear_layer=nn.Linear(8, 20), + bypass_layers=[nn.Dropout(0.5), nn.Dropout(0.5)]), + ActivationBiasCorrectionNet(prev_layer=nn.GELU(), + linear_layer=nn.Conv2d(3, 20, 1), + bypass_layers=[nn.Dropout(0.5), nn.Dropout(0.5)]), + ActivationBiasCorrectionNet(prev_layer=nn.GELU(), + linear_layer=nn.Conv2d(3, 8, (1, 1)), + bypass_layers=[nn.Dropout(0.5)]), + ActivationBiasCorrectionNet(prev_layer=nn.Hardswish(), + linear_layer=nn.ConvTranspose2d(3, 8, 1), + bypass_layers=[nn.Dropout(0.5)]), + ActivationBiasCorrectionNet(prev_layer=nn.GELU(), + linear_layer=nn.Conv2d(3, 8, 3), + bypass_layers=[nn.Dropout(0.5)]), + ActivationBiasCorrectionPadNet(), + ActivationBiasCorrectionReshapeNet()] + + for activation_bias_correction_threshold in [0.0, 1e-6, 1e9]: + for model in model_list: + BaseActivationBiasCorrectionTest(self, + model=model, + activation_bias_correction_threshold= + activation_bias_correction_threshold).run_test() def test_split_concat_net(self): """ From 26a850d50a005d48fea8ceb247a38b0b5732b492 Mon Sep 17 00:00:00 2001 From: ariell Date: Wed, 6 Nov 2024 15:25:27 +0200 Subject: [PATCH 5/5] Fix comments --- ...compute_activation_bias_correction_of_graph.py | 15 +++++++++------ .../activation_bias_correction_test.py | 8 +++++++- .../activation_bias_correction_test.py | 7 +++++++ 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py index 765226c6d..bf753709b 100644 --- a/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +++ b/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py @@ -91,8 +91,10 @@ def compute_activation_bias_correction(graph: Graph, Graph with activation bias correction term for each node. """ - # Retrieve the 'kernel_size' value if it exists and ensure it is None, 1, or (1, 1). - # This feature supports only Dense/Linear layers and convolution layers with kernel size of 1 or (1, 1). + # Retrieve the 'kernel_size' value if it exists and ensure it is None, 1, or (1, 1). This feature supports only + # Dense/Linear layers and convolution layers with kernel size of 1 or (1, 1). + # For Dense/Linear layers, which lack a 'kernel_size' attribute, the result will be None, and no restriction + # applies in that case. if linear_node.framework_attr.get(kernel_size) not in [None, 1, (1, 1)]: # If the kernel size is not 1 or (1, 1), return the current graph unmodified return graph @@ -136,17 +138,18 @@ def compute_activation_bias_correction(graph: Graph, # size matching the number of output channels. if kernel is not None: - # Get the axes that are not the output channel + # Get the axes that are not the output channel. output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(linear_node.type) axis_not_output_channel = list(range(len(kernel.shape))) axis_not_output_channel.remove(output_channel_index) - # special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters + # Special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters. if output_channel_index == input_channel_index: - axis_not_output_channel.remove(3) # 3 is the depth multiplier index + axis_not_output_channel.remove(3) # 3 is the depth multiplier index. activation_bias_correction_term = mean_diff * np.sum(kernel, axis=tuple(axis_not_output_channel)) - linear_node.final_activation_quantization_cfg.activation_bias_correction_term = activation_bias_correction_term.flatten() + linear_node.final_activation_quantization_cfg.activation_bias_correction_term = ( + activation_bias_correction_term.flatten()) return graph diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py index 0de409258..e48e828ed 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py @@ -69,12 +69,18 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= bias = float_linear_layers[-1].bias bias_after_activation_bias_correction = quantized_linear_layers[-1].layer.bias + y = float_model.predict(input_x) + y_hat = quantized_model.predict(input_x) + + self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') + if getattr(float_linear_layers[-1], KERNEL_SIZE, None) in [None, 1, (1, 1)]: if self.activation_bias_correction_threshold > 1e8: self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), msg=f"Error in activation bias correction: expected no change in the bias " f"value in case of activation_bias_correction_threshold " f"{self.activation_bias_correction_threshold}.") + else: self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), msg=f"Error in activation bias correction: expected a change in the bias " @@ -82,4 +88,4 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= else: self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), msg=f"Error in activation bias correction: expected no change in the bias value " - f"in case of conv with kernel 2.") + f"in case of conv with kernel different than 1 or (1, 1).") diff --git a/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py b/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py index beba40d80..460aad783 100644 --- a/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py @@ -19,6 +19,7 @@ import model_compression_toolkit as mct from model_compression_toolkit.core.pytorch.constants import KERNEL_SIZE +from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, set_model from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest """ @@ -109,6 +110,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= bias = float_model.linear_layer.bias.cpu().detach().numpy() bias_after_activation_bias_correction = quantized_model.linear_layer.layer.bias.cpu().detach().numpy() + set_model(float_model) + y = float_model(to_torch_tensor(input_x[0])) + y_hat = quantized_model(to_torch_tensor(input_x[0])) + + self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') + if getattr(float_model.linear_layer, KERNEL_SIZE, None) in [None, 1, (1, 1)]: if self.activation_bias_correction_threshold > 1e8: self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction),