Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add feature Activation Bias Correction #1256

Merged
merged 5 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 46 additions & 27 deletions model_compression_toolkit/core/common/framework_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_hessian_scores_calculator(self,

Returns: HessianScoresCalculator to use for the hessian approximation scores computation for this request.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_hessian_scores_calculator method.') # pragma: no cover

@abstractmethod
Expand All @@ -77,7 +77,7 @@ def to_numpy(self, tensor: Any) -> np.ndarray:
Returns:
Numpy array converted from the input tensor.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s to_numpy method.') # pragma: no cover

@abstractmethod
Expand All @@ -90,7 +90,7 @@ def to_tensor(self, tensor: np.ndarray) -> Any:
Returns:
Framework's tensor converted from the input Numpy array.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s to_tensor method.') # pragma: no cover

@abstractmethod
Expand All @@ -106,7 +106,7 @@ def model_reader(self,
Returns:
Graph representing the input model.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s model_reader method.') # pragma: no cover

@abstractmethod
Expand All @@ -131,7 +131,7 @@ def model_builder(self,
Returns:
A tuple with the model and additional relevant supporting objects.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s model_builder method.') # pragma: no cover

@abstractmethod
Expand All @@ -148,7 +148,7 @@ def run_model_inference(self,
Returns:
The frameworks model's output.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s run_model_inference method.') # pragma: no cover

@abstractmethod
Expand All @@ -167,9 +167,28 @@ def shift_negative_correction(self,
Returns:
Graph after SNC.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s apply_shift_negative_correction method.') # pragma: no cover

@abstractmethod
def compute_activation_bias_correction(self,
graph: Graph,
quant_config: QuantizationConfig,
fw_info: FrameworkInfo) -> Graph:
"""
Compute activation bias correction on a graph.

Args:
graph: Graph to apply activation bias correction on.
quant_config: QuantizationConfig of how the model should be quantized.
fw_info: FrameworkInfo object with information about the specific framework's model.

Returns:
Graph after activation bias correction computing.
"""
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s compute_activation_bias_correction method.') # pragma: no cover

@abstractmethod
def get_substitutions_channel_equalization(self,
quant_config: QuantizationConfig,
Expand All @@ -184,7 +203,7 @@ def get_substitutions_channel_equalization(self,
Returns:
A list of the framework substitutions used after we collect statistics.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover

@abstractmethod
Expand All @@ -194,7 +213,7 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
Returns: A list of the framework substitutions used to prepare the graph.

"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover

@abstractmethod
Expand All @@ -208,23 +227,23 @@ def get_substitutions_pre_statistics_collection(self, quant_config: Quantization
Returns: A list of the framework substitutions used before we collect statistics.

"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_pre_statistics_collection method.') # pragma: no cover

@abstractmethod
def get_linear_collapsing_substitution(self) -> common.BaseSubstitution:
"""
Returns: linear collapsing substitution
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover

@abstractmethod
def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution:
"""
Returns: conv2d add const collapsing substitution
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_op2d_add_const_collapsing_substitution method.') # pragma: no cover

@abstractmethod
Expand All @@ -239,15 +258,15 @@ def get_substitutions_statistics_correction(self, quant_config: QuantizationConf
Returns:
A list of the framework substitutions used for statistics correction.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_statistics_correction method.') # pragma: no cover

@abstractmethod
def get_residual_collapsing_substitution(self) -> List[common.BaseSubstitution]:
"""
Returns: A list of the framework substitutions used for residual collapsing
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_residual_collapsing_substitution method.') # pragma: no cover


Expand All @@ -263,7 +282,7 @@ def get_substitutions_post_statistics_collection(self, quant_config: Quantizatio
Returns:
A list of the framework substitutions used after we collect statistics.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_post_statistics_collection method.') # pragma: no cover

@abstractmethod
Expand All @@ -272,7 +291,7 @@ def get_substitutions_virtual_weights_activation_coupling(self) -> List[common.B
Returns: A list of Keras substitutions used to build a virtual graph with composed activation-weights pairs.
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_virtual_weights_activation_coupling '
f'method.') # pragma: no cover

Expand All @@ -288,7 +307,7 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz
Returns:
A list of the framework substitutions used after we apply second moment statistics.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_after_second_moment_correction '
f'method.') # pragma: no cover

Expand Down Expand Up @@ -316,7 +335,7 @@ def get_sensitivity_evaluator(self,
A function that computes the metric.
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover

def get_node_prior_info(self, node: BaseNode,
Expand All @@ -334,7 +353,7 @@ def get_node_prior_info(self, node: BaseNode,
NodePriorInfo with information about the node.
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_node_prior_info method.') # pragma: no cover

def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
Expand All @@ -345,7 +364,7 @@ def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool
Returns: True if the node should be considered an interest point, False otherwise.
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s count_node_for_mixed_precision_interest_points method.') # pragma: no cover

def get_mp_node_distance_fn(self, n: BaseNode,
Expand All @@ -364,7 +383,7 @@ def get_mp_node_distance_fn(self, n: BaseNode,
Returns: A distance function between two tensors and a axis on which the distance is computed (if exists).
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_mp_node_distance_fn method.') # pragma: no cover


Expand All @@ -381,7 +400,7 @@ def is_output_node_compatible_for_hessian_score_computation(self,

"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s is_output_node_compatible_for_hessian_score_computation method.') # pragma: no cover

@abstractmethod
Expand All @@ -398,7 +417,7 @@ def get_node_mac_operations(self,
Returns: The MAC count of the operation
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_node_mac_operations method.') # pragma: no cover

@abstractmethod
Expand All @@ -419,7 +438,7 @@ def apply_second_moment_correction(self,
Returns:
A Graph after second moment correction.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s apply_second_moment_correction method.') # pragma: no cover

@abstractmethod
Expand All @@ -436,7 +455,7 @@ def sensitivity_eval_inference(self,
Returns:
The output of the model inference on the given input.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s sensitivity_eval_inference method.') # pragma: no cover

def get_inferable_quantizers(self, node: BaseNode):
Expand All @@ -452,9 +471,9 @@ def get_inferable_quantizers(self, node: BaseNode):

"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_inferable_quantizers method.') # pragma: no cover

@staticmethod
def convert_data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def __init__(self,
self.activation_error_method = qc.activation_error_method
self.activation_n_bits = op_cfg.activation_n_bits
self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
self.activation_bias_correction_term = None
self.enable_activation_quantization = op_cfg.enable_activation_quantization
self.quantization_preserving = op_cfg.quantization_preserving
self.signedness = op_cfg.signedness
self.activation_channel_equalization = qc.activation_channel_equalization
self.input_scaling = qc.input_scaling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class QuantizationConfig:
shift_negative_threshold_recalculation: bool = False
shift_negative_params_search: bool = False
concat_threshold_update: bool = False
activation_bias_correction: bool = False
activation_bias_correction_threshold: float = 0.0


# Default quantization configuration the library use.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from model_compression_toolkit.core import CoreConfig, QuantizationConfig
from model_compression_toolkit.core.common import BaseNode, Graph
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import AttributeQuantizationConfig


def apply_activation_bias_correction_to_graph(graph: Graph,
core_config: CoreConfig,
fw_impl: FrameworkImplementation) -> Graph:
"""
Get a graph, where each node has a final activation quantization configuration (with an activation bias
correction term in it), and apply the activation bias correction for each node in the graph.

Args:
graph: Graph to apply activation bias correction to.
core_config: CoreConfig containing parameters of how the model should be quantized.
fw_impl: FrameworkImplementation object with a specific framework methods implementation.

Returns:
Graph with activation bias correction apply to it's nodes.
"""

for n in graph.nodes:
# Activation bias correction is only relevant for nodes with kernel op
kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
if core_config.quantization_config.activation_bias_correction and kernel_attr is not None and \
n.final_activation_quantization_cfg.activation_bias_correction_term is not None:
# If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was
# calculated during model preparation, and is used now in the node's bias term.
_apply_activation_bias_correction_to_node(n, fw_impl, core_config.quantization_config)
return graph


def _apply_activation_bias_correction_to_node(node: BaseNode,
fw_impl: FrameworkImplementation,
qc: QuantizationConfig):
"""
Set new bias to node using the activation bias correction term that is stored in the
final activation quantization configuration.

Args:
node: Node to set its corrected bias after activation bias correction.
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
qc: QuantizationConfig containing parameters of how the model should be quantized.

"""
correction = node.final_activation_quantization_cfg.activation_bias_correction_term
bias = node.get_weights_by_keys(fw_impl.constants.BIAS) # get original bias from node's weights

if bias is None:
# If the layer has no bias, we set the bias as -correction.
node.set_weights_by_keys(fw_impl.constants.BIAS, - correction)

# Mark the use_bias attribute of the node.
node.framework_attr[fw_impl.constants.USE_BIAS] = True

# Configure the quantization of the bias as disabled.
node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS,
lapid92 marked this conversation as resolved.
Show resolved Hide resolved
WeightsAttrQuantizationConfig(
qc,
AttributeQuantizationConfig(
enable_weights_quantization=False)))
else:
# If the layer has bias, we subtract the correction from original bias
node.set_weights_by_keys(fw_impl.constants.BIAS, bias - correction)
Loading
Loading