Skip to content

Commit

Permalink
Bug fix and test for it
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobson committed May 2, 2024
1 parent 947d4f0 commit b9f3937
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
8 changes: 5 additions & 3 deletions swmmanywhere/shortest_path_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,15 @@ def tarjans_pq(G: nx.MultiDiGraph,
if len(parent) != n - 1:
raise ValueError("Graph is not connected or has multiple roots.")

new_graph = G.copy()
for u,v in G.edges():
new_graph.remove_edge(u,v)
new_graph = nx.MultiDiGraph()

for u,v in mst_edges:
d= G_.get_edge_data(u,v)[0]
new_graph.add_edge(u,v,**d)

for u, d in G_.nodes(data=True):
new_graph.nodes[u].update(d)

nx.set_node_attributes(new_graph, outlets, 'outlet')
new_graph = nx.relabel_nodes(new_graph, node_mapping)

Expand Down
9 changes: 8 additions & 1 deletion tests/test_graph_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,18 @@ def test_identify_outlets_and_derive_topology():
G_ = gu.derive_topology(G_,params)
assert len(G_.edges) == 22
assert len(set([d['outlet'] for u,d in G_.nodes(data=True)])) == 2
for u,d in G_.nodes(data=True):
assert 'x' in d.keys()
assert 'y' in d.keys()


# Test outlet derivation parameters
G_ = G.copy()
params.outlet_length = 600
G_ = gu.identify_outlets(G_, params)
outlets = [(u,v,d) for u,v,d in G_.edges(data=True) if d['edge_type'] == 'outlet']
assert len(outlets) == 1

def test_identify_outlets_and_derive_topology_withtopo():
"""Test the identify_outlets and derive_topology functions."""
G, _ = load_street_network()
Expand Down Expand Up @@ -323,6 +327,9 @@ def test_identify_outlets_and_derive_topology_withtopo():
G_ = gu.identify_outlets(G_, params)
G_ = gu.derive_topology(G_, params)
assert len(set([d['outlet'] for u,d in G_.nodes(data=True)])) == 1
for u,d in G_.nodes(data=True):
assert 'x' in d.keys()
assert 'y' in d.keys()

def test_pipe_by_pipe():
"""Test the pipe_by_pipe function."""
Expand Down

0 comments on commit b9f3937

Please sign in to comment.