diff --git a/swmmanywhere/graph_utilities.py b/swmmanywhere/graph_utilities.py index 962b55c6..ad81ce75 100644 --- a/swmmanywhere/graph_utilities.py +++ b/swmmanywhere/graph_utilities.py @@ -121,9 +121,24 @@ def add_graphfcn(self, node_attributes = node_attributes.union(self.adds_node_attributes) return edge_attributes, node_attributes -class GraphFunctionRegistry: +class GraphFunctionRegistry(dict): """Registry object.""" - pass + + def register(self, cls): + """Register a graph function.""" + if cls.__name__ in self: + raise ValueError(f"{cls.__name__} already in the graph functions registry!") + + self[cls.__name__] = cls() + return cls + + def __getattr__(self, name): + """Get a graph function from the graphfcn dict.""" + try: + return self[name] + except KeyError: + raise AttributeError(f"{name} NOT in the graph functions registry!") + graphfcns = GraphFunctionRegistry() @@ -136,7 +151,7 @@ def register_graphfcn(cls) -> Callable: Returns: cls (Callable): The same class """ - setattr(graphfcns, cls.__name__, cls()) + graphfcns.register(cls) return cls def get_osmid_id(data: dict) -> Hashable: