diff --git a/swmmanywhere/graph_utilities.py b/swmmanywhere/graph_utilities.py index d4ba74a2..56ee69ab 100644 --- a/swmmanywhere/graph_utilities.py +++ b/swmmanywhere/graph_utilities.py @@ -371,8 +371,22 @@ def __call__(self, G: nx.Graph, **kwargs) -> nx.Graph: Returns: G (nx.Graph): A graph """ + # Convert to directed G_new = G.copy() G_new = nx.MultiDiGraph(G_new) + + # MultiDiGraph adds edges in both directions, but rivers (and geometries) + # are only in one direction. So we remove the reverse edges and add them + # back in with the correct geometry. + # This assumes that 'id' is of format 'start-end' (see assign_id) + arcs_to_remove = [(u,v) for u,v,d in G_new.edges(data=True) + if f'{u}-{v}' != d.get('id')] + + # Remove the reverse edges + for u, v in arcs_to_remove: + G_new.remove_edge(u, v) + + # Add in reversed edges for streets only and with geometry for u, v, data in G.edges(data=True): include = data.get('edge_type', True) if isinstance(include, str):