Skip to content

Commit

Permalink
Fix max-cut bounds.
Browse files Browse the repository at this point in the history
Fix test name typo.
  • Loading branch information
elad-c committed Dec 24, 2024
1 parent f347f51 commit 1104f77
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def compute_graph_max_cut(memory_graph: MemoryGraph,
estimate = (u_bound + l_bound) / 2
schedule, max_cut_size, cuts = max_cut_astar.solve(estimate_factor=estimate, iter_limit=astar_n_iter)
if schedule is None:
return last_result
l_bound = estimate
else:
u_bound = min(estimate, max_cut_size)
last_result = (schedule, max_cut_size, cuts)

next_u_bound = min(estimate, max_cut_size)
last_result = (schedule, max_cut_size, cuts)

if l_bound * (1 + eps) >= next_u_bound:
return last_result
if l_bound * (1 + eps) >= u_bound:
return last_result

it += 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
from typing import List
from operator import getitem

from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.graph.edge import EDGE_SOURCE_INDEX
Expand Down Expand Up @@ -58,7 +59,16 @@ def __init__(self, model_graph: Graph):
# Add memory tensor as current node's output
node_to_tensor.append((n, memory_tensor))

ot_edges = [oe for oe in out_edges if oe.source_index == i]
# TODO maxcut: refactor this code. it handles split->getitem generated by fx.
ot_edges = []
for oe in out_edges:
if oe.sink_node.type is getitem and len(oe.sink_node.op_call_args) == 1 and isinstance(oe.sink_node.op_call_args[0], int):
source_index = oe.sink_node.op_call_args[0]
else:
source_index = oe.source_index
if source_index == i:
ot_edges.append(oe)

for oe in ot_edges:
# Add current memory tensor as input to current node's successors
tensor_to_node.append((memory_tensor, oe.sink_node))
Expand All @@ -75,6 +85,7 @@ def __init__(self, model_graph: Graph):
inputs_tensors_memory = [sum([t.total_size for t in self.operation_node_children(n)])
for n in nodes if n in model_graph.get_inputs()]

# TODO maxcut: why both inputs and outputs of each nodes, while the A* solves for node outputs only???
nodes_total_memory = [sum([t.total_size for t in self.operation_node_children(n)] +
[t.total_size for t in self.operation_node_parents(n)])
for n in nodes if n not in model_graph.get_inputs()]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def run_test(self):
self.verify_results(ru_data, sum_parameters, max_tensor)


class TestResourceUtilizationDataComplesAllBitwidth(ResourceUtilizationDataBaseTestClass):
class TestResourceUtilizationDataComplexAllBitwidth(ResourceUtilizationDataBaseTestClass):

def run_test(self):
model = ComplexModel()
Expand Down
5 changes: 3 additions & 2 deletions tests/pytorch_tests/function_tests/test_function_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
BNLayerInfoCollectionTest, INP2BNInfoCollectionTest
from tests.pytorch_tests.function_tests.get_gptq_config_test import TestGetGPTQConfig
from tests.pytorch_tests.function_tests.resource_utilization_data_test import TestResourceUtilizationDataBasicAllBitwidth, \
TestResourceUtilizationDataBasicPartialBitwidth, TestResourceUtilizationDataComplexPartialBitwidth, TestResourceUtilizationDataComplesAllBitwidth
TestResourceUtilizationDataBasicPartialBitwidth, TestResourceUtilizationDataComplexPartialBitwidth, TestResourceUtilizationDataComplexAllBitwidth
from tests.pytorch_tests.function_tests.layer_fusing_test import LayerFusingTest1, LayerFusingTest2, LayerFusingTest3, \
LayerFusingTest4
from tests.pytorch_tests.function_tests.set_device_test import SetDeviceTest
Expand Down Expand Up @@ -100,7 +100,8 @@ def test_ru_data_complex_all(self):
"""
This test checks the resource utilization data Pytorch API.
"""
TestResourceUtilizationDataComplesAllBitwidth(self).run_test()
# TODO maxcut: test fails to fund lowest cut (3*224*250 + 3). also need to fix the "max_tensor" of the test Model.
TestResourceUtilizationDataComplexAllBitwidth(self).run_test()

def test_ru_data_complex_partial(self):
"""
Expand Down

0 comments on commit 1104f77

Please sign in to comment.