Skip to content

Commit

Permalink
Replace max tensor with max cut (#1295)
Browse files Browse the repository at this point in the history
Replace MaxTensor with MaxCut for activation mixed precision (Experimental).
  • Loading branch information
elad-c authored and liord committed Dec 31, 2024
1 parent 20fa29f commit 5502435
Show file tree
Hide file tree
Showing 26 changed files with 330 additions and 107 deletions.
15 changes: 8 additions & 7 deletions model_compression_toolkit/core/common/fusion/graph_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def create_fused_graph(self, graph: Graph) -> Dict[str, str]:
The fusion process involves:
1. Creating new fused nodes to represent these groups.
2. Updating the graph structure to replace the original nodes with fused nodes.
3. Maintaining mapping mapping of original node names to their fused node names.
3. Maintaining mapping of original node names to their fused node names.
Args:
graph: Graph to sue its nodes.
graph: Graph to fuse its nodes.
Returns:
Mapping of original node names to their fused node names
Expand All @@ -54,7 +54,8 @@ def create_fused_graph(self, graph: Graph) -> Dict[str, str]:
fused_nodes_mapping[node.name] = new_fused_node.name
return fused_nodes_mapping

def _create_fused_node(self, nodes: List[BaseNode]) -> BaseNode:
@staticmethod
def _create_fused_node(nodes: List[BaseNode]) -> BaseNode:
"""
Create a new node that represents the fusion of the given nodes.
Expand All @@ -79,10 +80,10 @@ def _create_fused_node(self, nodes: List[BaseNode]) -> BaseNode:

return fused_node

def _replace_nodes_with_fused_node(self,
graph: Graph,
nodes_to_fuse: List[BaseNode],
fused_node: BaseNode):
@staticmethod
def _replace_nodes_with_fused_node(graph: Graph,
nodes_to_fuse: List[BaseNode],
fused_node: BaseNode):
"""
Replace the specified nodes in the graph with a new fused node.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def compute_graph_max_cut(memory_graph: MemoryGraph,
estimate = (u_bound + l_bound) / 2
schedule, max_cut_size, cuts = max_cut_astar.solve(estimate_factor=estimate, iter_limit=astar_n_iter)
if schedule is None:
return last_result
l_bound = estimate
else:
u_bound = min(estimate, max_cut_size)
last_result = (schedule, max_cut_size, cuts)

next_u_bound = min(estimate, max_cut_size)
last_result = (schedule, max_cut_size, cuts)

if l_bound * (1 + eps) >= next_u_bound:
return last_result
if l_bound * (1 + eps) >= u_bound:
return last_result

it += 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def solve(self, estimate_factor: float, iter_limit: int = 500) -> Tuple[List[Bas
cut_route = routes[next_cut]

if next_cut == self.target_cut:
# TODO maxcut: Why do we filter the cuts (cut_route) but not the max cut size (cut_sost).
# This is a mismatch between max_cut and max(cuts).
# Also, unfiltered cut_route seems perfect, including input and output tensor sizes of current op.
return self._remove_dummys_from_path(cut_route[0].op_order), cut_cost,\
list(set([self._remove_dummys_from_cut(self.clean_memory_for_next_step(c)) for c in cut_route]))

Expand All @@ -178,7 +181,8 @@ def solve(self, estimate_factor: float, iter_limit: int = 500) -> Tuple[List[Bas
cost = self.accumulate(cut_cost, c.memory_size())
if c not in open_list:
self._update_expanded_node(c, cost, cut_route, open_list, costs, routes)
elif self.ordering(cost, costs[c]):
# TODO maxcut: this isn't covered in the coverage test. check if needed and remove no cover
elif self.ordering(cost, costs[c]): # pragma: no cover
# If we already saw this cut during the search with a larger cost, then we want to update the order
# of the schedule in the cut
# Remove call - removes the cut with the same memory elements but different ordering from open
Expand All @@ -187,7 +191,8 @@ def solve(self, estimate_factor: float, iter_limit: int = 500) -> Tuple[List[Bas
self._update_expanded_node(c, cost, cut_route, open_list, costs, routes)

# Halt or No Solution
return None, 0, None
# TODO maxcut: this isn't covered in the coverage test. check if needed and remove no cover
return None, 0, None # pragma: no cover

@staticmethod
def _update_expanded_node(cut: Cut, cost: float, route: List[Cut], open_list: List[Cut],
Expand Down Expand Up @@ -223,8 +228,7 @@ def _get_cut_to_expand(self, open_list: List[Cut], costs: Dict[Cut, float], rout
"""
ordered_cuts_list = sorted(open_list,
key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate_factor)), len(routes[c])),
reverse=False)
key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate_factor)), -len(routes[c])))

assert len(ordered_cuts_list) > 0
return ordered_cuts_list[0]
Expand Down Expand Up @@ -349,7 +353,8 @@ def ordering(cost_1, cost_2) -> bool:
Returns: True if the first cost is smaller than the second one, else otherwise.
"""
return cost_1 < cost_2
# TODO maxcut: this isn't covered in the coverage test. check if needed and remove no cover
return cost_1 < cost_2 # pragma: no cover

def estimate(self, cut: Cut, estimate_factor: float) -> float:
"""
Expand Down Expand Up @@ -377,9 +382,10 @@ def get_init_estimate_factor(memory_graph: MemoryGraph) -> float:
Returns: An initial estimate value.
"""
l_bound = memory_graph.memory_lbound_single_op
u_bound = 2 * sum([t.total_size for t in memory_graph.b_nodes]) - l_bound
return (u_bound + l_bound) / 2
# TODO maxcut: this isn't covered in the coverage test. check if needed and remove no cover
l_bound = memory_graph.memory_lbound_single_op # pragma: no cover
u_bound = 2 * sum([t.total_size for t in memory_graph.b_nodes]) - l_bound # pragma: no cover
return (u_bound + l_bound) / 2 # pragma: no cover

@staticmethod
def _remove_dummys_from_path(path: List[BaseNode]) -> List[BaseNode]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ def __init__(self, shape: Tuple[Any], node_name: str, node_output_index: int, in
init_size_to_zero: Whether to initialize the memory tensor size to 0 or not.
"""

self.shape = shape[1:] # remove batch size (first element) from output shape
# remove batch size (first element) from output shape. If the shape is a list then remove the first
# axis. If shape a vector (e.g. output of size) then set the shape minus 1 to ignore the batch value.
if len(shape) == 1:
self.shape = [] if shape[0] is None else [shape[0] - 1]
else:
self.shape = shape[1:]
# The total size of a tensor is considered to be the number of elements in the tensor
self.total_size = self._get_tensor_total_size() if not init_size_to_zero else 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
from typing import List
from operator import getitem

from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.graph.edge import EDGE_SOURCE_INDEX
Expand Down Expand Up @@ -45,7 +46,8 @@ def __init__(self, model_graph: Graph):
tensor_to_node = []

for n in nodes:
n_outputs = [n.output_shape] if isinstance(n.output_shape, tuple) else n.output_shape
n_outputs = n.output_shape if isinstance(n.output_shape[0], (tuple, list)) else [n.output_shape]

out_edges = model_graph.out_edges(n, sort_by_attr=EDGE_SOURCE_INDEX)

for i, ot in enumerate(n_outputs):
Expand All @@ -54,7 +56,16 @@ def __init__(self, model_graph: Graph):
# Add memory tensor as current node's output
node_to_tensor.append((n, memory_tensor))

ot_edges = [oe for oe in out_edges if oe.source_index == i]
# TODO maxcut: refactor this code. it handles split->getitem generated by fx.
ot_edges = []
for oe in out_edges:
if oe.sink_node.type is getitem and len(oe.sink_node.op_call_args) == 1 and isinstance(oe.sink_node.op_call_args[0], int):
source_index = oe.sink_node.op_call_args[0]
else:
source_index = oe.source_index
if source_index == i:
ot_edges.append(oe)

for oe in ot_edges:
# Add current memory tensor as input to current node's successors
tensor_to_node.append((memory_tensor, oe.sink_node))
Expand All @@ -71,6 +82,7 @@ def __init__(self, model_graph: Graph):
inputs_tensors_memory = [sum([t.total_size for t in self.operation_node_children(n)])
for n in nodes if n in model_graph.get_inputs()]

# TODO maxcut: why both inputs and outputs of each nodes, while the A* solves for node outputs only???
nodes_total_memory = [sum([t.total_size for t in self.operation_node_children(n)] +
[t.total_size for t in self.operation_node_parents(n)])
for n in nodes if n not in model_graph.get_inputs()]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
VirtualSplitWeightsNode, VirtualSplitActivationNode
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget, ResourceUtilization
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_functions_mapping import RuFunctions
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_aggregation_methods import MpRuAggregation
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import MpRuMetric
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import MpRuMetric, calc_graph_cuts
from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import Cut
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation

Expand All @@ -40,7 +42,7 @@ def __init__(self,
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation,
sensitivity_evaluator: SensitivityEvaluation,
ru_functions: Dict[RUTarget, Tuple[MpRuMetric, MpRuAggregation]],
ru_functions: Dict[RUTarget, RuFunctions[MpRuMetric, MpRuAggregation]],
target_resource_utilization: ResourceUtilization,
original_graph: Graph = None):
"""
Expand All @@ -65,8 +67,11 @@ def __init__(self,
self.sensitivity_evaluator = sensitivity_evaluator
self.layer_to_bitwidth_mapping = self.get_search_space()
self.compute_metric_fn = self.get_sensitivity_metric()
self._cuts = None

self.compute_ru_functions = ru_functions
ru_types = [ru_target for ru_target, ru_value in
target_resource_utilization.get_resource_utilization_dict().items() if ru_value < np.inf]
self.compute_ru_functions = {ru_target: ru_fn for ru_target, ru_fn in ru_functions.items() if ru_target in ru_types}
self.target_resource_utilization = target_resource_utilization
self.min_ru_config = self.graph.get_min_candidates_config(fw_info)
self.max_ru_config = self.graph.get_max_candidates_config(fw_info)
Expand All @@ -76,6 +81,17 @@ def __init__(self,
self.config_reconstruction_helper = ConfigReconstructionHelper(virtual_graph=self.graph,
original_graph=self.original_graph)

@property
def cuts(self) -> List[Cut]:
"""
Calculates graph cuts. Written as property, so it will only be calculated once and
only if cuts are needed.
"""
if self._cuts is None:
self._cuts = calc_graph_cuts(self.original_graph)
return self._cuts

def get_search_space(self) -> Dict[int, List[int]]:
"""
The search space is a mapping from a node's index to a list of integers (possible bitwidths candidates indeces
Expand Down Expand Up @@ -106,6 +122,21 @@ def get_sensitivity_metric(self) -> Callable:

return self.sensitivity_evaluator.compute_metric

def _calc_ru_fn(self, ru_target, ru_fn, mp_cfg) -> np.ndarray:
"""
Computes a resource utilization for a certain mixed precision configuration.
The method computes a resource utilization vector for specific target resource utilization.
Returns: resource utilization value.
"""
# ru_fn is a pair of resource utilization computation method and
# resource utilization aggregation method (in this method we only need the first one)
if ru_target is RUTarget.ACTIVATION:
return ru_fn.metric_fn(mp_cfg, self.graph, self.fw_info, self.fw_impl, self.cuts)
else:
return ru_fn.metric_fn(mp_cfg, self.graph, self.fw_info, self.fw_impl)

def compute_min_ru(self) -> Dict[RUTarget, np.ndarray]:
"""
Computes a resource utilization vector with the values matching to the minimal mp configuration
Expand All @@ -118,10 +149,10 @@ def compute_min_ru(self) -> Dict[RUTarget, np.ndarray]:
"""
min_ru = {}
for ru_target, ru_fns in self.compute_ru_functions.items():
# ru_fns is a pair of resource utilization computation method and
for ru_target, ru_fn in self.compute_ru_functions.items():
# ru_fns is a pair of resource utilization computation method and
# resource utilization aggregation method (in this method we only need the first one)
min_ru[ru_target] = ru_fns[0](self.min_ru_config, self.graph, self.fw_info, self.fw_impl)
min_ru[ru_target] = self._calc_ru_fn(ru_target, ru_fn, self.min_ru_config)

return min_ru

Expand Down Expand Up @@ -212,7 +243,7 @@ def compute_node_ru_for_candidate(self, conf_node_idx: int, candidate_idx: int,
"""
cfg = self.replace_config_in_index(self.min_ru_config, conf_node_idx, candidate_idx)
return self.compute_ru_functions[target].metric_fn(cfg, self.graph, self.fw_info, self.fw_impl)
return self._calc_ru_fn(target, self.compute_ru_functions[target], cfg)

@staticmethod
def replace_config_in_index(mp_cfg: List[int], idx: int, value: int) -> List[int]:
Expand Down Expand Up @@ -241,13 +272,15 @@ def _non_configurable_nodes_ru(self) -> Dict[RUTarget, np.ndarray]:
"""

non_conf_ru_dict = {}
for target, ru_value in self.target_resource_utilization.get_resource_utilization_dict().items():
for target, ru_fns in self.compute_ru_functions.items():
# Call for the ru method of the given target - empty quantization configuration list is passed since we
# compute for non-configurable nodes
if target == RUTarget.BOPS:
ru_vector = None
elif target == RUTarget.ACTIVATION:
ru_vector = ru_fns.metric_fn([], self.graph, self.fw_info, self.fw_impl, self.cuts)
else:
ru_vector = self.compute_ru_functions[target].metric_fn([], self.graph, self.fw_info, self.fw_impl)
ru_vector = ru_fns.metric_fn([], self.graph, self.fw_info, self.fw_impl)

non_conf_ru_dict[target] = ru_vector

Expand All @@ -266,14 +299,15 @@ def compute_resource_utilization_for_config(self, config: List[int]) -> Resource
"""

ru_dict = {}

for ru_target, ru_fns in self.compute_ru_functions.items():
# Passing False to ru methods and aggregations to indicates that the computations
# are not for constraints setting
if ru_target == RUTarget.BOPS:
configurable_nodes_ru_vector = ru_fns[0](config, self.original_graph, self.fw_info, self.fw_impl, False)
configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.original_graph, self.fw_info, self.fw_impl, False)
elif ru_target == RUTarget.ACTIVATION:
configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.graph, self.fw_info, self.fw_impl, self.cuts)
else:
configurable_nodes_ru_vector = ru_fns[0](config, self.original_graph, self.fw_info, self.fw_impl)
configurable_nodes_ru_vector = ru_fns.metric_fn(config, self.original_graph, self.fw_info, self.fw_impl)
non_configurable_nodes_ru_vector = self.non_conf_ru_dict.get(ru_target)
if non_configurable_nodes_ru_vector is None or len(non_configurable_nodes_ru_vector) == 0:
ru_ru = self.compute_ru_functions[ru_target].aggregate_fn(configurable_nodes_ru_vector, False)
Expand Down Expand Up @@ -647,7 +681,7 @@ def get_weights_for_split_activation(self,
# It's ok, need to find the node's configuration
self.retrieve_weights_activation_config(activation_node, weights_node, virtual_node, virtual_cfg_idx, virtual_mp_cfg)
else:
Logger.critical(f"Virtual graph configuration error: Expected the predecessor of node '{n.name}' to have multiple outputs when not composed with an activation node.") # pragma: no cover
Logger.critical(f"Virtual graph configuration error: Expected the predecessor of node '{weights_node.name}' to have multiple outputs when not composed with an activation node.") # pragma: no cover

def update_config_at_original_idx(self, n: BaseNode, origin_cfg_idx: int):
"""
Expand Down
Loading

0 comments on commit 5502435

Please sign in to comment.