Skip to content

Commit

Permalink
Just some cosmetics (#1291)
Browse files Browse the repository at this point in the history
  • Loading branch information
elad-c authored Dec 10, 2024
1 parent 4a16071 commit 262db8a
Show file tree
Hide file tree
Showing 11 changed files with 33 additions and 39 deletions.
2 changes: 2 additions & 0 deletions model_compression_toolkit/core/common/fusion/graph_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class FusedLayerType:
"""
def __init__(self):
self.__name__ = 'FusedLayer'


class GraphFuser:

def create_fused_graph(self, graph: Graph) -> Dict[str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

SchedulerInfo = namedtuple('SchedulerInfo', [OPERATORS_SCHEDULING, MAX_CUT, CUTS, FUSED_NODES_MAPPING])


def compute_graph_max_cut(memory_graph: MemoryGraph,
n_iter: int = 50,
astar_n_iter: int = 500,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ def set_bit_widths(mixed_precision_enable: bool,
node_name = node.name if not node.reuse else '_'.join(node.name.split('_')[:-2])
if node_name in sorted_nodes_names: # only configurable nodes are in this list
node_index_in_graph = sorted_nodes_names.index(node_name)
_set_node_final_qc(bit_widths_config,
_set_node_final_qc(bit_widths_config[node_index_in_graph],
node,
node_index_in_graph,
graph.fw_info)
else:
if node.is_activation_quantization_enabled():
Expand Down Expand Up @@ -83,8 +82,7 @@ def set_bit_widths(mixed_precision_enable: bool,


def _get_node_qc_by_bit_widths(node: BaseNode,
bit_width_cfg: List[int],
node_index_in_graph: int,
node_bit_width_cfg: int,
fw_info) -> Any:
"""
Get the node's quantization configuration that
Expand All @@ -93,8 +91,7 @@ def _get_node_qc_by_bit_widths(node: BaseNode,
Args:
node: Node to get its quantization configuration candidate.
bit_width_cfg: Configuration which determines the node's desired bit width.
node_index_in_graph: Index of the node in the bit_width_cfg.
node_bit_width_cfg: Configuration which determines the node's desired bit width.
fw_info: Information relevant to a specific framework about how layers should be quantized.
Returns:
Expand All @@ -104,24 +101,21 @@ def _get_node_qc_by_bit_widths(node: BaseNode,
kernel_attr = fw_info.get_kernel_op_attributes(node.type)

if node.is_activation_quantization_enabled():
bit_index_in_cfg = bit_width_cfg[node_index_in_graph]
qc = node.candidates_quantization_cfg[bit_index_in_cfg]
qc = node.candidates_quantization_cfg[node_bit_width_cfg]

return qc

elif kernel_attr is not None:
if node.is_weights_quantization_enabled(kernel_attr[0]):
bit_index_in_cfg = bit_width_cfg[node_index_in_graph]
qc = node.candidates_quantization_cfg[bit_index_in_cfg]
qc = node.candidates_quantization_cfg[node_bit_width_cfg]

return qc

Logger.critical(f"Quantization configuration for node '{node.name}' not found in candidate configurations.") # pragma: no cover


def _set_node_final_qc(bit_width_cfg: List[int],
def _set_node_final_qc(node_bit_width_cfg: int,
node: BaseNode,
node_index_in_graph: int,
fw_info):
"""
Get the node's quantization configuration that
Expand All @@ -130,15 +124,13 @@ def _set_node_final_qc(bit_width_cfg: List[int],
If the node quantization config was not found, raise an exception.
Args:
bit_width_cfg: Configuration which determines the node's desired bit width.
node_bit_width_cfg: Configuration which determines the node's desired bit width.
node: Node to set its node quantization configuration.
node_index_in_graph: Index of the node in the bit_width_cfg.
fw_info: Information relevant to a specific framework about how layers should be quantized.
"""
node_qc = _get_node_qc_by_bit_widths(node,
bit_width_cfg,
node_index_in_graph,
node_bit_width_cfg,
fw_info)

if node_qc is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def filter_candidates_for_mixed_precision(graph: Graph,
such that only a single candidate would remain, with the bitwidth equal to the one defined in the matching layer's
base config in the TPC.
Note" This function modifies the graph inplace!
Note: This function modifies the graph inplace!
Args:
graph: A graph representation of the model to be quantized.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def compute_resource_utilization_data(in_model: Any,
Returns:
ResourceUtilization: An object encapsulating the calculated resource utilization computations.
"""
core_config = _create_core_config_for_ru(core_config)
# We assume that the resource_utilization_data API is used to compute the model resource utilization for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def sum_ru_values(ru_vector: np.ndarray, set_constraints: bool = True) -> List[A
Returns: A list with an lpSum object for lp problem definition with the vector's sum.
"""
if not set_constraints:
return [0] if len(ru_vector) == 0 else [sum(ru_vector)]
return [lpSum(ru_vector)]
if set_constraints:
return [lpSum(ru_vector)]
return [0] if len(ru_vector) == 0 else [sum(ru_vector)]



def max_ru_values(ru_vector: np.ndarray, set_constraints: bool = True) -> List[float]:
Expand All @@ -53,9 +54,10 @@ def max_ru_values(ru_vector: np.ndarray, set_constraints: bool = True) -> List[f
in the linear programming problem formalization.
"""
if not set_constraints:
return [0] if len(ru_vector) == 0 else [max(ru_vector)]
return [ru for ru in ru_vector]
if set_constraints:
return [ru for ru in ru_vector]
return [0] if len(ru_vector) == 0 else [max(ru_vector)]



def total_ru(ru_tensor: np.ndarray, set_constraints: bool = True) -> List[float]:
Expand All @@ -74,16 +76,14 @@ def total_ru(ru_tensor: np.ndarray, set_constraints: bool = True) -> List[float]
in the linear programming problem formalization.
"""
if not set_constraints:
if set_constraints:
weights_ru = lpSum([ru[0] for ru in ru_tensor])
return [weights_ru + activation_ru for _, activation_ru in ru_tensor]
else:
weights_ru = sum([ru[0] for ru in ru_tensor])
activation_ru = max([ru[1] for ru in ru_tensor])
return [weights_ru + activation_ru]

weights_ru = lpSum([ru[0] for ru in ru_tensor])
total_ru = [weights_ru + activation_ru for _, activation_ru in ru_tensor]

return total_ru


class MpRuAggregation(Enum):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def mp_integer_programming_search(search_manager: MixedPrecisionSearchManager,

assert lp_problem.status == LpStatusOptimal, Logger.critical(
"No solution was found during solving the LP problem")
Logger.info(LpStatus[lp_problem.status])
Logger.info(f"ILP status: {LpStatus[lp_problem.status]}")

# Take the bitwidth index only if its corresponding indicator is one.
config = np.asarray(
Expand All @@ -82,7 +82,7 @@ def mp_integer_programming_search(search_manager: MixedPrecisionSearchManager,
in layer_to_indicator_vars_mapping.values()]
).flatten()

if target_resource_utilization.bops < np.inf:
if target_resource_utilization.bops_restricted():
return search_manager.config_reconstruction_helper.reconstruct_config_from_virtual_graph(config)
else:
return config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],
"""
# Refinement is not supported for BOPs utilization for now...
if target_resource_utilization.bops < np.inf:
if target_resource_utilization.bops_restricted():
Logger.info(f'Target resource utilization constraint BOPs - Skipping MP greedy solution refinement')
return mp_solution

Expand Down
3 changes: 1 addition & 2 deletions model_compression_toolkit/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ def core_runner(in_model: Any,
f'Mixed Precision has overwrite bit-width configuration{core_config.mixed_precision_config.configuration_overwrite}')
bit_widths_config = core_config.mixed_precision_config.configuration_overwrite

if (target_resource_utilization.activation_memory < np.inf or
target_resource_utilization.total_memory < np.inf):
if target_resource_utilization.activation_restricted() or target_resource_utilization.total_mem_restricted():
Logger.warning(
f"Running mixed precision for activation compression, please note this feature is experimental and is "
f"subject to future changes. If you encounter an issue, please open an issue in our GitHub "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class BaseManualBitWidthSelectionTest(MixedPrecisionActivationBaseTest):
def create_feature_network(self, input_shape):
return NetForBitSelection(input_shape)

def get_mp_core_config(self):
@staticmethod
def get_mp_core_config():
qc = mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.MSE, mct.core.QuantizationErrorMethod.MSE,
relu_bound_to_power_of_2=False, weights_bias_correction=True,
input_scaling=False, activation_channel_equalization=False)
Expand All @@ -92,6 +93,7 @@ def get_core_configs(self):
core_config.bit_width_config.set_manual_activation_bit_width(self.filters, self.bit_widths)
return {"mixed_precision_activation_model": core_config}


class ManualBitWidthByLayerTypeTest(BaseManualBitWidthSelectionTest):
"""
This test check the manual bit width configuration.
Expand Down Expand Up @@ -159,10 +161,8 @@ def __init__(self, unit_test, filters, bit_widths):
for filter, bit_width in zip(filters, bit_widths):
self.layer_names.update({filter.node_name: bit_width})


super().__init__(unit_test)


def compare(self, quantized_models, float_model, input_x=None, quantization_info=None):
# in the compare we need bit_widths to be a list
bit_widths = [self.bit_widths] if not isinstance(self.bit_widths, list) else self.bit_widths
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info:
raise NotImplementedError

def verify_config(self, result_config, expected_config):
self.unit_test.assertTrue(all(result_config == expected_config))
self.unit_test.assertTrue(all(result_config == expected_config),
f"Configuration mismatch: expected {expected_config} but got {result_config}.")


class MixedPrecisionActivationSearch8Bit(MixedPrecisionActivationBaseTest):
Expand Down

0 comments on commit 262db8a

Please sign in to comment.