diff --git a/model_compression_toolkit/core/common/framework_implementation.py b/model_compression_toolkit/core/common/framework_implementation.py index 76ddd917b..3e63c561b 100644 --- a/model_compression_toolkit/core/common/framework_implementation.py +++ b/model_compression_toolkit/core/common/framework_implementation.py @@ -170,6 +170,25 @@ def shift_negative_correction(self, raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' f'framework\'s apply_shift_negative_correction method.') # pragma: no cover + @abstractmethod + def compute_activation_bias_correction(self, + graph: Graph, + core_config: CoreConfig, + fw_info: FrameworkInfo) -> Graph: + """ + Compute activation bias correction on a graph. + + Args: + graph: Graph to apply activation bias correction on. + core_config: QuantizationConfig of how the model should be quantized. + fw_info: FrameworkInfo object with information about the specific framework's model. + + Returns: + Graph after activation bias correction computing. + """ + raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + f'framework\'s compute_activation_bias_correction method.') # pragma: no cover + @abstractmethod def get_substitutions_channel_equalization(self, quant_config: QuantizationConfig, @@ -454,7 +473,7 @@ def get_inferable_quantizers(self, node: BaseNode): raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' f'framework\'s get_inferable_quantizers method.') # pragma: no cover - + @staticmethod def convert_data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int): """ diff --git a/model_compression_toolkit/core/common/quantization/quantization_config.py b/model_compression_toolkit/core/common/quantization/quantization_config.py index 8af7ee658..a940c4504 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/quantization_config.py @@ -70,6 +70,8 @@ class QuantizationConfig: weights_error_method: QuantizationErrorMethod = QuantizationErrorMethod.MSE relu_bound_to_power_of_2: bool = False weights_bias_correction: bool = True + activation_bias_correction: bool = False + activation_bias_correction_threshold: float = 0.0 weights_second_moment_correction: bool = False input_scaling: bool = False softmax_shift: bool = False diff --git a/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py b/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py new file mode 100644 index 000000000..83fcc14ef --- /dev/null +++ b/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py @@ -0,0 +1,80 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import copy + +from model_compression_toolkit.core import CoreConfig, QuantizationConfig +from model_compression_toolkit.core.common import BaseNode, Graph +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig +from model_compression_toolkit.target_platform_capabilities.target_platform import AttributeQuantizationConfig + + +def apply_activation_bias_correction_to_graph(graph: Graph, + core_config: CoreConfig, + fw_impl: FrameworkImplementation) -> Graph: + """ + Get a graph, where each node has a final activation quantization configuration (with a activation bias + correction term in it), and apply the activation bias correction for each node in the graph. + + Args: + graph: Graph to apply activation bias correction to. + core_config: CoreConfig containing parameters of how the model should be quantized. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + + Returns: + Graph with activation bias correction apply to it's nodes. + """ + + graph = copy.deepcopy(graph) + for n in graph.nodes: + # Activation bias correction is only relevant for nodes with kernel op + kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0] + if core_config.quantization_config.activation_bias_correction and kernel_attr is not None and \ + hasattr(n.final_activation_quantization_cfg, 'activation_bias_correction_term'): + # If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was + # calculated during model preparation, and is used now in the node's bias term. + _apply_activation_bias_correction_to_node(n, fw_impl, core_config.quantization_config) + return graph + + +def _apply_activation_bias_correction_to_node(node: BaseNode, + fw_impl: FrameworkImplementation, + qc: QuantizationConfig): + """ + Set new bias to node using the activation bias correction term that is stored in the + final activation quantization configuration. + + Args: + node: Node to set its corrected bias after activation bias correction. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + qc: QuantizationConfig containing parameters of how the model should be quantized. + + """ + correction = node.final_activation_quantization_cfg.activation_bias_correction_term + bias = node.get_weights_by_keys(fw_impl.constants.BIAS) # get original bias from node's weights + + if bias is not None: # If the layer has bias, we subtract the correction from original bias + node.set_weights_by_keys(fw_impl.constants.BIAS, bias - correction) + else: + # If the layer has no bias, we consider it as if it has and its value is 0 and add a "dummy" attribute + # configuration with disabled quantization. + node.set_weights_by_keys(fw_impl.constants.BIAS, - correction) + node.framework_attr[fw_impl.constants.USE_BIAS] = True # Mark the use_bias attribute of the node. + node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS, + WeightsAttrQuantizationConfig( + qc, + AttributeQuantizationConfig( + enable_weights_quantization=False))) diff --git a/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py new file mode 100644 index 000000000..e8ad63679 --- /dev/null +++ b/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py @@ -0,0 +1,209 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import List, Tuple, Any, Callable + +import numpy as np + +from model_compression_toolkit.core import CoreConfig +from model_compression_toolkit.core.common import BaseNode, Graph +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher + + +def get_next_nodes_to_correct(node: BaseNode, + graph: Graph, + linear_node_types: NodeOperationMatcher, + bypass_node_types: NodeOperationMatcher, + bypass_nodes: List = None) -> Tuple[Any, Any]: + """ + Search for the previous node which is not a bypass node of a given node. Go over the previous nodes of the node + and recursively search for a node. + + Args: + node: Node to search for its previous node. + graph: Graph the node is in. + linear_node_types: Types of linear nodes to consider. + bypass_node_types: Types of nodes for bypassing to consider. + bypass_nodes: a list of bypass nodes found while running this function + + Returns: The previous node (if found) and a list of bypass nodes (if any), or Nones if it were not found or there + are multiple incoming edges to one of nodes during the search (which means, the substitution can not be applied). + """ + + prev_nodes = graph.get_prev_nodes(node) + + if len(prev_nodes) != 1: + return None, None + + prev_node = prev_nodes[0] + + # If the previous node is not a bypass type, return it as the valid node along with any bypass nodes + if not bypass_node_types.apply(prev_node): + return prev_node, bypass_nodes + + # If the previous node is a bypass node type, add it to the bypass_nodes list and continue searching + if bypass_node_types.apply(prev_node): + if bypass_nodes: + bypass_nodes.append(prev_node) + else: + bypass_nodes = [prev_node] + return get_next_nodes_to_correct(node=prev_node, + graph=graph, + linear_node_types=linear_node_types, + bypass_node_types=bypass_node_types, + bypass_nodes=bypass_nodes) + return None, None + + +def calculate_bin_centers(bin_edges: np.ndarray) -> np.ndarray: + """ + Calculate the centers of bins given their edges. + + Parameters: + bin_edges: Array of bin edges. + + Returns: + np.ndarray: Array of bin centers. + """ + # Ensure bin_edges is a numpy array + bin_edges = np.array(bin_edges, dtype=np.float32) + + # Calculate the centers by averaging continuous bin edges + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0 + return bin_centers + + +def compute_activation_bias_correction(graph: Graph, + core_config: CoreConfig, + fw_info: FrameworkInfo, + fw_impl: FrameworkImplementation, + linear_node: BaseNode, + prev_node: BaseNode, + kernel_size: str) -> Graph: + """ + Compute the activation bias correction term, and store it in the final activation + quantization configuration. + + Args: + graph: Graph with nodes to compute the activation bias correction for each node's final activation quantization configuration. + core_config: Configuration object containing parameters of how the model should be quantized. + fw_info: Framework info like lists of nodes their kernel should quantized. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + linear_node: Node to compute the activation bias correction for. + prev_node: Node to compute the activation error caused by his activation quantization. + kernel_size: The framework specific attribute name of the convolution layer's kernel size. + + Returns: + Graph with activation bias correction term for each node. + """ + + # Check if 'kernel_size' is a key in the framework-specific attributes of the linear_node, if it is then the + # linear_node is a convolution + if kernel_size in linear_node.framework_attr.keys(): + # Retrieve the value of 'kernel_size' and check if it is not 1 or (1, 1). This feature supports only kernel + # size of 1 or (1, 1). + if linear_node.framework_attr.get(kernel_size) not in [1, (1, 1)]: + # If the kernel size is not 1 or (1, 1), return the current graph unmodified + return graph + + prev_node_act_quant_cfg = prev_node.final_activation_quantization_cfg + + # Check if the previous node's has activation quantization configuration and if the previous node have the + # histogram collector + if prev_node_act_quant_cfg is None or not hasattr(graph.get_out_stats_collector(prev_node), 'hc'): + return graph + + float_bins, float_count = graph.get_out_stats_collector(prev_node).hc.get_histogram() + + # Calculate the centers of the float bins + float_centers = calculate_bin_centers(float_bins) + + # Quantize the bin edges and calculate the centers of the quantized bins + quant_bins = prev_node_act_quant_cfg.quantize_node_output(fw_impl.to_tensor(float_bins)) + quant_bins = fw_impl.to_numpy(quant_bins) + quant_centers = calculate_bin_centers(quant_bins) + + # Calculate the mean of the both the float and the quantized bin centers, weighted by the bin counts + mean_float_centers = np.sum(float_centers * float_count) / np.sum(float_count) + mean_quant_centers = np.sum(quant_centers * float_count) / np.sum(float_count) + + # Compute the difference between the mean quantized center and the mean float center + mean_diff = mean_quant_centers - mean_float_centers + + # Check if activation bias correction is enabled based on the configured threshold + if core_config.quantization_config.activation_bias_correction_threshold > 0: + + # Calculate the normalized bias as a percentage of the float center norm + float_centers_norm1 = np.abs(mean_float_centers) + normalized_bias = 100 * np.abs(mean_diff) / float_centers_norm1 + + # If the normalized bias is below the activation bias correction threshold, return the unmodified graph + if normalized_bias < core_config.quantization_config.activation_bias_correction_threshold: + return graph + + # The correction term is a function of the layer type. + kernel = linear_node.get_weights_by_keys(fw_info.kernel_ops_attributes_mapping.get(linear_node.type)[0]) + + if kernel is not None: + output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(linear_node.type) + axis_not_output_channel = list(range(len(kernel.shape))) + axis_not_output_channel.remove(output_channel_index) + + if output_channel_index == input_channel_index: + axis_not_output_channel.remove(3) # 3 is the depth multiplier index + + activation_bias_correction_term = mean_diff * np.sum(kernel, axis=tuple(axis_not_output_channel)) + linear_node.final_activation_quantization_cfg.activation_bias_correction_term = activation_bias_correction_term.flatten() + return graph + + +def compute_activation_bias_correction_of_graph(graph: Graph, + core_config: CoreConfig, + fw_info: FrameworkInfo, + fw_impl: FrameworkImplementation, + activation_bias_correction_node_matchers: Callable, + kernel_size: str) -> Graph: + """ + Compute the activation bias correction term for the graph. + + Args: + graph: Graph with nodes to compute the activation bias correction. + core_config: Configuration object containing parameters of how the model should be quantized. + fw_info: Framework info like lists of nodes their kernel should quantized. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + activation_bias_correction_node_matchers: Function to match the layers for activation bias correction. + kernel_size: The framework specific attribute name of the convolution layer's kernel size. + + + Returns: + Graph with activation bias correction term for each relevant node. + """ + linear_node_types, bypass_node_types = activation_bias_correction_node_matchers() + + for n in graph.nodes: + if linear_node_types.apply(n): + prev_node, _ = get_next_nodes_to_correct(node=n, + graph=graph, + linear_node_types=linear_node_types, + bypass_node_types=bypass_node_types) + graph = compute_activation_bias_correction(graph=graph, + core_config=core_config, + fw_info=fw_info, + fw_impl=fw_impl, + linear_node=n, + prev_node=prev_node, + kernel_size=kernel_size) + return graph diff --git a/model_compression_toolkit/core/common/statistics_correction/statistics_correction.py b/model_compression_toolkit/core/common/statistics_correction/statistics_correction.py index e11aa0290..b07f3f3d1 100644 --- a/model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +++ b/model_compression_toolkit/core/common/statistics_correction/statistics_correction.py @@ -18,6 +18,8 @@ from model_compression_toolkit.core.common import Graph from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.quantization.core_config import CoreConfig +from model_compression_toolkit.core.common.statistics_correction.apply_activation_bias_correction_to_graph import \ + apply_activation_bias_correction_to_graph from model_compression_toolkit.core.common.statistics_correction.apply_bias_correction_to_graph import \ apply_bias_correction_to_graph from model_compression_toolkit.core.common.statistics_correction.apply_second_moment_correction_to_graph import \ @@ -73,7 +75,7 @@ def apply_statistics_correction(transformed_graph: Graph, fw_impl: FrameworkImplementation, tb_w: TensorboardWriter = None, ) -> Graph: """ - Apply statistics moment correction on graph. + Apply statistics correction on graph. Args: transformed_graph: Graph to apply statistics correction. representative_data_gen (Callable): Dataset used for calibration. @@ -84,7 +86,7 @@ def apply_statistics_correction(transformed_graph: Graph, tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc. Returns: - Graph after statistics correction correction. + Graph after statistics correction. """ ############################################# @@ -104,4 +106,14 @@ def apply_statistics_correction(transformed_graph: Graph, if tb_w is not None: tb_w.add_graph(transformed_graph, 'after_statistics_correction') + ############################################# + # Apply Activation Bias Correction + ############################################# + if core_config.quantization_config.activation_bias_correction: + transformed_graph = apply_activation_bias_correction_to_graph(graph=transformed_graph, + core_config=core_config, + fw_impl=fw_impl) + if tb_w is not None: + tb_w.add_graph(transformed_graph, 'after_activation_bias_correction') + return transformed_graph diff --git a/model_compression_toolkit/core/keras/keras_implementation.py b/model_compression_toolkit/core/keras/keras_implementation.py index 6020b137f..024c0bd54 100644 --- a/model_compression_toolkit/core/keras/keras_implementation.py +++ b/model_compression_toolkit/core/keras/keras_implementation.py @@ -28,6 +28,8 @@ from model_compression_toolkit.core.keras.hessian.activation_hessian_scores_calculator_keras import \ ActivationHessianScoresCalculatorKeras from model_compression_toolkit.core.keras.hessian.weights_hessian_scores_calculator_keras import WeightsHessianScoresCalculatorKeras +from model_compression_toolkit.core.keras.statistics_correction.keras_compute_activation_bias_correction_of_graph import \ + keras_compute_activation_bias_correction_of_graph from model_compression_toolkit.exporter.model_wrapper.fw_agnostic.get_inferable_quantizers import \ get_inferable_quantizers from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \ @@ -84,7 +86,7 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.residual_collapsing import \ keras_residual_collapsing from model_compression_toolkit.core.keras.graph_substitutions.substitutions.input_scaling import InputScaling, \ - InputScalingWithPad + InputScalingWithPad from model_compression_toolkit.core.keras.graph_substitutions.substitutions.concat_threshold_update import ConcatThresholdUpdate from model_compression_toolkit.core.keras.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \ ReLUBoundToPowerOfTwo @@ -218,6 +220,25 @@ def shift_negative_correction(self, core_config, fw_info) + def compute_activation_bias_correction(self, + graph: Graph, + core_config: CoreConfig, + fw_info: FrameworkInfo): + """ + Compute activation bias correction on a graph. + + Args: + graph: Graph to apply activation bias correction on. + core_config: QuantizationConfig of how the model should be quantized. + fw_info: FrameworkInfo object with information about the specific framework's model. + + Returns: + Graph after activation bias correction computing. + """ + return keras_compute_activation_bias_correction_of_graph(graph=graph, + core_config=core_config, + fw_info=fw_info, + fw_impl=self) def get_substitutions_channel_equalization(self, quant_config: QuantizationConfig, @@ -309,7 +330,7 @@ def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution: """ return keras_op2d_add_const_collapsing() - def get_substitutions_post_statistics_collection(self, + def get_substitutions_post_statistics_collection(self, quant_config: QuantizationConfig) -> List[common.BaseSubstitution]: """ Return a list of the framework substitutions used after we collect statistics. diff --git a/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py new file mode 100644 index 000000000..7dcdc6683 --- /dev/null +++ b/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py @@ -0,0 +1,88 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf +from packaging import version + +from model_compression_toolkit.core.keras.constants import KERNEL_SIZE + +if version.parse(tf.__version__) >= version.parse("2.13"): + from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ + MaxPooling2D, Flatten, Cropping2D, Permute, GlobalAveragePooling2D +else: + from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ + MaxPooling2D, Flatten, Cropping2D, Permute, GlobalAveragePooling2D + +from model_compression_toolkit.core import CoreConfig +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common import Graph +from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher +from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \ + compute_activation_bias_correction_of_graph + + +def activation_bias_correction_node_matchers(): + # Match linear layers where we can add a correction. + linear_node = NodeOperationMatcher(Conv2D) | \ + NodeOperationMatcher(Dense) | \ + NodeOperationMatcher(DepthwiseConv2D) + + # Match bypass layers what don't affect the quantization process. + bypass_node = (NodeOperationMatcher(Cropping2D) | + NodeOperationMatcher(GlobalAveragePooling2D) | + NodeOperationMatcher(Dropout) | + NodeOperationMatcher(ZeroPadding2D) | + NodeOperationMatcher(MaxPooling2D) | + NodeOperationMatcher(tf.split) | + NodeOperationMatcher(tf.cast) | + NodeOperationMatcher(tf.unstack) | + NodeOperationMatcher(tf.__operators__.getitem) | + NodeOperationMatcher(tf.strided_slice) | + NodeOperationMatcher(Reshape) | + NodeOperationMatcher(tf.reshape) | + NodeOperationMatcher(Permute) | + NodeOperationMatcher(tf.transpose) | + NodeOperationMatcher(Flatten) | + NodeOperationMatcher(tf.gather) | + NodeOperationMatcher(tf.compat.v1.gather)) + + return linear_node, bypass_node + + +def keras_compute_activation_bias_correction_of_graph(graph: Graph, + core_config: CoreConfig, + fw_info: FrameworkInfo, + fw_impl: FrameworkImplementation) -> Graph: + """ + Compute the activation bias correction term for graph based on a Keras model. + + Args: + graph: Graph with nodes to compute the activation bias correction. + core_config: Configuration object containing parameters of how the model should be quantized. + fw_info: Framework info like lists of nodes their kernel should quantized. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + + Returns: + Graph with activation bias correction term for each relevant node. + """ + graph = compute_activation_bias_correction_of_graph(graph=graph, + core_config=core_config, + fw_info=fw_info, + fw_impl=fw_impl, + activation_bias_correction_node_matchers= + activation_bias_correction_node_matchers, + kernel_size=KERNEL_SIZE) + return graph diff --git a/model_compression_toolkit/core/pytorch/pytorch_implementation.py b/model_compression_toolkit/core/pytorch/pytorch_implementation.py index 38e03dd08..2b5e3bae8 100644 --- a/model_compression_toolkit/core/pytorch/pytorch_implementation.py +++ b/model_compression_toolkit/core/pytorch/pytorch_implementation.py @@ -92,6 +92,8 @@ from model_compression_toolkit.core.pytorch.reader.reader import model_reader from model_compression_toolkit.core.pytorch.statistics_correction.apply_second_moment_correction import \ pytorch_apply_second_moment_correction +from model_compression_toolkit.core.pytorch.statistics_correction.pytorch_compute_activation_bias_correction_of_graph import \ + pytorch_compute_activation_bias_correction_of_graph from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model from model_compression_toolkit.exporter.model_wrapper.fw_agnostic.get_inferable_quantizers import \ get_inferable_quantizers @@ -212,6 +214,25 @@ def shift_negative_correction(self, core_config, fw_info) + def compute_activation_bias_correction(self, + graph: Graph, + core_config: CoreConfig, + fw_info: FrameworkInfo): + """ + Compute activation bias correction on a graph. + + Args: + graph: Graph to apply activation bias correction on. + core_config: QuantizationConfig of how the model should be quantized. + fw_info: FrameworkInfo object with information about the specific framework's model. + + Returns: + Graph after activation bias correction computing. + """ + return pytorch_compute_activation_bias_correction_of_graph(graph=graph, + core_config=core_config, + fw_info=fw_info, + fw_impl=self) def get_substitutions_channel_equalization(self, quant_config: QuantizationConfig, diff --git a/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py new file mode 100644 index 000000000..f2d17e4e5 --- /dev/null +++ b/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py @@ -0,0 +1,71 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from torch import reshape, transpose, flatten, permute +from torch.nn import Conv2d, Linear, Dropout, ZeroPad2d, AdaptiveAvgPool2d +from torch.nn.functional import avg_pool2d, pad + +from model_compression_toolkit.core import CoreConfig +from model_compression_toolkit.core.common import Graph +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher +from model_compression_toolkit.core.common.statistics_correction.compute_activation_bias_correction_of_graph import \ + compute_activation_bias_correction_of_graph +from model_compression_toolkit.core.pytorch.constants import KERNEL_SIZE + + +def activation_bias_correction_node_matchers(): + # Match linear layers where we can add a correction. + linear_node = NodeOperationMatcher(Linear) | NodeOperationMatcher(Conv2d) + + # Match bypass layers what don't affect the quantization process. + bypass_node = NodeOperationMatcher(reshape) | \ + NodeOperationMatcher(avg_pool2d) | \ + NodeOperationMatcher(transpose) | \ + NodeOperationMatcher(Dropout) | \ + NodeOperationMatcher(flatten) | \ + NodeOperationMatcher(ZeroPad2d) | \ + NodeOperationMatcher(pad) | \ + NodeOperationMatcher(AdaptiveAvgPool2d) | \ + NodeOperationMatcher(permute) + + return linear_node, bypass_node + + +def pytorch_compute_activation_bias_correction_of_graph(graph: Graph, + core_config: CoreConfig, + fw_info: FrameworkInfo, + fw_impl: FrameworkImplementation) -> Graph: + """ + Compute the activation bias correction term for graph based on a PyTorch model. + + Args: + graph: Graph with nodes to compute the activation bias correction. + core_config: Configuration object containing parameters of how the model should be quantized. + fw_info: Framework info like lists of nodes their kernel should quantized. + fw_impl: FrameworkImplementation object with a specific framework methods implementation. + + Returns: + Graph with activation bias correction term for each relevant node. + """ + graph = compute_activation_bias_correction_of_graph(graph=graph, + core_config=core_config, + fw_info=fw_info, + fw_impl=fw_impl, + activation_bias_correction_node_matchers= + activation_bias_correction_node_matchers, + kernel_size=KERNEL_SIZE) + return graph diff --git a/model_compression_toolkit/core/runner.py b/model_compression_toolkit/core/runner.py index 66b25f080..b111219bb 100644 --- a/model_compression_toolkit/core/runner.py +++ b/model_compression_toolkit/core/runner.py @@ -164,6 +164,14 @@ def core_runner(in_model: Any, tg, bit_widths_config) + ###################################### + # Compute Activation Bias Correction + ###################################### + if core_config.quantization_config.activation_bias_correction: + tg = fw_impl.compute_activation_bias_correction(graph=tg, + core_config=core_config, + fw_info=fw_info) + # Edit the graph again after finalizing the configurations. # This is since some actions regard the final configuration and should be edited. edit_network_graph(tg, fw_info, core_config.debug_config.network_editor) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py new file mode 100644 index 000000000..86dc0bff8 --- /dev/null +++ b/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py @@ -0,0 +1,161 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from model_compression_toolkit.core import QuantizationConfig +from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest + +import tensorflow as tf +import numpy as np + +from tests.keras_tests.utils import get_layers_from_model_by_type + +keras = tf.keras +layers = keras.layers + +""" +This test checks the Activation Bias Correction feature. +""" + +class BaseActivationBiasCorrectionTest(BaseKerasFeatureNetworkTest): + """ + This test checks the Activation Bias Correction feature. + """ + + def __init__(self, unit_test): + super().__init__(unit_test) + + def get_quantization_config(self): + return QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + activation_bias_correction=True) + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = layers.Activation('gelu')(inputs) + x = layers.Dropout(0.5)(x) + x = layers.Dropout(0.5)(x) + outputs = layers.Dense(30)(x) + return keras.Model(inputs=inputs, outputs=outputs) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + float_dense_layers = get_layers_from_model_by_type(float_model, layers.Dense) + quantized_dense_layers = get_layers_from_model_by_type(quantized_model, layers.Dense) + + bias = float_dense_layers[-1].bias + bias_after_activation_bias_correction = quantized_dense_layers[-1].layer.bias + + self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected a change in the bias value.") + + +class BaseActivationBiasCorrectionConvTest(BaseKerasFeatureNetworkTest): + """ + This test checks the Activation Bias Correction feature with Conv2D layer. + """ + + def __init__(self, unit_test): + super().__init__(unit_test) + + def get_quantization_config(self): + # A small value set to the activation bias correction threshold only to activate the threshold + # filtering without changing the bias correction values. + return QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + activation_bias_correction=True, + activation_bias_correction_threshold=1e-6) + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = layers.Activation('swish')(inputs) + x = layers.ZeroPadding2D(2)(x) + outputs = layers.Conv2D(filters=3, kernel_size=1, use_bias=True)(x) + return keras.Model(inputs=inputs, outputs=outputs) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + float_conv_layers = get_layers_from_model_by_type(float_model, layers.Conv2D) + quantized_conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) + + bias = float_conv_layers[-1].bias + bias_after_activation_bias_correction = quantized_conv_layers[-1].layer.bias + + self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected a change in the bias value.") + + +class BaseActivationBiasCorrectionDWConvTest(BaseKerasFeatureNetworkTest): + """ + This test checks the Activation Bias Correction feature with DepthWiseConv2D layer. + """ + + def __init__(self, unit_test): + super().__init__(unit_test) + + def get_quantization_config(self): + return QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + activation_bias_correction=True) + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = tf.nn.gelu(inputs) + x = layers.MaxPooling2D(pool_size=(2, 2), strides=2)(x) + outputs = layers.DepthwiseConv2D(kernel_size=1, use_bias=True, bias_initializer='glorot_uniform', + depth_multiplier=1)(x) + return keras.Model(inputs=inputs, outputs=outputs) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + float_dw_conv_layers = get_layers_from_model_by_type(float_model, layers.DepthwiseConv2D) + quantized_dw_conv_layers = get_layers_from_model_by_type(quantized_model, layers.DepthwiseConv2D) + + bias = float_dw_conv_layers[-1].bias + bias_after_activation_bias_correction = quantized_dw_conv_layers[-1].layer.bias + + self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected a change in the bias value.") + + +class BaseActivationBiasCorrectionReshapeConvTest(BaseKerasFeatureNetworkTest): + """ + This test checks the Activation Bias Correction feature. + """ + + def __init__(self, unit_test): + super().__init__(unit_test) + + def get_quantization_config(self): + # A large value is assigned to the activation bias correction threshold to enable threshold filtering, + # which adjusts the bias correction values to zero. + return QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + activation_bias_correction=True, + activation_bias_correction_threshold=1e9) + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = layers.Activation('swish')(inputs) + x = layers.Flatten()(x) + x = layers.Reshape(target_shape=(8, 8, 3))(x) + outputs = layers.Conv2D(filters=3, kernel_size=1, use_bias=True, bias_initializer='ones')(x) + return keras.Model(inputs=inputs, outputs=outputs) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + float_conv_layers = get_layers_from_model_by_type(float_model, layers.Conv2D) + quantized_conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) + + bias = float_conv_layers[-1].bias + bias_after_activation_bias_correction = quantized_conv_layers[-1].layer.bias + + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias value in " + f"case of activation_bias_correction_threshold 1e9.") diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 45c8cef3e..10ea22c28 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -28,6 +28,9 @@ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod from model_compression_toolkit.gptq import RoundingType from model_compression_toolkit.target_platform_capabilities import constants as C +from tests.keras_tests.feature_networks_tests.feature_networks.activation_bias_correction_test import \ + BaseActivationBiasCorrectionTest, BaseActivationBiasCorrectionConvTest, BaseActivationBiasCorrectionDWConvTest, \ + BaseActivationBiasCorrectionReshapeConvTest from tests.keras_tests.feature_networks_tests.feature_networks.activation_decomposition_test import \ ActivationDecompositionTest from tests.keras_tests.feature_networks_tests.feature_networks.activation_relu_bound_to_power_of_2_test import \ @@ -516,6 +519,12 @@ def test_conv2d_bn_concat(self): def test_activation_scaling_relu6(self): ReLUBoundToPOTNetTest(self).run_test() + def test_activation_bias_correction(self): + BaseActivationBiasCorrectionTest(self).run_test() + BaseActivationBiasCorrectionConvTest(self).run_test() + BaseActivationBiasCorrectionDWConvTest(self).run_test() + BaseActivationBiasCorrectionReshapeConvTest(self).run_test() + def test_layer_activation_softmax_shift(self): SoftmaxShiftTest(self, layers.Dense(20, activation='softmax'), None).run_test() diff --git a/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py b/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py new file mode 100644 index 000000000..cdd2a6970 --- /dev/null +++ b/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py @@ -0,0 +1,159 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import torch +from torch.nn import GELU, Hardswish, AdaptiveAvgPool2d, ZeroPad2d, Linear, Conv2d + +import model_compression_toolkit as mct +from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest + +""" +This test checks the Activation Bias Correction feature. +""" + + +class ActivationBiasCorrectionNet(torch.nn.Module): + """ + This is the network to test the Activation Bias Correction feature. + """ + + def __init__(self, + linear_layer, + bypass_layers): + super(ActivationBiasCorrectionNet, self).__init__() + self.activation_layer = GELU() + self.linear_layer = linear_layer + self.bypass_layers = torch.nn.ModuleList(bypass_layers) + + def forward(self, x): + x = self.activation_layer(x) + + for bypass_layer in self.bypass_layers: + x = bypass_layer(x) + x = self.linear_layer(x) + return x + + +class ActivationBiasCorrectionPadNet(torch.nn.Module): + """ + This is the network to test the Activation Bias Correction feature with pooling/padding layers as a bypass layers. + """ + + def __init__(self): + super(ActivationBiasCorrectionPadNet, self).__init__() + self.activation_layer = Hardswish() + self.pooling_layer = AdaptiveAvgPool2d(output_size=6) + self.padding_layer = ZeroPad2d(padding=2) + self.linear_layer = Linear(10, 10) + + def forward(self, x): + x = self.activation_layer(x) + x = self.pooling_layer(x) + x = self.padding_layer(x) + x = self.linear_layer(x) + return x + + +class ActivationBiasCorrectionReshapeNet(torch.nn.Module): + """ + This is the network to test the Activation Bias Correction feature with reshape layers as a bypass layers. + """ + + def __init__(self): + super(ActivationBiasCorrectionReshapeNet, self).__init__() + self.activation_layer = GELU() + self.linear_layer = Conv2d(in_channels=8, out_channels=1, kernel_size=1) + + def forward(self, x): + x = self.activation_layer(x) + x = x.flatten() + x = x.reshape(8, 2, -1) + x = self.linear_layer(x) + return x + + +class BaseActivationBiasCorrectionTest(BasePytorchFeatureNetworkTest): + def __init__(self, unit_test): + super().__init__(unit_test) + + def get_quantization_config(self): + return mct.core.QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + activation_bias_correction=True) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + bias = float_model.linear_layer.bias.cpu().detach().numpy() + bias_after_activation_bias_correction = quantized_model.linear_layer.layer.bias.cpu().detach().numpy() + self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected a change in the bias value.") + + +class BaseActivationBiasCorrectionNetTest(BaseActivationBiasCorrectionTest): + def __init__(self, unit_test, linear_layer, bypass_layers): + super().__init__(unit_test) + self.linear_layer = linear_layer + self.bypass_layers = bypass_layers + + def create_networks(self): + return ActivationBiasCorrectionNet(linear_layer=self.linear_layer, + bypass_layers=self.bypass_layers) + + +class BaseActivationBiasCorrectionPadNetTest(BaseActivationBiasCorrectionTest): + def __init__(self, unit_test): + super().__init__(unit_test) + + def get_quantization_config(self): + # A small value set to the activation bias correction threshold only to activate the threshold + # filtering without changing the bias correction values. + return mct.core.QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + activation_bias_correction=True, + activation_bias_correction_threshold=1e-6) + + def create_networks(self): + return ActivationBiasCorrectionPadNet() + + +class BaseActivationBiasCorrectionBigThrTest(BaseActivationBiasCorrectionTest): + def __init__(self, unit_test): + super().__init__(unit_test) + + def get_quantization_config(self): + # A large value is assigned to the activation bias correction threshold to enable threshold filtering, + # which adjusts the bias correction values to zero. + return mct.core.QuantizationConfig(weights_bias_correction=False, + weights_second_moment_correction=False, + activation_bias_correction=True, + activation_bias_correction_threshold=1e9) + + def create_networks(self): + return ActivationBiasCorrectionPadNet() + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + bias = float_model.linear_layer.bias.cpu().detach().numpy() + bias_after_activation_bias_correction = quantized_model.linear_layer.layer.bias.cpu().detach().numpy() + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias value in " + f"case of activation_bias_correction_threshold 1e9.") + + +class BaseActivationBiasCorrectionReshapeNetTest(BaseActivationBiasCorrectionTest): + def __init__(self, unit_test): + super().__init__(unit_test) + + def create_networks(self): + return ActivationBiasCorrectionReshapeNet() diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index 44f8193bb..d1fe8833a 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -30,6 +30,9 @@ from model_compression_toolkit.trainable_infrastructure import TrainingMethod from tests.pytorch_tests.model_tests.feature_models.activation_16bit_test import Activation16BitTest, \ Activation16BitMixedPrecisionTest +from tests.pytorch_tests.model_tests.feature_models.activation_bias_correction_test import \ + BaseActivationBiasCorrectionPadNetTest, BaseActivationBiasCorrectionNetTest, \ + BaseActivationBiasCorrectionReshapeNetTest, BaseActivationBiasCorrectionBigThrTest from tests.pytorch_tests.model_tests.feature_models.add_net_test import AddNetTest from tests.pytorch_tests.model_tests.feature_models.add_same_test import AddSameNetTest from tests.pytorch_tests.model_tests.feature_models.bn_attributes_quantization_test import BNAttributesQuantization @@ -302,13 +305,13 @@ def test_permute_substitution(self): This test checks the permute substitution feature """ PermuteSubstitutionTest(self).run_test() - + def test_reshape_substitution(self): """ This test checks the reshape substitution feature """ ReshapeSubstitutionTest(self).run_test() - + def test_constant_conv_substitution(self): """ This test checks the constant conv substitution feature @@ -438,6 +441,21 @@ def test_shift_negative_activation_net(self): for activation_layer in [torch.nn.Hardswish, torch.nn.GELU]: ShiftNegaviteActivationNetTest(self, activation_layer=activation_layer).run_test(seed=3) + def test_activation_bias_correction_net(self): + """ + This test checks the activation bias correction feature. + """ + BaseActivationBiasCorrectionPadNetTest(self).run_test() + BaseActivationBiasCorrectionReshapeNetTest(self).run_test() + BaseActivationBiasCorrectionBigThrTest(self).run_test() + for linear_layer in [nn.Linear(8, 20), + nn.Conv2d(3, 20, 1), + nn.Conv2d(3, 8, (1, 1))]: + for bypass_layers in [[nn.Dropout(0.5)], [nn.Dropout(0.5), nn.Dropout(0.5)]]: + BaseActivationBiasCorrectionNetTest(self, + linear_layer=linear_layer, + bypass_layers=bypass_layers).run_test() + def test_split_concat_net(self): """ This test checks: @@ -715,7 +733,7 @@ def test_bn_attributes_quantization(self): """ BNAttributesQuantization(self, quantize_linear=False).run_test() BNAttributesQuantization(self, quantize_linear=True).run_test() - + def test_concat_threshold_update(self): ConcatUpdateTest(self).run_test()