diff --git a/swmmanywhere/metric_utilities.py b/swmmanywhere/metric_utilities.py index f1859347..9318fd10 100644 --- a/swmmanywhere/metric_utilities.py +++ b/swmmanywhere/metric_utilities.py @@ -87,12 +87,41 @@ def align_calc_nse(synthetic_results: pd.DataFrame, how='outer').sort_values(by='date') # Interpolate to time in real data - df['value_syn'] = df.set_index('date').value_syn.interpolate().values + df['value_syn'] = df.set_index('date').value_syn.interpolate().to_numpy() df = df.dropna(subset=['value_real']) # Calculate NSE return nse(df.value_real, df.value_syn) +def create_subgraph(G: nx.Graph, + nodes: list) -> nx.Graph: + """Create a subgraph. + + Create a subgraph of G based on the nodes list. Taken from networkx + documentation: https://networkx.org/documentation/stable/reference/classes/generated/networkx.Graph.subgraph.html + + Args: + G (nx.Graph): The original graph. + nodes (list): The list of nodes to include in the subgraph. + + Returns: + nx.Graph: The subgraph. + """ + # Create a subgraph SG based on a (possibly multigraph) G + SG = G.__class__() + SG.add_nodes_from((n, G.nodes[n]) for n in nodes) + if SG.is_multigraph(): + SG.add_edges_from((n, nbr, key, d) + for n, nbrs in G.adj.items() if n in nodes + for nbr, keydict in nbrs.items() if nbr in nodes + for key, d in keydict.items()) + else: + SG.add_edges_from((n, nbr, d) + for n, nbrs in G.adj.items() if n in nodes + for nbr, d in nbrs.items() if nbr in nodes) + SG.graph.update(G.graph) + return SG + def nse(y: np.ndarray, yhat: np.ndarray) -> float: """Calculate Nash-Sutcliffe efficiency (NSE).""" @@ -133,7 +162,8 @@ def best_outlet_match(synthetic_G: nx.Graph, # Subselect the matching graph outlet_nodes = [n for n, d in synthetic_G.nodes(data=True) if d['outlet'] == outlet] - return synthetic_G.subgraph(outlet_nodes), outlet + sg = create_subgraph(synthetic_G,outlet_nodes) + return sg, outlet def dominant_outlet(G: nx.DiGraph, results: pd.DataFrame) -> tuple[nx.DiGraph,int]: @@ -165,7 +195,7 @@ def dominant_outlet(G: nx.DiGraph, if d['id'] == max_outlet_arc][0] # Subselect the matching graph - sg = G.subgraph(nx.ancestors(G, max_outlet) | {max_outlet}) + sg = create_subgraph(G, nx.ancestors(G, max_outlet) | {max_outlet}) return sg, max_outlet @metrics.register