From 073adf3986f8e8ad3dabe481baccbcf8de8caa21 Mon Sep 17 00:00:00 2001 From: Dobson Date: Fri, 5 Apr 2024 10:35:26 +0100 Subject: [PATCH] Update graph_utilities.py make proper graphfcn --- swmmanywhere/graph_utilities.py | 50 ++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/swmmanywhere/graph_utilities.py b/swmmanywhere/graph_utilities.py index e697677e..8ccecc79 100644 --- a/swmmanywhere/graph_utilities.py +++ b/swmmanywhere/graph_utilities.py @@ -222,31 +222,37 @@ def __call__(self, return G @register_graphfcn -def remove_parallels(G: nx.MultiDiGraph, weight = 'length') -> nx.DiGraph: - """Remove parallel edges from a street network. +class remove_isolated_nodes(BaseGraphFunction): + """remove_isolated_nodes class.""" - Retain the edge with the smallest weight (i.e., length). + def __call__(self, G: nx.Graph, **kwargs) -> nx.Graph: + """Remove parallel edges from a street network. - Args: - G (nx.MultiDiGraph): A graph. - weight (str): The edge attribute to use as the weight. - - Returns: - G (nx.DiGraph): The graph with parallel edges removed. + Retain the edge with the smallest weight (i.e., length). - Author: - Taher Chegini - """ - graph = ox.get_digraph(G) - _, _, attr_list = next(iter(graph.edges(data=True))) # type: ignore - attr_list = cast("dict[str, Any]", attr_list) - if weight not in attr_list: - raise ValueError(f"{weight} not in edge attributes.") - attr = nx.get_node_attributes(graph, weight) - parallels = (e for e in attr if e[::-1] in attr) - graph.remove_edges_from({e if attr[e] > attr[e[::-1]] - else e[::-1] for e in parallels}) - return graph + Args: + G (nx.MultiDiGraph): A graph. + **kwargs: Additional keyword arguments are ignored. + + Returns: + G (nx.DiGraph): The graph with parallel edges removed. + + Author: + Taher Chegini + """ + # Set the attribute (weight) used to determine which parallel edge to + # retain. Could make this a parameter in parameters.py if needed. + weight = 'length' + graph = ox.get_digraph(G) + _, _, attr_list = next(iter(graph.edges(data=True))) # type: ignore + attr_list = cast("dict[str, Any]", attr_list) + if weight not in attr_list: + raise ValueError(f"{weight} not in edge attributes.") + attr = nx.get_node_attributes(graph, weight) + parallels = (e for e in attr if e[::-1] in attr) + graph.remove_edges_from({e if attr[e] > attr[e[::-1]] + else e[::-1] for e in parallels}) + return graph @register_graphfcn class format_osmnx_lanes(BaseGraphFunction,