Skip to content

Commit

Permalink
Update graph_utilities.py
Browse files Browse the repository at this point in the history
make proper graphfcn
  • Loading branch information
Dobson committed Apr 5, 2024
1 parent fd27834 commit 073adf3
Showing 1 changed file with 28 additions and 22 deletions.
50 changes: 28 additions & 22 deletions swmmanywhere/graph_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,31 +222,37 @@ def __call__(self,
return G

@register_graphfcn
def remove_parallels(G: nx.MultiDiGraph, weight = 'length') -> nx.DiGraph:
"""Remove parallel edges from a street network.
class remove_isolated_nodes(BaseGraphFunction):
"""remove_isolated_nodes class."""

Retain the edge with the smallest weight (i.e., length).
def __call__(self, G: nx.Graph, **kwargs) -> nx.Graph:
"""Remove parallel edges from a street network.
Args:
G (nx.MultiDiGraph): A graph.
weight (str): The edge attribute to use as the weight.
Returns:
G (nx.DiGraph): The graph with parallel edges removed.
Retain the edge with the smallest weight (i.e., length).
Author:
Taher Chegini
"""
graph = ox.get_digraph(G)
_, _, attr_list = next(iter(graph.edges(data=True))) # type: ignore
attr_list = cast("dict[str, Any]", attr_list)
if weight not in attr_list:
raise ValueError(f"{weight} not in edge attributes.")
attr = nx.get_node_attributes(graph, weight)
parallels = (e for e in attr if e[::-1] in attr)
graph.remove_edges_from({e if attr[e] > attr[e[::-1]]
else e[::-1] for e in parallels})
return graph
Args:
G (nx.MultiDiGraph): A graph.
**kwargs: Additional keyword arguments are ignored.
Returns:
G (nx.DiGraph): The graph with parallel edges removed.
Author:
Taher Chegini
"""
# Set the attribute (weight) used to determine which parallel edge to
# retain. Could make this a parameter in parameters.py if needed.
weight = 'length'
graph = ox.get_digraph(G)
_, _, attr_list = next(iter(graph.edges(data=True))) # type: ignore
attr_list = cast("dict[str, Any]", attr_list)
if weight not in attr_list:
raise ValueError(f"{weight} not in edge attributes.")
attr = nx.get_node_attributes(graph, weight)
parallels = (e for e in attr if e[::-1] in attr)
graph.remove_edges_from({e if attr[e] > attr[e[::-1]]
else e[::-1] for e in parallels})
return graph

@register_graphfcn
class format_osmnx_lanes(BaseGraphFunction,
Expand Down

0 comments on commit 073adf3

Please sign in to comment.