Skip to content

Commit

Permalink
MaxCut initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
elad-c committed Dec 25, 2024
1 parent 8ddbc8b commit 4071211
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_functions_mapping import RuFunctions
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_aggregation_methods import MpRuAggregation
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.ru_methods import MpRuMetric, calc_graph_cuts
from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import Cut
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation

Expand Down Expand Up @@ -81,9 +82,9 @@ def __init__(self,
original_graph=self.original_graph)

@property
def cuts(self):
def cuts(self) -> List[Cut]:
"""
Calcualtes graph cuts. Written as property so it will only be calculkated once and
Calculates graph cuts. Written as property, so it will only be calculated once and
only if cuts are needed.
"""
Expand Down Expand Up @@ -121,9 +122,9 @@ def get_sensitivity_metric(self) -> Callable:

return self.sensitivity_evaluator.compute_metric

def _calc_ru_fn(self, ru_target, ru_fn, mp_cfg):
def _calc_ru_fn(self, ru_target, ru_fn, mp_cfg) -> np.ndarray:
"""
Computes a resource utilization for a certain mp configuration
Computes a resource utilization for a certain mixed precision configuration.
The method computes a resource utilization vector for specific target resource utilization.
Returns: resource utilization value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def activation_maxcut_size_utilization(mp_cfg: List[int],
(not used in this method).
fw_impl: FrameworkImplementation object with specific framework methods implementation(not used in this method).
cuts: a list of graph cuts (optional. if not provided calculated locally).
TODO maxcut: refactor - need to remove the cuts so all metric functions signatures are the same.
Returns: A vector of node's cut memory sizes.
Note that the vector is not necessarily of the same length as the given config.
Expand Down

0 comments on commit 4071211

Please sign in to comment.