diff --git a/model_compression_toolkit/core/common/framework_implementation.py b/model_compression_toolkit/core/common/framework_implementation.py index 3e63c561b..ad7758d46 100644 --- a/model_compression_toolkit/core/common/framework_implementation.py +++ b/model_compression_toolkit/core/common/framework_implementation.py @@ -64,7 +64,7 @@ def get_hessian_scores_calculator(self, Returns: HessianScoresCalculator to use for the hessian approximation scores computation for this request. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_hessian_scores_calculator method.') # pragma: no cover @abstractmethod @@ -77,7 +77,7 @@ def to_numpy(self, tensor: Any) -> np.ndarray: Returns: Numpy array converted from the input tensor. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s to_numpy method.') # pragma: no cover @abstractmethod @@ -90,7 +90,7 @@ def to_tensor(self, tensor: np.ndarray) -> Any: Returns: Framework's tensor converted from the input Numpy array. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s to_tensor method.') # pragma: no cover @abstractmethod @@ -106,7 +106,7 @@ def model_reader(self, Returns: Graph representing the input model. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s model_reader method.') # pragma: no cover @abstractmethod @@ -131,7 +131,7 @@ def model_builder(self, Returns: A tuple with the model and additional relevant supporting objects. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s model_builder method.') # pragma: no cover @abstractmethod @@ -148,7 +148,7 @@ def run_model_inference(self, Returns: The frameworks model's output. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s run_model_inference method.') # pragma: no cover @abstractmethod @@ -167,26 +167,26 @@ def shift_negative_correction(self, Returns: Graph after SNC. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s apply_shift_negative_correction method.') # pragma: no cover @abstractmethod def compute_activation_bias_correction(self, graph: Graph, - core_config: CoreConfig, + quant_config: QuantizationConfig, fw_info: FrameworkInfo) -> Graph: """ Compute activation bias correction on a graph. Args: graph: Graph to apply activation bias correction on. - core_config: QuantizationConfig of how the model should be quantized. + quant_config: QuantizationConfig of how the model should be quantized. fw_info: FrameworkInfo object with information about the specific framework's model. Returns: Graph after activation bias correction computing. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s compute_activation_bias_correction method.') # pragma: no cover @abstractmethod @@ -203,7 +203,7 @@ def get_substitutions_channel_equalization(self, Returns: A list of the framework substitutions used after we collect statistics. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover @abstractmethod @@ -213,7 +213,7 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List Returns: A list of the framework substitutions used to prepare the graph. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover @abstractmethod @@ -227,7 +227,7 @@ def get_substitutions_pre_statistics_collection(self, quant_config: Quantization Returns: A list of the framework substitutions used before we collect statistics. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_pre_statistics_collection method.') # pragma: no cover @abstractmethod @@ -235,7 +235,7 @@ def get_linear_collapsing_substitution(self) -> common.BaseSubstitution: """ Returns: linear collapsing substitution """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover @abstractmethod @@ -243,7 +243,7 @@ def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution: """ Returns: conv2d add const collapsing substitution """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_op2d_add_const_collapsing_substitution method.') # pragma: no cover @abstractmethod @@ -258,7 +258,7 @@ def get_substitutions_statistics_correction(self, quant_config: QuantizationConf Returns: A list of the framework substitutions used for statistics correction. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_statistics_correction method.') # pragma: no cover @abstractmethod @@ -266,7 +266,7 @@ def get_residual_collapsing_substitution(self) -> List[common.BaseSubstitution]: """ Returns: A list of the framework substitutions used for residual collapsing """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_residual_collapsing_substitution method.') # pragma: no cover @@ -282,7 +282,7 @@ def get_substitutions_post_statistics_collection(self, quant_config: Quantizatio Returns: A list of the framework substitutions used after we collect statistics. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_post_statistics_collection method.') # pragma: no cover @abstractmethod @@ -291,7 +291,7 @@ def get_substitutions_virtual_weights_activation_coupling(self) -> List[common.B Returns: A list of Keras substitutions used to build a virtual graph with composed activation-weights pairs. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_virtual_weights_activation_coupling ' f'method.') # pragma: no cover @@ -307,7 +307,7 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz Returns: A list of the framework substitutions used after we apply second moment statistics. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_substitutions_after_second_moment_correction ' f'method.') # pragma: no cover @@ -335,7 +335,7 @@ def get_sensitivity_evaluator(self, A function that computes the metric. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover def get_node_prior_info(self, node: BaseNode, @@ -353,7 +353,7 @@ def get_node_prior_info(self, node: BaseNode, NodePriorInfo with information about the node. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_node_prior_info method.') # pragma: no cover def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool: @@ -364,7 +364,7 @@ def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool Returns: True if the node should be considered an interest point, False otherwise. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s count_node_for_mixed_precision_interest_points method.') # pragma: no cover def get_mp_node_distance_fn(self, n: BaseNode, @@ -383,7 +383,7 @@ def get_mp_node_distance_fn(self, n: BaseNode, Returns: A distance function between two tensors and a axis on which the distance is computed (if exists). """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_mp_node_distance_fn method.') # pragma: no cover @@ -400,7 +400,7 @@ def is_output_node_compatible_for_hessian_score_computation(self, """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s is_output_node_compatible_for_hessian_score_computation method.') # pragma: no cover @abstractmethod @@ -417,7 +417,7 @@ def get_node_mac_operations(self, Returns: The MAC count of the operation """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_node_mac_operations method.') # pragma: no cover @abstractmethod @@ -438,7 +438,7 @@ def apply_second_moment_correction(self, Returns: A Graph after second moment correction. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s apply_second_moment_correction method.') # pragma: no cover @abstractmethod @@ -455,7 +455,7 @@ def sensitivity_eval_inference(self, Returns: The output of the model inference on the given input. """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s sensitivity_eval_inference method.') # pragma: no cover def get_inferable_quantizers(self, node: BaseNode): @@ -471,7 +471,7 @@ def get_inferable_quantizers(self, node: BaseNode): """ - raise NotImplementedError(f'{self.__class__.__name__} have to implement the ' + raise NotImplementedError(f'{self.__class__.__name__} has to implement the ' f'framework\'s get_inferable_quantizers method.') # pragma: no cover @staticmethod diff --git a/model_compression_toolkit/core/common/quantization/node_quantization_config.py b/model_compression_toolkit/core/common/quantization/node_quantization_config.py index 0c68fb7b4..a790cbc77 100644 --- a/model_compression_toolkit/core/common/quantization/node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/node_quantization_config.py @@ -95,7 +95,9 @@ def __init__(self, self.activation_error_method = qc.activation_error_method self.activation_n_bits = op_cfg.activation_n_bits self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2 + self.activation_bias_correction_term = None self.enable_activation_quantization = op_cfg.enable_activation_quantization + self.quantization_preserving = op_cfg.quantization_preserving self.signedness = op_cfg.signedness self.activation_channel_equalization = qc.activation_channel_equalization self.input_scaling = qc.input_scaling diff --git a/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py b/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py index 83fcc14ef..293e3dcce 100644 --- a/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +++ b/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== -import copy - from model_compression_toolkit.core import CoreConfig, QuantizationConfig from model_compression_toolkit.core.common import BaseNode, Graph from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation @@ -26,7 +24,7 @@ def apply_activation_bias_correction_to_graph(graph: Graph, core_config: CoreConfig, fw_impl: FrameworkImplementation) -> Graph: """ - Get a graph, where each node has a final activation quantization configuration (with a activation bias + Get a graph, where each node has a final activation quantization configuration (with an activation bias correction term in it), and apply the activation bias correction for each node in the graph. Args: @@ -38,12 +36,11 @@ def apply_activation_bias_correction_to_graph(graph: Graph, Graph with activation bias correction apply to it's nodes. """ - graph = copy.deepcopy(graph) for n in graph.nodes: # Activation bias correction is only relevant for nodes with kernel op kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0] if core_config.quantization_config.activation_bias_correction and kernel_attr is not None and \ - hasattr(n.final_activation_quantization_cfg, 'activation_bias_correction_term'): + n.final_activation_quantization_cfg.activation_bias_correction_term is not None: # If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was # calculated during model preparation, and is used now in the node's bias term. _apply_activation_bias_correction_to_node(n, fw_impl, core_config.quantization_config) @@ -66,15 +63,19 @@ def _apply_activation_bias_correction_to_node(node: BaseNode, correction = node.final_activation_quantization_cfg.activation_bias_correction_term bias = node.get_weights_by_keys(fw_impl.constants.BIAS) # get original bias from node's weights - if bias is not None: # If the layer has bias, we subtract the correction from original bias - node.set_weights_by_keys(fw_impl.constants.BIAS, bias - correction) - else: - # If the layer has no bias, we consider it as if it has and its value is 0 and add a "dummy" attribute - # configuration with disabled quantization. + if bias is None: + # If the layer has no bias, we set the bias as -correction. node.set_weights_by_keys(fw_impl.constants.BIAS, - correction) - node.framework_attr[fw_impl.constants.USE_BIAS] = True # Mark the use_bias attribute of the node. + + # Mark the use_bias attribute of the node. + node.framework_attr[fw_impl.constants.USE_BIAS] = True + + # Configure the quantization of the bias as disabled. node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS, WeightsAttrQuantizationConfig( qc, AttributeQuantizationConfig( enable_weights_quantization=False))) + else: + # If the layer has bias, we subtract the correction from original bias + node.set_weights_by_keys(fw_impl.constants.BIAS, bias - correction) diff --git a/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py index c828250cb..765226c6d 100644 --- a/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +++ b/model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py @@ -12,82 +12,63 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from typing import List, Tuple, Any, Callable - import numpy as np +from typing import Any, Callable -from model_compression_toolkit.core import CoreConfig +from model_compression_toolkit.core import QuantizationConfig from model_compression_toolkit.core.common import BaseNode, Graph from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.framework_info import FrameworkInfo -from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher -def get_next_nodes_to_correct(node: BaseNode, - graph: Graph, - linear_node_types: NodeOperationMatcher, - bypass_node_types: NodeOperationMatcher, - bypass_nodes: List = None) -> Tuple[Any, Any]: +def get_previous_node_with_activation_quantization(linear_node: BaseNode, + graph: Graph) -> Any: """ - Search for the previous node which is not a bypass node of a given node. Go over the previous nodes of the node - and recursively search for a node. + Search recursively for the previous node with activation quantization. Args: - node: Node to search for its previous node. + linear_node: Node to search for its previous node. graph: Graph the node is in. - linear_node_types: Types of linear nodes to consider. - bypass_node_types: Types of nodes for bypassing to consider. - bypass_nodes: a list of bypass nodes found while running this function - Returns: The previous node (if found) and a list of bypass nodes (if any), or Nones if it were not found or there - are multiple incoming edges to one of nodes during the search (which means, the substitution can not be applied). + Returns: + The previous node (if found) or None if it was not found or there are multiple incoming edges to one of + nodes during the search (which means, the substitution can not be applied). """ - prev_nodes = graph.get_prev_nodes(node) + prev_nodes = graph.get_prev_nodes(linear_node) if len(prev_nodes) != 1: - return None, None # pragma: no cover + return None # pragma: no cover prev_node = prev_nodes[0] - # If the previous node is not a bypass type, return it as the valid node along with any bypass nodes - if not bypass_node_types.apply(prev_node): - return prev_node, bypass_nodes + activation_quantization_config = prev_node.final_activation_quantization_cfg - # If the previous node is a bypass node type, add it to the bypass_nodes list and continue searching - if bypass_node_types.apply(prev_node): - if bypass_nodes: - bypass_nodes.append(prev_node) - else: - bypass_nodes = [prev_node] - return get_next_nodes_to_correct(node=prev_node, - graph=graph, - linear_node_types=linear_node_types, - bypass_node_types=bypass_node_types, - bypass_nodes=bypass_nodes) - return None, None # pragma: no cover + # Search for node with activation quantization + if (activation_quantization_config.enable_activation_quantization and + not activation_quantization_config.quantization_preserving): + return prev_node + else: + return get_previous_node_with_activation_quantization(prev_node, graph) def calculate_bin_centers(bin_edges: np.ndarray) -> np.ndarray: """ Calculate the centers of bins given their edges. - Parameters: - bin_edges: Array of bin edges. + Args: + bin_edges: Array of bin edges. Returns: - np.ndarray: Array of bin centers. + np.ndarray: Array of bin centers. """ - # Ensure bin_edges is a numpy array - bin_edges = np.array(bin_edges, dtype=np.float32) - # Calculate the centers by averaging continuous bin edges bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0 return bin_centers def compute_activation_bias_correction(graph: Graph, - core_config: CoreConfig, + quant_config: QuantizationConfig, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation, linear_node: BaseNode, @@ -99,7 +80,7 @@ def compute_activation_bias_correction(graph: Graph, Args: graph: Graph with nodes to compute the activation bias correction for each node's final activation quantization configuration. - core_config: Configuration object containing parameters of how the model should be quantized. + quant_config: QuantizationConfig of how the model should be quantized. fw_info: Framework info like lists of nodes their kernel should quantized. fw_impl: FrameworkImplementation object with a specific framework methods implementation. linear_node: Node to compute the activation bias correction for. @@ -110,19 +91,16 @@ def compute_activation_bias_correction(graph: Graph, Graph with activation bias correction term for each node. """ - # Check if 'kernel_size' is a key in the framework-specific attributes of the linear_node, if it is then the - # linear_node is a convolution - if kernel_size in linear_node.framework_attr.keys(): - # Retrieve the value of 'kernel_size' and check if it is not 1 or (1, 1). This feature supports only kernel - # size of 1 or (1, 1). - if linear_node.framework_attr.get(kernel_size) not in [1, (1, 1)]: - # If the kernel size is not 1 or (1, 1), return the current graph unmodified - return graph + # Retrieve the 'kernel_size' value if it exists and ensure it is None, 1, or (1, 1). + # This feature supports only Dense/Linear layers and convolution layers with kernel size of 1 or (1, 1). + if linear_node.framework_attr.get(kernel_size) not in [None, 1, (1, 1)]: + # If the kernel size is not 1 or (1, 1), return the current graph unmodified + return graph prev_node_act_quant_cfg = prev_node.final_activation_quantization_cfg # Check if the previous node's has activation quantization configuration and if the previous node have the - # histogram collector + # histogram collector. if prev_node_act_quant_cfg is None or not hasattr(graph.get_out_stats_collector(prev_node), 'hc'): return graph # pragma: no cover @@ -143,25 +121,27 @@ def compute_activation_bias_correction(graph: Graph, # Compute the difference between the mean quantized center and the mean float center mean_diff = mean_quant_centers - mean_float_centers - # Check if activation bias correction is enabled based on the configured threshold - if core_config.quantization_config.activation_bias_correction_threshold > 0: - - # Calculate the normalized bias as a percentage of the float center norm - float_centers_norm1 = np.abs(mean_float_centers) - normalized_bias = 100 * np.abs(mean_diff) / float_centers_norm1 + # Calculate the normalized bias as a percentage of the float center norm + float_centers_norm1 = np.abs(mean_float_centers) + normalized_bias = 100 * np.abs(mean_diff) / float_centers_norm1 - # If the normalized bias is below the activation bias correction threshold, return the unmodified graph - if normalized_bias < core_config.quantization_config.activation_bias_correction_threshold: - return graph + # If the normalized bias is below the activation bias correction threshold, return the graph unmodified. + # By default, the threshold is set to 0, allowing all nodes to proceed in this case. + if normalized_bias < quant_config.activation_bias_correction_threshold: + return graph - # The correction term is a function of the layer type. kernel = linear_node.get_weights_by_keys(fw_info.kernel_ops_attributes_mapping.get(linear_node.type)[0]) + # Compute the activation bias correction by applying the quantization error to the kernel, resulting in an output + # size matching the number of output channels. if kernel is not None: + + # Get the axes that are not the output channel output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(linear_node.type) axis_not_output_channel = list(range(len(kernel.shape))) axis_not_output_channel.remove(output_channel_index) + # special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters if output_channel_index == input_channel_index: axis_not_output_channel.remove(3) # 3 is the depth multiplier index @@ -171,7 +151,7 @@ def compute_activation_bias_correction(graph: Graph, def compute_activation_bias_correction_of_graph(graph: Graph, - core_config: CoreConfig, + quant_config: QuantizationConfig, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation, activation_bias_correction_node_matchers: Callable, @@ -181,7 +161,7 @@ def compute_activation_bias_correction_of_graph(graph: Graph, Args: graph: Graph with nodes to compute the activation bias correction. - core_config: Configuration object containing parameters of how the model should be quantized. + quant_config: QuantizationConfig of how the model should be quantized. fw_info: Framework info like lists of nodes their kernel should quantized. fw_impl: FrameworkImplementation object with a specific framework methods implementation. activation_bias_correction_node_matchers: Function to match the layers for activation bias correction. @@ -191,17 +171,14 @@ def compute_activation_bias_correction_of_graph(graph: Graph, Returns: Graph with activation bias correction term for each relevant node. """ - linear_node_types, bypass_node_types = activation_bias_correction_node_matchers() + linear_node_types = activation_bias_correction_node_matchers() for n in graph.nodes: if linear_node_types.apply(n): - prev_node, _ = get_next_nodes_to_correct(node=n, - graph=graph, - linear_node_types=linear_node_types, - bypass_node_types=bypass_node_types) + prev_node = get_previous_node_with_activation_quantization(n, graph) if prev_node is not None: graph = compute_activation_bias_correction(graph=graph, - core_config=core_config, + quant_config=quant_config, fw_info=fw_info, fw_impl=fw_impl, linear_node=n, diff --git a/model_compression_toolkit/core/keras/keras_implementation.py b/model_compression_toolkit/core/keras/keras_implementation.py index 024c0bd54..85fab7637 100644 --- a/model_compression_toolkit/core/keras/keras_implementation.py +++ b/model_compression_toolkit/core/keras/keras_implementation.py @@ -222,21 +222,21 @@ def shift_negative_correction(self, def compute_activation_bias_correction(self, graph: Graph, - core_config: CoreConfig, + quant_config: QuantizationConfig, fw_info: FrameworkInfo): """ Compute activation bias correction on a graph. Args: graph: Graph to apply activation bias correction on. - core_config: QuantizationConfig of how the model should be quantized. + quant_config: QuantizationConfig of how the model should be quantized. fw_info: FrameworkInfo object with information about the specific framework's model. Returns: Graph after activation bias correction computing. """ return keras_compute_activation_bias_correction_of_graph(graph=graph, - core_config=core_config, + quant_config=quant_config, fw_info=fw_info, fw_impl=self) diff --git a/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py index 7dcdc6683..ce1ac9c23 100644 --- a/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +++ b/model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py @@ -19,13 +19,11 @@ from model_compression_toolkit.core.keras.constants import KERNEL_SIZE if version.parse(tf.__version__) >= version.parse("2.13"): - from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ - MaxPooling2D, Flatten, Cropping2D, Permute, GlobalAveragePooling2D + from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose else: - from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ - MaxPooling2D, Flatten, Cropping2D, Permute, GlobalAveragePooling2D + from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose -from model_compression_toolkit.core import CoreConfig +from model_compression_toolkit.core import QuantizationConfig from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.framework_info import FrameworkInfo from model_compression_toolkit.core.common import Graph @@ -38,32 +36,13 @@ def activation_bias_correction_node_matchers(): # Match linear layers where we can add a correction. linear_node = NodeOperationMatcher(Conv2D) | \ NodeOperationMatcher(Dense) | \ - NodeOperationMatcher(DepthwiseConv2D) - - # Match bypass layers what don't affect the quantization process. - bypass_node = (NodeOperationMatcher(Cropping2D) | - NodeOperationMatcher(GlobalAveragePooling2D) | - NodeOperationMatcher(Dropout) | - NodeOperationMatcher(ZeroPadding2D) | - NodeOperationMatcher(MaxPooling2D) | - NodeOperationMatcher(tf.split) | - NodeOperationMatcher(tf.cast) | - NodeOperationMatcher(tf.unstack) | - NodeOperationMatcher(tf.__operators__.getitem) | - NodeOperationMatcher(tf.strided_slice) | - NodeOperationMatcher(Reshape) | - NodeOperationMatcher(tf.reshape) | - NodeOperationMatcher(Permute) | - NodeOperationMatcher(tf.transpose) | - NodeOperationMatcher(Flatten) | - NodeOperationMatcher(tf.gather) | - NodeOperationMatcher(tf.compat.v1.gather)) - - return linear_node, bypass_node + NodeOperationMatcher(DepthwiseConv2D) | \ + NodeOperationMatcher(Conv2DTranspose) + return linear_node def keras_compute_activation_bias_correction_of_graph(graph: Graph, - core_config: CoreConfig, + quant_config: QuantizationConfig, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation) -> Graph: """ @@ -71,7 +50,7 @@ def keras_compute_activation_bias_correction_of_graph(graph: Graph, Args: graph: Graph with nodes to compute the activation bias correction. - core_config: Configuration object containing parameters of how the model should be quantized. + quant_config: QuantizationConfig of how the model should be quantized. fw_info: Framework info like lists of nodes their kernel should quantized. fw_impl: FrameworkImplementation object with a specific framework methods implementation. @@ -79,7 +58,7 @@ def keras_compute_activation_bias_correction_of_graph(graph: Graph, Graph with activation bias correction term for each relevant node. """ graph = compute_activation_bias_correction_of_graph(graph=graph, - core_config=core_config, + quant_config=quant_config, fw_info=fw_info, fw_impl=fw_impl, activation_bias_correction_node_matchers= diff --git a/model_compression_toolkit/core/pytorch/pytorch_implementation.py b/model_compression_toolkit/core/pytorch/pytorch_implementation.py index 2b5e3bae8..5ec26a66d 100644 --- a/model_compression_toolkit/core/pytorch/pytorch_implementation.py +++ b/model_compression_toolkit/core/pytorch/pytorch_implementation.py @@ -216,21 +216,21 @@ def shift_negative_correction(self, def compute_activation_bias_correction(self, graph: Graph, - core_config: CoreConfig, + quant_config: QuantizationConfig, fw_info: FrameworkInfo): """ Compute activation bias correction on a graph. Args: graph: Graph to apply activation bias correction on. - core_config: QuantizationConfig of how the model should be quantized. + quant_config: QuantizationConfig of how the model should be quantized. fw_info: FrameworkInfo object with information about the specific framework's model. Returns: Graph after activation bias correction computing. """ return pytorch_compute_activation_bias_correction_of_graph(graph=graph, - core_config=core_config, + quant_config=quant_config, fw_info=fw_info, fw_impl=self) diff --git a/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py b/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py index f2d17e4e5..149050cf2 100644 --- a/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +++ b/model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py @@ -13,11 +13,9 @@ # limitations under the License. # ============================================================================== -from torch import reshape, transpose, flatten, permute -from torch.nn import Conv2d, Linear, Dropout, ZeroPad2d, AdaptiveAvgPool2d -from torch.nn.functional import avg_pool2d, pad +from torch.nn import Conv2d, Linear, ConvTranspose2d -from model_compression_toolkit.core import CoreConfig +from model_compression_toolkit.core import QuantizationConfig from model_compression_toolkit.core.common import Graph from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.framework_info import FrameworkInfo @@ -29,24 +27,12 @@ def activation_bias_correction_node_matchers(): # Match linear layers where we can add a correction. - linear_node = NodeOperationMatcher(Linear) | NodeOperationMatcher(Conv2d) - - # Match bypass layers what don't affect the quantization process. - bypass_node = NodeOperationMatcher(reshape) | \ - NodeOperationMatcher(avg_pool2d) | \ - NodeOperationMatcher(transpose) | \ - NodeOperationMatcher(Dropout) | \ - NodeOperationMatcher(flatten) | \ - NodeOperationMatcher(ZeroPad2d) | \ - NodeOperationMatcher(pad) | \ - NodeOperationMatcher(AdaptiveAvgPool2d) | \ - NodeOperationMatcher(permute) - - return linear_node, bypass_node + linear_node = NodeOperationMatcher(Linear) | NodeOperationMatcher(Conv2d) | NodeOperationMatcher(ConvTranspose2d) + return linear_node def pytorch_compute_activation_bias_correction_of_graph(graph: Graph, - core_config: CoreConfig, + quant_config: QuantizationConfig, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation) -> Graph: """ @@ -54,7 +40,7 @@ def pytorch_compute_activation_bias_correction_of_graph(graph: Graph, Args: graph: Graph with nodes to compute the activation bias correction. - core_config: Configuration object containing parameters of how the model should be quantized. + quant_config: QuantizationConfig of how the model should be quantized. fw_info: Framework info like lists of nodes their kernel should quantized. fw_impl: FrameworkImplementation object with a specific framework methods implementation. @@ -62,7 +48,7 @@ def pytorch_compute_activation_bias_correction_of_graph(graph: Graph, Graph with activation bias correction term for each relevant node. """ graph = compute_activation_bias_correction_of_graph(graph=graph, - core_config=core_config, + quant_config=quant_config, fw_info=fw_info, fw_impl=fw_impl, activation_bias_correction_node_matchers= diff --git a/model_compression_toolkit/core/runner.py b/model_compression_toolkit/core/runner.py index b111219bb..65ef60176 100644 --- a/model_compression_toolkit/core/runner.py +++ b/model_compression_toolkit/core/runner.py @@ -169,7 +169,7 @@ def core_runner(in_model: Any, ###################################### if core_config.quantization_config.activation_bias_correction: tg = fw_impl.compute_activation_bias_correction(graph=tg, - core_config=core_config, + quant_config=core_config.quantization_config, fw_info=fw_info) # Edit the graph again after finalizing the configurations. diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py index 4c98fefd4..0de409258 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== from model_compression_toolkit.core import QuantizationConfig +from model_compression_toolkit.core.keras.constants import KERNEL_SIZE from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest import tensorflow as tf @@ -27,171 +28,58 @@ This test checks the Activation Bias Correction feature. """ + class BaseActivationBiasCorrectionTest(BaseKerasFeatureNetworkTest): """ This test checks the Activation Bias Correction feature. """ - def __init__(self, unit_test): + def __init__(self, unit_test, + prev_layer, + bypass_layer_list, + linear_layer, + activation_bias_correction_threshold=0.0): super().__init__(unit_test) + self.prev_layer = prev_layer + self.bypass_layer_list = bypass_layer_list + self.linear_layer = linear_layer + self.activation_bias_correction_threshold = activation_bias_correction_threshold def get_quantization_config(self): return QuantizationConfig(weights_bias_correction=False, weights_second_moment_correction=False, shift_negative_activation_correction=False, - activation_bias_correction=True) - - def create_networks(self): - inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) - x = layers.Activation('gelu')(inputs) - x = layers.Dropout(0.5)(x) - x = layers.Dropout(0.5)(x) - outputs = layers.Dense(30, use_bias=False)(x) - return keras.Model(inputs=inputs, outputs=outputs) - - def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): - float_dense_layers = get_layers_from_model_by_type(float_model, layers.Dense) - quantized_dense_layers = get_layers_from_model_by_type(quantized_model, layers.Dense) - - bias = float_dense_layers[-1].bias - bias_after_activation_bias_correction = quantized_dense_layers[-1].layer.bias - - self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), - msg=f"Error in activation bias correction: expected a change in the bias value.") - - -class BaseActivationBiasCorrectionConvTest(BaseKerasFeatureNetworkTest): - """ - This test checks the Activation Bias Correction feature with Conv2D layer. - """ - - def __init__(self, unit_test): - super().__init__(unit_test) - - def get_quantization_config(self): - # A small value set to the activation bias correction threshold only to activate the threshold - # filtering without changing the bias correction values. - return QuantizationConfig(weights_bias_correction=False, - weights_second_moment_correction=False, - activation_bias_correction=True, - activation_bias_correction_threshold=1e-6) - - def create_networks(self): - inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) - x = layers.Activation('swish')(inputs) - x = layers.ZeroPadding2D(2)(x) - outputs = layers.Conv2D(filters=3, kernel_size=1, use_bias=True)(x) - return keras.Model(inputs=inputs, outputs=outputs) - - def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): - float_conv_layers = get_layers_from_model_by_type(float_model, layers.Conv2D) - quantized_conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) - - bias = float_conv_layers[-1].bias - bias_after_activation_bias_correction = quantized_conv_layers[-1].layer.bias - - self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), - msg=f"Error in activation bias correction: expected a change in the bias value.") - - -class BaseActivationBiasCorrectionDWConvTest(BaseKerasFeatureNetworkTest): - """ - This test checks the Activation Bias Correction feature with DepthWiseConv2D layer. - """ - - def __init__(self, unit_test): - super().__init__(unit_test) - - def get_quantization_config(self): - return QuantizationConfig(weights_bias_correction=False, - weights_second_moment_correction=False, - activation_bias_correction=True) - - def create_networks(self): - inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) - x = tf.nn.gelu(inputs) - x = layers.MaxPooling2D(pool_size=(2, 2), strides=2)(x) - outputs = layers.DepthwiseConv2D(kernel_size=1, use_bias=True, bias_initializer='glorot_uniform', - depth_multiplier=1)(x) - return keras.Model(inputs=inputs, outputs=outputs) - - def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): - float_dw_conv_layers = get_layers_from_model_by_type(float_model, layers.DepthwiseConv2D) - quantized_dw_conv_layers = get_layers_from_model_by_type(quantized_model, layers.DepthwiseConv2D) - - bias = float_dw_conv_layers[-1].bias - bias_after_activation_bias_correction = quantized_dw_conv_layers[-1].layer.bias - - self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), - msg=f"Error in activation bias correction: expected a change in the bias value.") - - -class BaseActivationBiasCorrectionReshapeConvTest(BaseKerasFeatureNetworkTest): - """ - This test checks the Activation Bias Correction feature. - """ - - def __init__(self, unit_test): - super().__init__(unit_test) - - def get_quantization_config(self): - # A large value is assigned to the activation bias correction threshold to enable threshold filtering, - # which adjusts the bias correction values to zero. - return QuantizationConfig(weights_bias_correction=False, - weights_second_moment_correction=False, activation_bias_correction=True, - activation_bias_correction_threshold=1e9) + activation_bias_correction_threshold=self.activation_bias_correction_threshold) def create_networks(self): inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) - x = layers.Activation('swish')(inputs) - x = layers.Flatten()(x) - x = layers.Reshape(target_shape=(8, 8, 3))(x) - outputs = layers.Conv2D(filters=3, kernel_size=1, use_bias=True, bias_initializer='ones')(x) - return keras.Model(inputs=inputs, outputs=outputs) - - def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): - float_conv_layers = get_layers_from_model_by_type(float_model, layers.Conv2D) - quantized_conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) - - bias = float_conv_layers[-1].bias - bias_after_activation_bias_correction = quantized_conv_layers[-1].layer.bias - - self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), - msg=f"Error in activation bias correction: expected no change in the bias value in " - f"case of activation_bias_correction_threshold 1e9.") + x = self.prev_layer(inputs) + for bypass_layer in self.bypass_layer_list: + x = bypass_layer(x) -class BaseActivationBiasCorrectionConv8Test(BaseKerasFeatureNetworkTest): - """ - This test checks the Activation Bias Correction feature. - """ - - def __init__(self, unit_test): - super().__init__(unit_test) - - def get_quantization_config(self): - # A large value is assigned to the activation bias correction threshold to enable threshold filtering, - # which adjusts the bias correction values to zero. - return QuantizationConfig(weights_bias_correction=False, - weights_second_moment_correction=False, - activation_bias_correction=True) - - def create_networks(self): - inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) - x = layers.Activation('swish')(inputs) - x = layers.Flatten()(x) - x = layers.Reshape(target_shape=(8, 8, 3))(x) - outputs = layers.Conv2D(filters=3, kernel_size=2, use_bias=True, bias_initializer='ones')(x) + outputs = self.linear_layer(x) return keras.Model(inputs=inputs, outputs=outputs) def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): - float_conv_layers = get_layers_from_model_by_type(float_model, layers.Conv2D) - quantized_conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) - - bias = float_conv_layers[-1].bias - bias_after_activation_bias_correction = quantized_conv_layers[-1].layer.bias - - self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), - msg=f"Error in activation bias correction: expected no change in the bias value in " - f"case of conv2d with kernel 2.") + float_linear_layers = get_layers_from_model_by_type(float_model, type(self.linear_layer)) + quantized_linear_layers = get_layers_from_model_by_type(quantized_model, type(self.linear_layer)) + + bias = float_linear_layers[-1].bias + bias_after_activation_bias_correction = quantized_linear_layers[-1].layer.bias + + if getattr(float_linear_layers[-1], KERNEL_SIZE, None) in [None, 1, (1, 1)]: + if self.activation_bias_correction_threshold > 1e8: + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias " + f"value in case of activation_bias_correction_threshold " + f"{self.activation_bias_correction_threshold}.") + else: + self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected a change in the bias " + f"value.") + else: + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias value " + f"in case of conv with kernel 2.") diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index e48973f99..59336b057 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -29,8 +29,7 @@ from model_compression_toolkit.gptq import RoundingType from model_compression_toolkit.target_platform_capabilities import constants as C from tests.keras_tests.feature_networks_tests.feature_networks.activation_bias_correction_test import \ - BaseActivationBiasCorrectionTest, BaseActivationBiasCorrectionConvTest, BaseActivationBiasCorrectionDWConvTest, \ - BaseActivationBiasCorrectionReshapeConvTest, BaseActivationBiasCorrectionConv8Test + BaseActivationBiasCorrectionTest from tests.keras_tests.feature_networks_tests.feature_networks.activation_decomposition_test import \ ActivationDecompositionTest from tests.keras_tests.feature_networks_tests.feature_networks.activation_relu_bound_to_power_of_2_test import \ @@ -66,7 +65,8 @@ RequiresMixedPrecision, RequiresMixedPrecisionWeights from tests.keras_tests.feature_networks_tests.feature_networks.mixed_precision_bops_test import \ MixedPrecisionBopsBasicTest, MixedPrecisionBopsAllWeightsLayersTest, MixedPrecisionWeightsOnlyBopsTest, \ - MixedPrecisionActivationOnlyBopsTest, MixedPrecisionBopsAndWeightsUtilizationTest, MixedPrecisionBopsAndActivationUtilizationTest, \ + MixedPrecisionActivationOnlyBopsTest, MixedPrecisionBopsAndWeightsUtilizationTest, \ + MixedPrecisionBopsAndActivationUtilizationTest, \ MixedPrecisionBopsAndTotalUtilizationTest, MixedPrecisionBopsWeightsActivationUtilizationTest, \ MixedPrecisionBopsMultipleOutEdgesTest from tests.keras_tests.feature_networks_tests.feature_networks.mixed_precision_tests import \ @@ -143,7 +143,8 @@ MixedPrecisionSearchPartWeightsLayersTest, MixedPrecisionDepthwiseTest, MixedPrecisionSearchLastLayerDistanceTest, \ MixedPrecisionSearchActivationNonConfNodesTest, MixedPrecisionSearchTotalMemoryNonConfNodesTest, \ MixedPrecisionCombinedNMSTest -from tests.keras_tests.feature_networks_tests.feature_networks.matmul_substitution_test import MatmulToDenseSubstitutionTest +from tests.keras_tests.feature_networks_tests.feature_networks.matmul_substitution_test import \ + MatmulToDenseSubstitutionTest from tests.keras_tests.feature_networks_tests.feature_networks.metadata_test import MetadataTest from tests.keras_tests.feature_networks_tests.feature_networks.tpc_test import TpcTest from tests.keras_tests.feature_networks_tests.feature_networks.const_representation_test import ConstRepresentationTest, \ @@ -153,8 +154,10 @@ AdvancedConstQuantizationTest, ConstQuantizationMultiInputTest from tests.keras_tests.feature_networks_tests.feature_networks.activation_16bit_test import Activation16BitTest, \ Activation16BitMixedPrecisionTest -from tests.keras_tests.feature_networks_tests.feature_networks.sigmoid_mul_substitution_test import SigMulSubstitutionTest -from tests.keras_tests.feature_networks_tests.feature_networks.conv_func_substitutions_test import ConvFuncSubstitutionsTest +from tests.keras_tests.feature_networks_tests.feature_networks.sigmoid_mul_substitution_test import \ + SigMulSubstitutionTest +from tests.keras_tests.feature_networks_tests.feature_networks.conv_func_substitutions_test import \ + ConvFuncSubstitutionsTest from model_compression_toolkit.qat.common.qat_config import TrainingMethod layers = tf.keras.layers @@ -170,7 +173,7 @@ def test_remove_identity(self): def test_per_tensor_weight_quantization(self): PerTensorWeightQuantizationTest(self).run_test() - + def test_single_relu_replacement(self): SingleReluReplacementTest(self).run_test() @@ -243,7 +246,7 @@ def test_mixed_precision_weights_only_activation_conf(self): def test_requires_mixed_recision(self): RequiresMixedPrecisionWeights(self, weights_memory=True).run_test() - RequiresMixedPrecision(self,activation_memory=True).run_test() + RequiresMixedPrecision(self, activation_memory=True).run_test() RequiresMixedPrecision(self, total_memory=True).run_test() RequiresMixedPrecision(self, bops=True).run_test() RequiresMixedPrecision(self).run_test() @@ -520,11 +523,25 @@ def test_activation_scaling_relu6(self): ReLUBoundToPOTNetTest(self).run_test() def test_activation_bias_correction(self): - BaseActivationBiasCorrectionTest(self).run_test() - BaseActivationBiasCorrectionConvTest(self).run_test() - BaseActivationBiasCorrectionDWConvTest(self).run_test() - BaseActivationBiasCorrectionReshapeConvTest(self).run_test() - BaseActivationBiasCorrectionConv8Test(self).run_test() + for use_bias in [False, True]: + for linear_layer in [layers.Dense(30, use_bias=use_bias), + layers.Conv2D(filters=3, kernel_size=1, use_bias=use_bias), + layers.DepthwiseConv2D(kernel_size=1, use_bias=use_bias, + bias_initializer='glorot_uniform', depth_multiplier=1), + layers.Conv2DTranspose(filters=3, kernel_size=1, use_bias=use_bias), + layers.Conv2D(filters=3, kernel_size=2, use_bias=use_bias)]: + for activation_bias_correction_threshold in [0.0, 1e-6, 1e9]: + for activation in ['gelu', 'swish']: + for bypass_layer in [[layers.ZeroPadding2D(2)], + [layers.Dropout(0.5), layers.Dropout(0.5)], + [layers.MaxPooling2D(pool_size=(2, 2), strides=2)], + [layers.Flatten(), layers.Reshape(target_shape=(8, 8, 3))]]: + BaseActivationBiasCorrectionTest(self, + prev_layer=layers.Activation(activation), + bypass_layer_list=bypass_layer, + linear_layer=linear_layer, + activation_bias_correction_threshold= + activation_bias_correction_threshold).run_test() def test_layer_activation_softmax_shift(self): SoftmaxShiftTest(self, layers.Dense(20, activation='softmax'), None).run_test() @@ -545,9 +562,12 @@ def test_conv2dbn_folding(self): Conv2DBNFoldingTest(self).run_test() def test_bn_forward_folding(self): - BNForwardFoldingTest(self, layers.Conv2D(2, 1, bias_initializer='glorot_uniform'), True, is_dwconv=True).run_test() - BNForwardFoldingTest(self, layers.DepthwiseConv2D(1, bias_initializer='glorot_uniform'), True, is_dwconv=True).run_test() - BNForwardFoldingTest(self, layers.Conv2DTranspose(2, 1, bias_initializer='glorot_uniform'), True, is_dwconv=True).run_test() + BNForwardFoldingTest(self, layers.Conv2D(2, 1, bias_initializer='glorot_uniform'), True, + is_dwconv=True).run_test() + BNForwardFoldingTest(self, layers.DepthwiseConv2D(1, bias_initializer='glorot_uniform'), True, + is_dwconv=True).run_test() + BNForwardFoldingTest(self, layers.Conv2DTranspose(2, 1, bias_initializer='glorot_uniform'), True, + is_dwconv=True).run_test() BNForwardFoldingTest(self, layers.Conv2D(2, 2), False, is_dwconv=True).run_test() BNForwardFoldingTest(self, layers.DepthwiseConv2D((3, 1)), False, is_dwconv=True).run_test() BNForwardFoldingTest(self, layers.Conv2DTranspose(2, (1, 3)), False, is_dwconv=True).run_test() @@ -592,10 +612,14 @@ def test_const_quantization(self): for qmethod in [QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC, QuantizationMethod.UNIFORM]: for error_method in [QuantizationErrorMethod.MSE, QuantizationErrorMethod.NOCLIPPING]: ConstQuantizationTest(self, func, c, qmethod=qmethod, error_method=error_method).run_test() - ConstQuantizationTest(self, func, c, input_reverse_order=True, qmethod=qmethod, error_method=error_method).run_test() - ConstQuantizationTest(self, func, c, input_reverse_order=True, use_kwargs=True, qmethod=qmethod, error_method=error_method).run_test() - ConstQuantizationTest(self, func, c, use_kwargs=True, qmethod=qmethod, error_method=error_method).run_test() - ConstQuantizationTest(self, func, 5.1, input_reverse_order=True, qmethod=qmethod, error_method=error_method).run_test() + ConstQuantizationTest(self, func, c, input_reverse_order=True, qmethod=qmethod, + error_method=error_method).run_test() + ConstQuantizationTest(self, func, c, input_reverse_order=True, use_kwargs=True, qmethod=qmethod, + error_method=error_method).run_test() + ConstQuantizationTest(self, func, c, use_kwargs=True, qmethod=qmethod, + error_method=error_method).run_test() + ConstQuantizationTest(self, func, 5.1, input_reverse_order=True, qmethod=qmethod, + error_method=error_method).run_test() AdvancedConstQuantizationTest(self).run_test() ConstQuantizationMultiInputTest(self).run_test() @@ -617,7 +641,8 @@ def test_const_representation(self): for func in [layers.Add(), layers.Multiply(), layers.Subtract()]: ConstRepresentationTest(self, func, c, is_list_input=True).run_test() ConstRepresentationTest(self, func, c, input_reverse_order=True, is_list_input=True).run_test() - ConstRepresentationTest(self, func, c, input_reverse_order=True, use_kwargs=True, is_list_input=True).run_test() + ConstRepresentationTest(self, func, c, input_reverse_order=True, use_kwargs=True, + is_list_input=True).run_test() ConstRepresentationTest(self, func, c, use_kwargs=True, is_list_input=True).run_test() ConstRepresentationMultiInputTest(self).run_test() @@ -707,7 +732,6 @@ def test_gptq(self): # GradientPTQLearnRateZeroConvGroupTest(self).run_test() # GradientPTQWeightsUpdateConvGroupTest(self).run_test() - def test_gptq_conv_group_dilation(self): GradientPTQLearnRateZeroConvGroupDilationTest(self).run_test() GradientPTQWeightsUpdateConvGroupDilationTest(self).run_test() @@ -804,7 +828,8 @@ def test_qat(self): QuantizationAwareTrainingQuantizersTest(self).run_test() QuantizationAwareTrainingQuantizerHolderTest(self).run_test() QATWrappersMixedPrecisionCfgTest(self).run_test() - QATWrappersMixedPrecisionCfgTest(self, ru_weights=17920 * 4 / 8, ru_activation=5408 * 4 / 8, expected_mp_cfg=[0, 4, 1, 1]).run_test() + QATWrappersMixedPrecisionCfgTest(self, ru_weights=17920 * 4 / 8, ru_activation=5408 * 4 / 8, + expected_mp_cfg=[0, 4, 1, 1]).run_test() def test_bn_attributes_quantization(self): BNAttributesQuantization(self, quantize_linear=False).run_test() @@ -893,8 +918,8 @@ def test_exceptions_manual_selection(self): # Invalid inputs to API with self.assertRaises(Exception) as context: ManualBitWidthSelectionTest(self, - [NodeNameFilter('relu1'), NodeNameFilter('add1'), NodeNameFilter('add2')], - [2, 4]).run_test() + [NodeNameFilter('relu1'), NodeNameFilter('add1'), NodeNameFilter('add2')], + [2, 4]).run_test() # Check that the correct exception message was raised self.assertEqual(str(context.exception), "Configuration Error: The number of provided bit_width values 2 must match the number of filters 3, or a single bit_width value should be provided for all filters.") @@ -909,15 +934,15 @@ def test_manual_bit_width_selection(self): ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Add), 4).run_test() ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Add), 2).run_test() ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Conv2D), NodeTypeFilter(layers.Dense)], - [2, 4]).run_test() + [2, 4]).run_test() ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Conv2D), NodeTypeFilter(layers.Dense)], - [4, 4]).run_test() + [4, 4]).run_test() ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Conv2D), NodeTypeFilter(layers.Add)], - [2, 4]).run_test() + [2, 4]).run_test() ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeTypeFilter(layers.Conv2D)], - [4, 4]).run_test() + [4, 4]).run_test() ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeTypeFilter(layers.Dense)], - 4).run_test() + 4).run_test() ManualBitWidthSelectionTest(self, NodeNameFilter('input'), 4).run_test() ManualBitWidthSelectionTest(self, NodeNameFilter('conv1'), 4).run_test() ManualBitWidthSelectionTest(self, NodeNameFilter('fc'), 4).run_test() @@ -926,7 +951,7 @@ def test_manual_bit_width_selection(self): ManualBitWidthSelectionTest(self, NodeNameFilter('relu1'), 4).run_test() ManualBitWidthSelectionTest(self, [NodeNameFilter('add1'), NodeNameFilter('conv1')], [2, 4]).run_test() ManualBitWidthSelectionTest(self, [NodeNameFilter('add2'), NodeNameFilter('relu1')], 4).run_test() - ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeNameFilter('add2')],[4, 2]).run_test() + ManualBitWidthSelectionTest(self, [NodeTypeFilter(layers.Add), NodeNameFilter('add2')], [4, 2]).run_test() if __name__ == '__main__': diff --git a/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py b/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py index cdd2a6970..beba40d80 100644 --- a/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py @@ -18,6 +18,7 @@ from torch.nn import GELU, Hardswish, AdaptiveAvgPool2d, ZeroPad2d, Linear, Conv2d import model_compression_toolkit as mct +from model_compression_toolkit.core.pytorch.constants import KERNEL_SIZE from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest """ @@ -31,10 +32,11 @@ class ActivationBiasCorrectionNet(torch.nn.Module): """ def __init__(self, + prev_layer, linear_layer, bypass_layers): super(ActivationBiasCorrectionNet, self).__init__() - self.activation_layer = GELU() + self.activation_layer = prev_layer self.linear_layer = linear_layer self.bypass_layers = torch.nn.ModuleList(bypass_layers) @@ -46,7 +48,6 @@ def forward(self, x): x = self.linear_layer(x) return x - class ActivationBiasCorrectionPadNet(torch.nn.Module): """ This is the network to test the Activation Bias Correction feature with pooling/padding layers as a bypass layers. @@ -86,74 +87,39 @@ def forward(self, x): class BaseActivationBiasCorrectionTest(BasePytorchFeatureNetworkTest): - def __init__(self, unit_test): + def __init__(self, unit_test, + model, + activation_bias_correction_threshold=0.0): super().__init__(unit_test) + self.model = model + self.activation_bias_correction_threshold = activation_bias_correction_threshold def get_quantization_config(self): return mct.core.QuantizationConfig(weights_bias_correction=False, weights_second_moment_correction=False, - activation_bias_correction=True) - - def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): - bias = float_model.linear_layer.bias.cpu().detach().numpy() - bias_after_activation_bias_correction = quantized_model.linear_layer.layer.bias.cpu().detach().numpy() - self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), - msg=f"Error in activation bias correction: expected a change in the bias value.") - - -class BaseActivationBiasCorrectionNetTest(BaseActivationBiasCorrectionTest): - def __init__(self, unit_test, linear_layer, bypass_layers): - super().__init__(unit_test) - self.linear_layer = linear_layer - self.bypass_layers = bypass_layers - - def create_networks(self): - return ActivationBiasCorrectionNet(linear_layer=self.linear_layer, - bypass_layers=self.bypass_layers) - - -class BaseActivationBiasCorrectionPadNetTest(BaseActivationBiasCorrectionTest): - def __init__(self, unit_test): - super().__init__(unit_test) - - def get_quantization_config(self): - # A small value set to the activation bias correction threshold only to activate the threshold - # filtering without changing the bias correction values. - return mct.core.QuantizationConfig(weights_bias_correction=False, - weights_second_moment_correction=False, - activation_bias_correction=True, - activation_bias_correction_threshold=1e-6) - - def create_networks(self): - return ActivationBiasCorrectionPadNet() - - -class BaseActivationBiasCorrectionBigThrTest(BaseActivationBiasCorrectionTest): - def __init__(self, unit_test): - super().__init__(unit_test) - - def get_quantization_config(self): - # A large value is assigned to the activation bias correction threshold to enable threshold filtering, - # which adjusts the bias correction values to zero. - return mct.core.QuantizationConfig(weights_bias_correction=False, - weights_second_moment_correction=False, + shift_negative_activation_correction=False, activation_bias_correction=True, - activation_bias_correction_threshold=1e9) + activation_bias_correction_threshold= + self.activation_bias_correction_threshold) def create_networks(self): - return ActivationBiasCorrectionPadNet() + return self.model def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): bias = float_model.linear_layer.bias.cpu().detach().numpy() bias_after_activation_bias_correction = quantized_model.linear_layer.layer.bias.cpu().detach().numpy() - self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), - msg=f"Error in activation bias correction: expected no change in the bias value in " - f"case of activation_bias_correction_threshold 1e9.") - -class BaseActivationBiasCorrectionReshapeNetTest(BaseActivationBiasCorrectionTest): - def __init__(self, unit_test): - super().__init__(unit_test) - - def create_networks(self): - return ActivationBiasCorrectionReshapeNet() + if getattr(float_model.linear_layer, KERNEL_SIZE, None) in [None, 1, (1, 1)]: + if self.activation_bias_correction_threshold > 1e8: + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias " + f"value in case of activation_bias_correction_threshold " + f"{self.activation_bias_correction_threshold}.") + else: + self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected a change in the bias " + f"value.") + else: + self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction), + msg=f"Error in activation bias correction: expected no change in the bias value " + f"in case of conv with kernel different than 1 or (1, 1).") diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index d1fe8833a..b5c02f350 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -30,9 +30,9 @@ from model_compression_toolkit.trainable_infrastructure import TrainingMethod from tests.pytorch_tests.model_tests.feature_models.activation_16bit_test import Activation16BitTest, \ Activation16BitMixedPrecisionTest -from tests.pytorch_tests.model_tests.feature_models.activation_bias_correction_test import \ - BaseActivationBiasCorrectionPadNetTest, BaseActivationBiasCorrectionNetTest, \ - BaseActivationBiasCorrectionReshapeNetTest, BaseActivationBiasCorrectionBigThrTest +from tests.pytorch_tests.model_tests.feature_models.activation_bias_correction_test import ( + BaseActivationBiasCorrectionTest, ActivationBiasCorrectionNet, ActivationBiasCorrectionPadNet, + ActivationBiasCorrectionReshapeNet) from tests.pytorch_tests.model_tests.feature_models.add_net_test import AddNetTest from tests.pytorch_tests.model_tests.feature_models.add_same_test import AddSameNetTest from tests.pytorch_tests.model_tests.feature_models.bn_attributes_quantization_test import BNAttributesQuantization @@ -445,16 +445,30 @@ def test_activation_bias_correction_net(self): """ This test checks the activation bias correction feature. """ - BaseActivationBiasCorrectionPadNetTest(self).run_test() - BaseActivationBiasCorrectionReshapeNetTest(self).run_test() - BaseActivationBiasCorrectionBigThrTest(self).run_test() - for linear_layer in [nn.Linear(8, 20), - nn.Conv2d(3, 20, 1), - nn.Conv2d(3, 8, (1, 1))]: - for bypass_layers in [[nn.Dropout(0.5)], [nn.Dropout(0.5), nn.Dropout(0.5)]]: - BaseActivationBiasCorrectionNetTest(self, - linear_layer=linear_layer, - bypass_layers=bypass_layers).run_test() + model_list = [ActivationBiasCorrectionNet(prev_layer=nn.GELU(), + linear_layer=nn.Linear(8, 20), + bypass_layers=[nn.Dropout(0.5), nn.Dropout(0.5)]), + ActivationBiasCorrectionNet(prev_layer=nn.GELU(), + linear_layer=nn.Conv2d(3, 20, 1), + bypass_layers=[nn.Dropout(0.5), nn.Dropout(0.5)]), + ActivationBiasCorrectionNet(prev_layer=nn.GELU(), + linear_layer=nn.Conv2d(3, 8, (1, 1)), + bypass_layers=[nn.Dropout(0.5)]), + ActivationBiasCorrectionNet(prev_layer=nn.Hardswish(), + linear_layer=nn.ConvTranspose2d(3, 8, 1), + bypass_layers=[nn.Dropout(0.5)]), + ActivationBiasCorrectionNet(prev_layer=nn.GELU(), + linear_layer=nn.Conv2d(3, 8, 3), + bypass_layers=[nn.Dropout(0.5)]), + ActivationBiasCorrectionPadNet(), + ActivationBiasCorrectionReshapeNet()] + + for activation_bias_correction_threshold in [0.0, 1e-6, 1e9]: + for model in model_list: + BaseActivationBiasCorrectionTest(self, + model=model, + activation_bias_correction_threshold= + activation_bias_correction_threshold).run_test() def test_split_concat_net(self): """