Skip to content

Commit

Permalink
Update metric_utilities.py
Browse files Browse the repository at this point in the history
proper subgraph
values->to_numpy
  • Loading branch information
barneydobson committed Mar 7, 2024
1 parent a99950c commit 8465bb9
Showing 1 changed file with 33 additions and 3 deletions.
36 changes: 33 additions & 3 deletions swmmanywhere/metric_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,41 @@ def align_calc_nse(synthetic_results: pd.DataFrame,
how='outer').sort_values(by='date')

# Interpolate to time in real data
df['value_syn'] = df.set_index('date').value_syn.interpolate().values
df['value_syn'] = df.set_index('date').value_syn.interpolate().to_numpy()
df = df.dropna(subset=['value_real'])

# Calculate NSE
return nse(df.value_real, df.value_syn)

def create_subgraph(G: nx.Graph,
nodes: list) -> nx.Graph:
"""Create a subgraph.
Create a subgraph of G based on the nodes list. Taken from networkx
documentation: https://networkx.org/documentation/stable/reference/classes/generated/networkx.Graph.subgraph.html
Args:
G (nx.Graph): The original graph.
nodes (list): The list of nodes to include in the subgraph.
Returns:
nx.Graph: The subgraph.
"""
# Create a subgraph SG based on a (possibly multigraph) G
SG = G.__class__()
SG.add_nodes_from((n, G.nodes[n]) for n in nodes)
if SG.is_multigraph():
SG.add_edges_from((n, nbr, key, d)
for n, nbrs in G.adj.items() if n in nodes
for nbr, keydict in nbrs.items() if nbr in nodes
for key, d in keydict.items())
else:
SG.add_edges_from((n, nbr, d)
for n, nbrs in G.adj.items() if n in nodes
for nbr, d in nbrs.items() if nbr in nodes)
SG.graph.update(G.graph)
return SG

def nse(y: np.ndarray,
yhat: np.ndarray) -> float:
"""Calculate Nash-Sutcliffe efficiency (NSE)."""
Expand Down Expand Up @@ -133,7 +162,8 @@ def best_outlet_match(synthetic_G: nx.Graph,
# Subselect the matching graph
outlet_nodes = [n for n, d in synthetic_G.nodes(data=True)
if d['outlet'] == outlet]
return synthetic_G.subgraph(outlet_nodes), outlet
sg = create_subgraph(synthetic_G,outlet_nodes)
return sg, outlet

def dominant_outlet(G: nx.DiGraph,
results: pd.DataFrame) -> tuple[nx.DiGraph,int]:
Expand Down Expand Up @@ -165,7 +195,7 @@ def dominant_outlet(G: nx.DiGraph,
if d['id'] == max_outlet_arc][0]

# Subselect the matching graph
sg = G.subgraph(nx.ancestors(G, max_outlet) | {max_outlet})
sg = create_subgraph(G, nx.ancestors(G, max_outlet) | {max_outlet})
return sg, max_outlet

@metrics.register
Expand Down

0 comments on commit 8465bb9

Please sign in to comment.