Skip to content

Commit

Permalink
Fix tests and node matchers
Browse files Browse the repository at this point in the history
  • Loading branch information
lapid92 committed Nov 6, 2024
1 parent 1d40f9b commit 2967f17
Show file tree
Hide file tree
Showing 13 changed files with 257 additions and 419 deletions.
58 changes: 29 additions & 29 deletions model_compression_toolkit/core/common/framework_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_hessian_scores_calculator(self,
Returns: HessianScoresCalculator to use for the hessian approximation scores computation for this request.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_hessian_scores_calculator method.') # pragma: no cover

@abstractmethod
Expand All @@ -77,7 +77,7 @@ def to_numpy(self, tensor: Any) -> np.ndarray:
Returns:
Numpy array converted from the input tensor.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s to_numpy method.') # pragma: no cover

@abstractmethod
Expand All @@ -90,7 +90,7 @@ def to_tensor(self, tensor: np.ndarray) -> Any:
Returns:
Framework's tensor converted from the input Numpy array.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s to_tensor method.') # pragma: no cover

@abstractmethod
Expand All @@ -106,7 +106,7 @@ def model_reader(self,
Returns:
Graph representing the input model.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s model_reader method.') # pragma: no cover

@abstractmethod
Expand All @@ -131,7 +131,7 @@ def model_builder(self,
Returns:
A tuple with the model and additional relevant supporting objects.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s model_builder method.') # pragma: no cover

@abstractmethod
Expand All @@ -148,7 +148,7 @@ def run_model_inference(self,
Returns:
The frameworks model's output.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s run_model_inference method.') # pragma: no cover

@abstractmethod
Expand All @@ -167,26 +167,26 @@ def shift_negative_correction(self,
Returns:
Graph after SNC.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s apply_shift_negative_correction method.') # pragma: no cover

@abstractmethod
def compute_activation_bias_correction(self,
graph: Graph,
core_config: CoreConfig,
quant_config: QuantizationConfig,
fw_info: FrameworkInfo) -> Graph:
"""
Compute activation bias correction on a graph.
Args:
graph: Graph to apply activation bias correction on.
core_config: QuantizationConfig of how the model should be quantized.
quant_config: QuantizationConfig of how the model should be quantized.
fw_info: FrameworkInfo object with information about the specific framework's model.
Returns:
Graph after activation bias correction computing.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s compute_activation_bias_correction method.') # pragma: no cover

@abstractmethod
Expand All @@ -203,7 +203,7 @@ def get_substitutions_channel_equalization(self,
Returns:
A list of the framework substitutions used after we collect statistics.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover

@abstractmethod
Expand All @@ -213,7 +213,7 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
Returns: A list of the framework substitutions used to prepare the graph.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover

@abstractmethod
Expand All @@ -227,23 +227,23 @@ def get_substitutions_pre_statistics_collection(self, quant_config: Quantization
Returns: A list of the framework substitutions used before we collect statistics.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_pre_statistics_collection method.') # pragma: no cover

@abstractmethod
def get_linear_collapsing_substitution(self) -> common.BaseSubstitution:
"""
Returns: linear collapsing substitution
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover

@abstractmethod
def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution:
"""
Returns: conv2d add const collapsing substitution
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_op2d_add_const_collapsing_substitution method.') # pragma: no cover

@abstractmethod
Expand All @@ -258,15 +258,15 @@ def get_substitutions_statistics_correction(self, quant_config: QuantizationConf
Returns:
A list of the framework substitutions used for statistics correction.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_statistics_correction method.') # pragma: no cover

@abstractmethod
def get_residual_collapsing_substitution(self) -> List[common.BaseSubstitution]:
"""
Returns: A list of the framework substitutions used for residual collapsing
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_residual_collapsing_substitution method.') # pragma: no cover


Expand All @@ -282,7 +282,7 @@ def get_substitutions_post_statistics_collection(self, quant_config: Quantizatio
Returns:
A list of the framework substitutions used after we collect statistics.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_post_statistics_collection method.') # pragma: no cover

@abstractmethod
Expand All @@ -291,7 +291,7 @@ def get_substitutions_virtual_weights_activation_coupling(self) -> List[common.B
Returns: A list of Keras substitutions used to build a virtual graph with composed activation-weights pairs.
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_virtual_weights_activation_coupling '
f'method.') # pragma: no cover

Expand All @@ -307,7 +307,7 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz
Returns:
A list of the framework substitutions used after we apply second moment statistics.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_after_second_moment_correction '
f'method.') # pragma: no cover

Expand Down Expand Up @@ -335,7 +335,7 @@ def get_sensitivity_evaluator(self,
A function that computes the metric.
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover

def get_node_prior_info(self, node: BaseNode,
Expand All @@ -353,7 +353,7 @@ def get_node_prior_info(self, node: BaseNode,
NodePriorInfo with information about the node.
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_node_prior_info method.') # pragma: no cover

def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
Expand All @@ -364,7 +364,7 @@ def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool
Returns: True if the node should be considered an interest point, False otherwise.
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s count_node_for_mixed_precision_interest_points method.') # pragma: no cover

def get_mp_node_distance_fn(self, n: BaseNode,
Expand All @@ -383,7 +383,7 @@ def get_mp_node_distance_fn(self, n: BaseNode,
Returns: A distance function between two tensors and a axis on which the distance is computed (if exists).
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_mp_node_distance_fn method.') # pragma: no cover


Expand All @@ -400,7 +400,7 @@ def is_output_node_compatible_for_hessian_score_computation(self,
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s is_output_node_compatible_for_hessian_score_computation method.') # pragma: no cover

@abstractmethod
Expand All @@ -417,7 +417,7 @@ def get_node_mac_operations(self,
Returns: The MAC count of the operation
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_node_mac_operations method.') # pragma: no cover

@abstractmethod
Expand All @@ -438,7 +438,7 @@ def apply_second_moment_correction(self,
Returns:
A Graph after second moment correction.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s apply_second_moment_correction method.') # pragma: no cover

@abstractmethod
Expand All @@ -455,7 +455,7 @@ def sensitivity_eval_inference(self,
Returns:
The output of the model inference on the given input.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s sensitivity_eval_inference method.') # pragma: no cover

def get_inferable_quantizers(self, node: BaseNode):
Expand All @@ -471,7 +471,7 @@ def get_inferable_quantizers(self, node: BaseNode):
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_inferable_quantizers method.') # pragma: no cover

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def __init__(self,
self.activation_error_method = qc.activation_error_method
self.activation_n_bits = op_cfg.activation_n_bits
self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
self.activation_bias_correction_term = None
self.enable_activation_quantization = op_cfg.enable_activation_quantization
self.quantization_preserving = op_cfg.quantization_preserving
self.signedness = op_cfg.signedness
self.activation_channel_equalization = qc.activation_channel_equalization
self.input_scaling = qc.input_scaling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# limitations under the License.
# ==============================================================================

import copy

from model_compression_toolkit.core import CoreConfig, QuantizationConfig
from model_compression_toolkit.core.common import BaseNode, Graph
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
Expand All @@ -26,7 +24,7 @@ def apply_activation_bias_correction_to_graph(graph: Graph,
core_config: CoreConfig,
fw_impl: FrameworkImplementation) -> Graph:
"""
Get a graph, where each node has a final activation quantization configuration (with a activation bias
Get a graph, where each node has a final activation quantization configuration (with an activation bias
correction term in it), and apply the activation bias correction for each node in the graph.
Args:
Expand All @@ -38,12 +36,11 @@ def apply_activation_bias_correction_to_graph(graph: Graph,
Graph with activation bias correction apply to it's nodes.
"""

graph = copy.deepcopy(graph)
for n in graph.nodes:
# Activation bias correction is only relevant for nodes with kernel op
kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
if core_config.quantization_config.activation_bias_correction and kernel_attr is not None and \
hasattr(n.final_activation_quantization_cfg, 'activation_bias_correction_term'):
n.final_activation_quantization_cfg.activation_bias_correction_term is not None:
# If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was
# calculated during model preparation, and is used now in the node's bias term.
_apply_activation_bias_correction_to_node(n, fw_impl, core_config.quantization_config)
Expand All @@ -66,15 +63,19 @@ def _apply_activation_bias_correction_to_node(node: BaseNode,
correction = node.final_activation_quantization_cfg.activation_bias_correction_term
bias = node.get_weights_by_keys(fw_impl.constants.BIAS) # get original bias from node's weights

if bias is not None: # If the layer has bias, we subtract the correction from original bias
node.set_weights_by_keys(fw_impl.constants.BIAS, bias - correction)
else:
# If the layer has no bias, we consider it as if it has and its value is 0 and add a "dummy" attribute
# configuration with disabled quantization.
if bias is None:
# If the layer has no bias, we set the bias as -correction.
node.set_weights_by_keys(fw_impl.constants.BIAS, - correction)
node.framework_attr[fw_impl.constants.USE_BIAS] = True # Mark the use_bias attribute of the node.

# Mark the use_bias attribute of the node.
node.framework_attr[fw_impl.constants.USE_BIAS] = True

# Configure the quantization of the bias as disabled.
node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS,
WeightsAttrQuantizationConfig(
qc,
AttributeQuantizationConfig(
enable_weights_quantization=False)))
else:
# If the layer has bias, we subtract the correction from original bias
node.set_weights_by_keys(fw_impl.constants.BIAS, bias - correction)
Loading

0 comments on commit 2967f17

Please sign in to comment.