Skip to content

Commit

Permalink
Add feture Activation Bias Correction
Browse files Browse the repository at this point in the history
  • Loading branch information
lapid92 committed Nov 4, 2024
1 parent fb3f75d commit 5905dfa
Show file tree
Hide file tree
Showing 14 changed files with 886 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,25 @@ def shift_negative_correction(self,
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s apply_shift_negative_correction method.') # pragma: no cover

@abstractmethod
def compute_activation_bias_correction(self,
graph: Graph,
core_config: CoreConfig,
fw_info: FrameworkInfo) -> Graph:
"""
Compute activation bias correction on a graph.
Args:
graph: Graph to apply activation bias correction on.
core_config: QuantizationConfig of how the model should be quantized.
fw_info: FrameworkInfo object with information about the specific framework's model.
Returns:
Graph after activation bias correction computing.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s compute_activation_bias_correction method.') # pragma: no cover

@abstractmethod
def get_substitutions_channel_equalization(self,
quant_config: QuantizationConfig,
Expand Down Expand Up @@ -454,7 +473,7 @@ def get_inferable_quantizers(self, node: BaseNode):

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_inferable_quantizers method.') # pragma: no cover

@staticmethod
def convert_data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class QuantizationConfig:
weights_error_method: QuantizationErrorMethod = QuantizationErrorMethod.MSE
relu_bound_to_power_of_2: bool = False
weights_bias_correction: bool = True
activation_bias_correction: bool = False
activation_bias_correction_threshold: float = 0.0
weights_second_moment_correction: bool = False
input_scaling: bool = False
softmax_shift: bool = False
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import copy

from model_compression_toolkit.core import CoreConfig, QuantizationConfig
from model_compression_toolkit.core.common import BaseNode, Graph
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import AttributeQuantizationConfig


def apply_activation_bias_correction_to_graph(graph: Graph,
core_config: CoreConfig,
fw_impl: FrameworkImplementation) -> Graph:
"""
Get a graph, where each node has a final activation quantization configuration (with a activation bias
correction term in it), and apply the activation bias correction for each node in the graph.
Args:
graph: Graph to apply activation bias correction to.
core_config: CoreConfig containing parameters of how the model should be quantized.
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
Returns:
Graph with activation bias correction apply to it's nodes.
"""

graph = copy.deepcopy(graph)
for n in graph.nodes:
# Activation bias correction is only relevant for nodes with kernel op
kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
if core_config.quantization_config.activation_bias_correction and kernel_attr is not None and \
hasattr(n.final_activation_quantization_cfg, 'activation_bias_correction_term'):
# If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was
# calculated during model preparation, and is used now in the node's bias term.
_apply_activation_bias_correction_to_node(n, fw_impl, core_config.quantization_config)
return graph


def _apply_activation_bias_correction_to_node(node: BaseNode,
fw_impl: FrameworkImplementation,
qc: QuantizationConfig):
"""
Set new bias to node using the activation bias correction term that is stored in the
final activation quantization configuration.
Args:
node: Node to set its corrected bias after activation bias correction.
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
qc: QuantizationConfig containing parameters of how the model should be quantized.
"""
correction = node.final_activation_quantization_cfg.activation_bias_correction_term
bias = node.get_weights_by_keys(fw_impl.constants.BIAS) # get original bias from node's weights

if bias is not None: # If the layer has bias, we subtract the correction from original bias
node.set_weights_by_keys(fw_impl.constants.BIAS, bias - correction)
else:
# If the layer has no bias, we consider it as if it has and its value is 0 and add a "dummy" attribute
# configuration with disabled quantization.
node.set_weights_by_keys(fw_impl.constants.BIAS, - correction)
node.framework_attr[fw_impl.constants.USE_BIAS] = True # Mark the use_bias attribute of the node.
node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS,
WeightsAttrQuantizationConfig(
qc,
AttributeQuantizationConfig(
enable_weights_quantization=False)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import List, Tuple, Any, Callable

import numpy as np

from model_compression_toolkit.core import CoreConfig
from model_compression_toolkit.core.common import BaseNode, Graph
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher


def get_next_nodes_to_correct(node: BaseNode,
graph: Graph,
linear_node_types: NodeOperationMatcher,
bypass_node_types: NodeOperationMatcher,
bypass_nodes: List = None) -> Tuple[Any, Any]:
"""
Search for the previous node which is not a bypass node of a given node. Go over the previous nodes of the node
and recursively search for a node.
Args:
node: Node to search for its previous node.
graph: Graph the node is in.
linear_node_types: Types of linear nodes to consider.
bypass_node_types: Types of nodes for bypassing to consider.
bypass_nodes: a list of bypass nodes found while running this function
Returns: The previous node (if found) and a list of bypass nodes (if any), or Nones if it were not found or there
are multiple incoming edges to one of nodes during the search (which means, the substitution can not be applied).
"""

prev_nodes = graph.get_prev_nodes(node)

if len(prev_nodes) != 1:
return None, None

prev_node = prev_nodes[0]

# If the previous node is not a bypass type, return it as the valid node along with any bypass nodes
if not bypass_node_types.apply(prev_node):
return prev_node, bypass_nodes

# If the previous node is a bypass node type, add it to the bypass_nodes list and continue searching
if bypass_node_types.apply(prev_node):
if bypass_nodes:
bypass_nodes.append(prev_node)
else:
bypass_nodes = [prev_node]
return get_next_nodes_to_correct(node=prev_node,
graph=graph,
linear_node_types=linear_node_types,
bypass_node_types=bypass_node_types,
bypass_nodes=bypass_nodes)
return None, None


def calculate_bin_centers(bin_edges: np.ndarray) -> np.ndarray:
"""
Calculate the centers of bins given their edges.
Parameters:
bin_edges: Array of bin edges.
Returns:
np.ndarray: Array of bin centers.
"""
# Ensure bin_edges is a numpy array
bin_edges = np.array(bin_edges, dtype=np.float32)

# Calculate the centers by averaging continuous bin edges
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0
return bin_centers


def compute_activation_bias_correction(graph: Graph,
core_config: CoreConfig,
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation,
linear_node: BaseNode,
prev_node: BaseNode,
kernel_size: str) -> Graph:
"""
Compute the activation bias correction term, and store it in the final activation
quantization configuration.
Args:
graph: Graph with nodes to compute the activation bias correction for each node's final activation quantization configuration.
core_config: Configuration object containing parameters of how the model should be quantized.
fw_info: Framework info like lists of nodes their kernel should quantized.
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
linear_node: Node to compute the activation bias correction for.
prev_node: Node to compute the activation error caused by his activation quantization.
kernel_size: The framework specific attribute name of the convolution layer's kernel size.
Returns:
Graph with activation bias correction term for each node.
"""

# Check if 'kernel_size' is a key in the framework-specific attributes of the linear_node, if it is then the
# linear_node is a convolution
if kernel_size in linear_node.framework_attr.keys():
# Retrieve the value of 'kernel_size' and check if it is not 1 or (1, 1). This feature supports only kernel
# size of 1 or (1, 1).
if linear_node.framework_attr.get(kernel_size) not in [1, (1, 1)]:
# If the kernel size is not 1 or (1, 1), return the current graph unmodified
return graph

prev_node_act_quant_cfg = prev_node.final_activation_quantization_cfg

# Check if the previous node's has activation quantization configuration and if the previous node have the
# histogram collector
if prev_node_act_quant_cfg is None or not hasattr(graph.get_out_stats_collector(prev_node), 'hc'):
return graph

float_bins, float_count = graph.get_out_stats_collector(prev_node).hc.get_histogram()

# Calculate the centers of the float bins
float_centers = calculate_bin_centers(float_bins)

# Quantize the bin edges and calculate the centers of the quantized bins
quant_bins = prev_node_act_quant_cfg.quantize_node_output(fw_impl.to_tensor(float_bins))
quant_bins = fw_impl.to_numpy(quant_bins)
quant_centers = calculate_bin_centers(quant_bins)

# Calculate the mean of the both the float and the quantized bin centers, weighted by the bin counts
mean_float_centers = np.sum(float_centers * float_count) / np.sum(float_count)
mean_quant_centers = np.sum(quant_centers * float_count) / np.sum(float_count)

# Compute the difference between the mean quantized center and the mean float center
mean_diff = mean_quant_centers - mean_float_centers

# Check if activation bias correction is enabled based on the configured threshold
if core_config.quantization_config.activation_bias_correction_threshold > 0:

# Calculate the normalized bias as a percentage of the float center norm
float_centers_norm1 = np.abs(mean_float_centers)
normalized_bias = 100 * np.abs(mean_diff) / float_centers_norm1

# If the normalized bias is below the activation bias correction threshold, return the unmodified graph
if normalized_bias < core_config.quantization_config.activation_bias_correction_threshold:
return graph

# The correction term is a function of the layer type.
kernel = linear_node.get_weights_by_keys(fw_info.kernel_ops_attributes_mapping.get(linear_node.type)[0])

if kernel is not None:
output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(linear_node.type)
axis_not_output_channel = list(range(len(kernel.shape)))
axis_not_output_channel.remove(output_channel_index)

if output_channel_index == input_channel_index:
axis_not_output_channel.remove(3) # 3 is the depth multiplier index

activation_bias_correction_term = mean_diff * np.sum(kernel, axis=tuple(axis_not_output_channel))
linear_node.final_activation_quantization_cfg.activation_bias_correction_term = activation_bias_correction_term.flatten()
return graph


def compute_activation_bias_correction_of_graph(graph: Graph,
core_config: CoreConfig,
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation,
activation_bias_correction_node_matchers: Callable,
kernel_size: str) -> Graph:
"""
Compute the activation bias correction term for the graph.
Args:
graph: Graph with nodes to compute the activation bias correction.
core_config: Configuration object containing parameters of how the model should be quantized.
fw_info: Framework info like lists of nodes their kernel should quantized.
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
activation_bias_correction_node_matchers: Function to match the layers for activation bias correction.
kernel_size: The framework specific attribute name of the convolution layer's kernel size.
Returns:
Graph with activation bias correction term for each relevant node.
"""
linear_node_types, bypass_node_types = activation_bias_correction_node_matchers()

for n in graph.nodes:
if linear_node_types.apply(n):
prev_node, _ = get_next_nodes_to_correct(node=n,
graph=graph,
linear_node_types=linear_node_types,
bypass_node_types=bypass_node_types)
graph = compute_activation_bias_correction(graph=graph,
core_config=core_config,
fw_info=fw_info,
fw_impl=fw_impl,
linear_node=n,
prev_node=prev_node,
kernel_size=kernel_size)
return graph
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
from model_compression_toolkit.core.common.statistics_correction.apply_activation_bias_correction_to_graph import \
apply_activation_bias_correction_to_graph
from model_compression_toolkit.core.common.statistics_correction.apply_bias_correction_to_graph import \
apply_bias_correction_to_graph
from model_compression_toolkit.core.common.statistics_correction.apply_second_moment_correction_to_graph import \
Expand Down Expand Up @@ -73,7 +75,7 @@ def apply_statistics_correction(transformed_graph: Graph,
fw_impl: FrameworkImplementation,
tb_w: TensorboardWriter = None, ) -> Graph:
"""
Apply statistics moment correction on graph.
Apply statistics correction on graph.
Args:
transformed_graph: Graph to apply statistics correction.
representative_data_gen (Callable): Dataset used for calibration.
Expand All @@ -84,7 +86,7 @@ def apply_statistics_correction(transformed_graph: Graph,
tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
Returns:
Graph after statistics correction correction.
Graph after statistics correction.
"""

#############################################
Expand All @@ -104,4 +106,14 @@ def apply_statistics_correction(transformed_graph: Graph,
if tb_w is not None:
tb_w.add_graph(transformed_graph, 'after_statistics_correction')

#############################################
# Apply Activation Bias Correction
#############################################
if core_config.quantization_config.activation_bias_correction:
transformed_graph = apply_activation_bias_correction_to_graph(graph=transformed_graph,
core_config=core_config,
fw_impl=fw_impl)
if tb_w is not None:
tb_w.add_graph(transformed_graph, 'after_activation_bias_correction')

return transformed_graph
Loading

0 comments on commit 5905dfa

Please sign in to comment.