diff --git a/swmmanywhere/graph_utilities.py b/swmmanywhere/graph_utilities.py index db89f43f..e697677e 100644 --- a/swmmanywhere/graph_utilities.py +++ b/swmmanywhere/graph_utilities.py @@ -13,7 +13,7 @@ from heapq import heappop, heappush from itertools import product from pathlib import Path -from typing import Any, Callable, Dict, Hashable, List, Optional +from typing import Any, Callable, Dict, Hashable, List, Optional, cast import geopandas as gpd import networkx as nx @@ -220,6 +220,33 @@ def __call__(self, for u, v, key in edges_to_remove: G.remove_edge(u, v, key) return G + +@register_graphfcn +def remove_parallels(G: nx.MultiDiGraph, weight = 'length') -> nx.DiGraph: + """Remove parallel edges from a street network. + + Retain the edge with the smallest weight (i.e., length). + + Args: + G (nx.MultiDiGraph): A graph. + weight (str): The edge attribute to use as the weight. + + Returns: + G (nx.DiGraph): The graph with parallel edges removed. + + Author: + Taher Chegini + """ + graph = ox.get_digraph(G) + _, _, attr_list = next(iter(graph.edges(data=True))) # type: ignore + attr_list = cast("dict[str, Any]", attr_list) + if weight not in attr_list: + raise ValueError(f"{weight} not in edge attributes.") + attr = nx.get_node_attributes(graph, weight) + parallels = (e for e in attr if e[::-1] in attr) + graph.remove_edges_from({e if attr[e] > attr[e[::-1]] + else e[::-1] for e in parallels}) + return graph @register_graphfcn class format_osmnx_lanes(BaseGraphFunction,