Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric format #59

Merged
merged 5 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions swmmanywhere/metric_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
"""Created 2023-12-20.

@author: Barnaby Dobson
"""
from inspect import signature
from typing import Callable

import geopandas as gpd
import networkx as nx
import numpy as np
import pandas as pd
from scipy import stats


class MetricRegistry(dict):
"""Registry object."""

def register(self, func: Callable) -> Callable:
"""Register a metric."""
if func.__name__ in self:
raise ValueError(f"{func.__name__} already in the metric registry!")

allowable_params = {"synthetic_results": pd.DataFrame,
"real_results": pd.DataFrame,
"synthetic_subs": gpd.GeoDataFrame,
"real_subs": gpd.GeoDataFrame,
"synthetic_G": nx.Graph,
"real_G": nx.Graph}

sig = signature(func)
for param, obj in sig.parameters.items():
if param == 'kwargs':
continue
if param not in allowable_params.keys():
raise ValueError(f"{param} of {func.__name__} not allowed.")
if obj.annotation != allowable_params[param]:
raise ValueError(f"""{param} of {func.__name__} should be of
type {allowable_params[param]}, not
{obj.__class__}.""")
self[func.__name__] = func
return func

def __getattr__(self, name):
"""Get a metric from the graphfcn dict."""
try:
return self[name]
except KeyError:
raise AttributeError(f"{name} NOT in the metric registry!")


metrics = MetricRegistry()

def extract_var(df: pd.DataFrame,
var: str) -> pd.DataFrame:
"""Extract var from a dataframe."""
df_ = df.loc[df.variable == var]
df_['duration'] = (df_.date - \
df_.date.min()).dt.total_seconds()
return df_

@metrics.register
def bias_flood_depth(
synthetic_results: pd.DataFrame,
real_results: pd.DataFrame,
synthetic_subs: gpd.GeoDataFrame,
real_subs: gpd.GeoDataFrame,
**kwargs) -> float:
"""Run the evaluated metric."""

def _f(x):
return np.trapz(x.value,x.duration)

syn_flooding = extract_var(synthetic_results,
'flooding').groupby('object').apply(_f)
syn_area = synthetic_subs.impervious_area.sum()
syn_tot = syn_flooding.sum() / syn_area

real_flooding = extract_var(real_results,
'flooding').groupby('object').apply(_f)
real_area = real_subs.impervious_area.sum()
real_tot = real_flooding.sum() / real_area

return (syn_tot - real_tot) / real_tot

@metrics.register
def kstest_betweenness(
synthetic_G: nx.Graph,
real_G: nx.Graph,
**kwargs) -> float:
"""Run the evaluated metric."""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cheginit it seems like we're happy with the format of metrics, but in the interest of not starting out with dud metrics, can I just check that comparing the distribution of nx.betweenness_centrality of two graphs via a KS test is actually a semi reasonable thing to do?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a loaded question 😄

First, regarding computing BC, networkx can be very slow (computing BC is computationally expensive in general), that's why I use networkit. Second, for comparing graphs in the context of optimization, there are more suitable metrics that we can choose from. For example, there is an interesting discussion here. You can also check out the distance measures or s-metric in networkx. It appears that there's a new backend for networkx that speeds up some slow operations in networkx, called graphblas

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK that's super helpful - though I'm going to bring it over to #50 since that's probably the best place to discuss loaded questions about graph comparisons ;)

syn_betweenness = nx.betweenness_centrality(synthetic_G)
real_betweenness = nx.betweenness_centrality(real_G)

#TODO does it make more sense to use statistic or pvalue?
return stats.ks_2samp(list(syn_betweenness.values()),
list(real_betweenness.values())).statistic
49 changes: 49 additions & 0 deletions tests/test_metric_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from pathlib import Path

import pandas as pd

from swmmanywhere.graph_utilities import load_graph
from swmmanywhere.metric_utilities import metrics as sm


def test_bias_flood_depth():
"""Test the bias_flood_depth metric."""
# Create synthetic and real data
synthetic_results = pd.DataFrame({
'object': ['obj1', 'obj1','obj2','obj2'],
'value': [10, 20, 5, 2],
'variable': 'flooding',
'date' : pd.to_datetime(['2021-01-01 00:00:00','2021-01-01 00:05:00',
'2021-01-01 00:00:00','2021-01-01 00:05:00'])
})
real_results = pd.DataFrame({
'object': ['obj1', 'obj1','obj2','obj2'],
'value': [15, 25, 10, 20],
'variable': 'flooding',
'date' : pd.to_datetime(['2021-01-01 00:00:00','2021-01-01 00:05:00',
'2021-01-01 00:00:00','2021-01-01 00:05:00'])
})
synthetic_subs = pd.DataFrame({
'impervious_area': [100, 200],
})
real_subs = pd.DataFrame({
'impervious_area': [150, 250],
})

# Run the metric
val = sm.bias_flood_depth(synthetic_results = synthetic_results,
real_results = real_results,
synthetic_subs = synthetic_subs,
real_subs = real_subs)
assert val == -0.29523809523809524
barneydobson marked this conversation as resolved.
Show resolved Hide resolved

def test_kstest_betweenness():
"""Test the kstest_betweenness metric."""
G = load_graph(Path(__file__).parent / 'test_data' / 'graph_topo_derived.json')
val = sm.kstest_betweenness(synthetic_G = G, real_G = G)
assert val == 0.0

G_ = G.copy()
G_.remove_node(list(G.nodes)[0])
val = sm.kstest_betweenness(synthetic_G = G_, real_G = G)
assert val == 0.286231884057971