diff --git a/swmmanywhere/metric_utilities.py b/swmmanywhere/metric_utilities.py new file mode 100644 index 00000000..e2f9498d --- /dev/null +++ b/swmmanywhere/metric_utilities.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +"""Created 2023-12-20. + +@author: Barnaby Dobson +""" +from inspect import signature +from typing import Callable + +import geopandas as gpd +import networkx as nx +import numpy as np +import pandas as pd +from scipy import stats + + +class MetricRegistry(dict): + """Registry object.""" + + def register(self, func: Callable) -> Callable: + """Register a metric.""" + if func.__name__ in self: + raise ValueError(f"{func.__name__} already in the metric registry!") + + allowable_params = {"synthetic_results": pd.DataFrame, + "real_results": pd.DataFrame, + "synthetic_subs": gpd.GeoDataFrame, + "real_subs": gpd.GeoDataFrame, + "synthetic_G": nx.Graph, + "real_G": nx.Graph} + + sig = signature(func) + for param, obj in sig.parameters.items(): + if param == 'kwargs': + continue + if param not in allowable_params.keys(): + raise ValueError(f"{param} of {func.__name__} not allowed.") + if obj.annotation != allowable_params[param]: + raise ValueError(f"""{param} of {func.__name__} should be of + type {allowable_params[param]}, not + {obj.__class__}.""") + self[func.__name__] = func + return func + + def __getattr__(self, name): + """Get a metric from the graphfcn dict.""" + try: + return self[name] + except KeyError: + raise AttributeError(f"{name} NOT in the metric registry!") + + +metrics = MetricRegistry() + +def extract_var(df: pd.DataFrame, + var: str) -> pd.DataFrame: + """Extract var from a dataframe.""" + df_ = df.loc[df.variable == var] + df_['duration'] = (df_.date - \ + df_.date.min()).dt.total_seconds() + return df_ + +@metrics.register +def bias_flood_depth( + synthetic_results: pd.DataFrame, + real_results: pd.DataFrame, + synthetic_subs: gpd.GeoDataFrame, + real_subs: gpd.GeoDataFrame, + **kwargs) -> float: + """Run the evaluated metric.""" + + def _f(x): + return np.trapz(x.value,x.duration) + + syn_flooding = extract_var(synthetic_results, + 'flooding').groupby('object').apply(_f) + syn_area = synthetic_subs.impervious_area.sum() + syn_tot = syn_flooding.sum() / syn_area + + real_flooding = extract_var(real_results, + 'flooding').groupby('object').apply(_f) + real_area = real_subs.impervious_area.sum() + real_tot = real_flooding.sum() / real_area + + return (syn_tot - real_tot) / real_tot + +@metrics.register +def kstest_betweenness( + synthetic_G: nx.Graph, + real_G: nx.Graph, + **kwargs) -> float: + """Run the evaluated metric.""" + syn_betweenness = nx.betweenness_centrality(synthetic_G) + real_betweenness = nx.betweenness_centrality(real_G) + + #TODO does it make more sense to use statistic or pvalue? + return stats.ks_2samp(list(syn_betweenness.values()), + list(real_betweenness.values())).statistic \ No newline at end of file diff --git a/tests/test_metric_utilities.py b/tests/test_metric_utilities.py new file mode 100644 index 00000000..a4739029 --- /dev/null +++ b/tests/test_metric_utilities.py @@ -0,0 +1,50 @@ +from pathlib import Path + +import numpy as np +import pandas as pd + +from swmmanywhere.graph_utilities import load_graph +from swmmanywhere.metric_utilities import metrics as sm + + +def test_bias_flood_depth(): + """Test the bias_flood_depth metric.""" + # Create synthetic and real data + synthetic_results = pd.DataFrame({ + 'object': ['obj1', 'obj1','obj2','obj2'], + 'value': [10, 20, 5, 2], + 'variable': 'flooding', + 'date' : pd.to_datetime(['2021-01-01 00:00:00','2021-01-01 00:05:00', + '2021-01-01 00:00:00','2021-01-01 00:05:00']) + }) + real_results = pd.DataFrame({ + 'object': ['obj1', 'obj1','obj2','obj2'], + 'value': [15, 25, 10, 20], + 'variable': 'flooding', + 'date' : pd.to_datetime(['2021-01-01 00:00:00','2021-01-01 00:05:00', + '2021-01-01 00:00:00','2021-01-01 00:05:00']) + }) + synthetic_subs = pd.DataFrame({ + 'impervious_area': [100, 200], + }) + real_subs = pd.DataFrame({ + 'impervious_area': [150, 250], + }) + + # Run the metric + val = sm.bias_flood_depth(synthetic_results = synthetic_results, + real_results = real_results, + synthetic_subs = synthetic_subs, + real_subs = real_subs) + assert np.isclose(val, -0.29523809523809524) + +def test_kstest_betweenness(): + """Test the kstest_betweenness metric.""" + G = load_graph(Path(__file__).parent / 'test_data' / 'graph_topo_derived.json') + val = sm.kstest_betweenness(synthetic_G = G, real_G = G) + assert val == 0.0 + + G_ = G.copy() + G_.remove_node(list(G.nodes)[0]) + val = sm.kstest_betweenness(synthetic_G = G_, real_G = G) + assert np.isclose(val, 0.286231884057971) \ No newline at end of file