From 262db8afc5245646dedfe5e77e223ee028adc2db Mon Sep 17 00:00:00 2001 From: Elad Cohen <78862769+elad-c@users.noreply.github.com> Date: Tue, 10 Dec 2024 17:13:01 +0200 Subject: [PATCH] Just some cosmetics (#1291) --- .../core/common/fusion/graph_fuser.py | 2 ++ .../memory_graph/compute_graph_max_cut.py | 1 + .../mixed_precision/bit_width_setter.py | 24 +++++++------------ .../mixed_precision_candidates_filter.py | 2 +- .../resource_utilization_data.py | 1 - .../ru_aggregation_methods.py | 24 +++++++++---------- .../search_methods/linear_programming.py | 4 ++-- .../solution_refinement_procedure.py | 2 +- model_compression_toolkit/core/runner.py | 3 +-- .../feature_models/manual_bit_selection.py | 6 ++--- .../mixed_precision_activation_test.py | 3 ++- 11 files changed, 33 insertions(+), 39 deletions(-) diff --git a/model_compression_toolkit/core/common/fusion/graph_fuser.py b/model_compression_toolkit/core/common/fusion/graph_fuser.py index a3df7c26f..3dac5a009 100644 --- a/model_compression_toolkit/core/common/fusion/graph_fuser.py +++ b/model_compression_toolkit/core/common/fusion/graph_fuser.py @@ -26,6 +26,8 @@ class FusedLayerType: """ def __init__(self): self.__name__ = 'FusedLayer' + + class GraphFuser: def create_fused_graph(self, graph: Graph) -> Dict[str, str]: diff --git a/model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py b/model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py index 8271e6c40..6ce792c7f 100644 --- a/model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py +++ b/model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py @@ -24,6 +24,7 @@ SchedulerInfo = namedtuple('SchedulerInfo', [OPERATORS_SCHEDULING, MAX_CUT, CUTS, FUSED_NODES_MAPPING]) + def compute_graph_max_cut(memory_graph: MemoryGraph, n_iter: int = 50, astar_n_iter: int = 500, diff --git a/model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py b/model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py index 76b3fc8d1..b1c0ea4ca 100644 --- a/model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +++ b/model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py @@ -48,9 +48,8 @@ def set_bit_widths(mixed_precision_enable: bool, node_name = node.name if not node.reuse else '_'.join(node.name.split('_')[:-2]) if node_name in sorted_nodes_names: # only configurable nodes are in this list node_index_in_graph = sorted_nodes_names.index(node_name) - _set_node_final_qc(bit_widths_config, + _set_node_final_qc(bit_widths_config[node_index_in_graph], node, - node_index_in_graph, graph.fw_info) else: if node.is_activation_quantization_enabled(): @@ -83,8 +82,7 @@ def set_bit_widths(mixed_precision_enable: bool, def _get_node_qc_by_bit_widths(node: BaseNode, - bit_width_cfg: List[int], - node_index_in_graph: int, + node_bit_width_cfg: int, fw_info) -> Any: """ Get the node's quantization configuration that @@ -93,8 +91,7 @@ def _get_node_qc_by_bit_widths(node: BaseNode, Args: node: Node to get its quantization configuration candidate. - bit_width_cfg: Configuration which determines the node's desired bit width. - node_index_in_graph: Index of the node in the bit_width_cfg. + node_bit_width_cfg: Configuration which determines the node's desired bit width. fw_info: Information relevant to a specific framework about how layers should be quantized. Returns: @@ -104,24 +101,21 @@ def _get_node_qc_by_bit_widths(node: BaseNode, kernel_attr = fw_info.get_kernel_op_attributes(node.type) if node.is_activation_quantization_enabled(): - bit_index_in_cfg = bit_width_cfg[node_index_in_graph] - qc = node.candidates_quantization_cfg[bit_index_in_cfg] + qc = node.candidates_quantization_cfg[node_bit_width_cfg] return qc elif kernel_attr is not None: if node.is_weights_quantization_enabled(kernel_attr[0]): - bit_index_in_cfg = bit_width_cfg[node_index_in_graph] - qc = node.candidates_quantization_cfg[bit_index_in_cfg] + qc = node.candidates_quantization_cfg[node_bit_width_cfg] return qc Logger.critical(f"Quantization configuration for node '{node.name}' not found in candidate configurations.") # pragma: no cover -def _set_node_final_qc(bit_width_cfg: List[int], +def _set_node_final_qc(node_bit_width_cfg: int, node: BaseNode, - node_index_in_graph: int, fw_info): """ Get the node's quantization configuration that @@ -130,15 +124,13 @@ def _set_node_final_qc(bit_width_cfg: List[int], If the node quantization config was not found, raise an exception. Args: - bit_width_cfg: Configuration which determines the node's desired bit width. + node_bit_width_cfg: Configuration which determines the node's desired bit width. node: Node to set its node quantization configuration. - node_index_in_graph: Index of the node in the bit_width_cfg. fw_info: Information relevant to a specific framework about how layers should be quantized. """ node_qc = _get_node_qc_by_bit_widths(node, - bit_width_cfg, - node_index_in_graph, + node_bit_width_cfg, fw_info) if node_qc is None: diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py index 9b239e30e..3abde76b7 100644 --- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py @@ -30,7 +30,7 @@ def filter_candidates_for_mixed_precision(graph: Graph, such that only a single candidate would remain, with the bitwidth equal to the one defined in the matching layer's base config in the TPC. - Note" This function modifies the graph inplace! + Note: This function modifies the graph inplace! Args: graph: A graph representation of the model to be quantized. diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py index b826b2c19..a0a3ede22 100644 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py @@ -57,7 +57,6 @@ def compute_resource_utilization_data(in_model: Any, Returns: ResourceUtilization: An object encapsulating the calculated resource utilization computations. - """ core_config = _create_core_config_for_ru(core_config) # We assume that the resource_utilization_data API is used to compute the model resource utilization for diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py index 2a75e51bc..123ae4404 100644 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py @@ -33,9 +33,10 @@ def sum_ru_values(ru_vector: np.ndarray, set_constraints: bool = True) -> List[A Returns: A list with an lpSum object for lp problem definition with the vector's sum. """ - if not set_constraints: - return [0] if len(ru_vector) == 0 else [sum(ru_vector)] - return [lpSum(ru_vector)] + if set_constraints: + return [lpSum(ru_vector)] + return [0] if len(ru_vector) == 0 else [sum(ru_vector)] + def max_ru_values(ru_vector: np.ndarray, set_constraints: bool = True) -> List[float]: @@ -53,9 +54,10 @@ def max_ru_values(ru_vector: np.ndarray, set_constraints: bool = True) -> List[f in the linear programming problem formalization. """ - if not set_constraints: - return [0] if len(ru_vector) == 0 else [max(ru_vector)] - return [ru for ru in ru_vector] + if set_constraints: + return [ru for ru in ru_vector] + return [0] if len(ru_vector) == 0 else [max(ru_vector)] + def total_ru(ru_tensor: np.ndarray, set_constraints: bool = True) -> List[float]: @@ -74,16 +76,14 @@ def total_ru(ru_tensor: np.ndarray, set_constraints: bool = True) -> List[float] in the linear programming problem formalization. """ - if not set_constraints: + if set_constraints: + weights_ru = lpSum([ru[0] for ru in ru_tensor]) + return [weights_ru + activation_ru for _, activation_ru in ru_tensor] + else: weights_ru = sum([ru[0] for ru in ru_tensor]) activation_ru = max([ru[1] for ru in ru_tensor]) return [weights_ru + activation_ru] - weights_ru = lpSum([ru[0] for ru in ru_tensor]) - total_ru = [weights_ru + activation_ru for _, activation_ru in ru_tensor] - - return total_ru - class MpRuAggregation(Enum): """ diff --git a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py index 81871065c..cada1e4e8 100644 --- a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +++ b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py @@ -73,7 +73,7 @@ def mp_integer_programming_search(search_manager: MixedPrecisionSearchManager, assert lp_problem.status == LpStatusOptimal, Logger.critical( "No solution was found during solving the LP problem") - Logger.info(LpStatus[lp_problem.status]) + Logger.info(f"ILP status: {LpStatus[lp_problem.status]}") # Take the bitwidth index only if its corresponding indicator is one. config = np.asarray( @@ -82,7 +82,7 @@ def mp_integer_programming_search(search_manager: MixedPrecisionSearchManager, in layer_to_indicator_vars_mapping.values()] ).flatten() - if target_resource_utilization.bops < np.inf: + if target_resource_utilization.bops_restricted(): return search_manager.config_reconstruction_helper.reconstruct_config_from_virtual_graph(config) else: return config diff --git a/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py b/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py index 397074b22..a9a1f9d6e 100644 --- a/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +++ b/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py @@ -47,7 +47,7 @@ def greedy_solution_refinement_procedure(mp_solution: List[int], """ # Refinement is not supported for BOPs utilization for now... - if target_resource_utilization.bops < np.inf: + if target_resource_utilization.bops_restricted(): Logger.info(f'Target resource utilization constraint BOPs - Skipping MP greedy solution refinement') return mp_solution diff --git a/model_compression_toolkit/core/runner.py b/model_compression_toolkit/core/runner.py index 65ef60176..1948f28c2 100644 --- a/model_compression_toolkit/core/runner.py +++ b/model_compression_toolkit/core/runner.py @@ -151,8 +151,7 @@ def core_runner(in_model: Any, f'Mixed Precision has overwrite bit-width configuration{core_config.mixed_precision_config.configuration_overwrite}') bit_widths_config = core_config.mixed_precision_config.configuration_overwrite - if (target_resource_utilization.activation_memory < np.inf or - target_resource_utilization.total_memory < np.inf): + if target_resource_utilization.activation_restricted() or target_resource_utilization.total_mem_restricted(): Logger.warning( f"Running mixed precision for activation compression, please note this feature is experimental and is " f"subject to future changes. If you encounter an issue, please open an issue in our GitHub " diff --git a/tests/pytorch_tests/model_tests/feature_models/manual_bit_selection.py b/tests/pytorch_tests/model_tests/feature_models/manual_bit_selection.py index 8d2207974..33bf89ca7 100644 --- a/tests/pytorch_tests/model_tests/feature_models/manual_bit_selection.py +++ b/tests/pytorch_tests/model_tests/feature_models/manual_bit_selection.py @@ -77,7 +77,8 @@ class BaseManualBitWidthSelectionTest(MixedPrecisionActivationBaseTest): def create_feature_network(self, input_shape): return NetForBitSelection(input_shape) - def get_mp_core_config(self): + @staticmethod + def get_mp_core_config(): qc = mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.MSE, mct.core.QuantizationErrorMethod.MSE, relu_bound_to_power_of_2=False, weights_bias_correction=True, input_scaling=False, activation_channel_equalization=False) @@ -92,6 +93,7 @@ def get_core_configs(self): core_config.bit_width_config.set_manual_activation_bit_width(self.filters, self.bit_widths) return {"mixed_precision_activation_model": core_config} + class ManualBitWidthByLayerTypeTest(BaseManualBitWidthSelectionTest): """ This test check the manual bit width configuration. @@ -159,10 +161,8 @@ def __init__(self, unit_test, filters, bit_widths): for filter, bit_width in zip(filters, bit_widths): self.layer_names.update({filter.node_name: bit_width}) - super().__init__(unit_test) - def compare(self, quantized_models, float_model, input_x=None, quantization_info=None): # in the compare we need bit_widths to be a list bit_widths = [self.bit_widths] if not isinstance(self.bit_widths, list) else self.bit_widths diff --git a/tests/pytorch_tests/model_tests/feature_models/mixed_precision_activation_test.py b/tests/pytorch_tests/model_tests/feature_models/mixed_precision_activation_test.py index 881bfda44..e4e387796 100644 --- a/tests/pytorch_tests/model_tests/feature_models/mixed_precision_activation_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/mixed_precision_activation_test.py @@ -69,7 +69,8 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info: raise NotImplementedError def verify_config(self, result_config, expected_config): - self.unit_test.assertTrue(all(result_config == expected_config)) + self.unit_test.assertTrue(all(result_config == expected_config), + f"Configuration mismatch: expected {expected_config} but got {result_config}.") class MixedPrecisionActivationSearch8Bit(MixedPrecisionActivationBaseTest):