diff --git a/dev-requirements.txt b/dev-requirements.txt index f8508132..69df7621 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -63,9 +63,7 @@ colorama==0.4.6 contourpy==1.2.0 # via matplotlib coverage[toml]==7.3.2 - # via - # coverage - # pytest-cov + # via pytest-cov cramjam==2.7.0 # via fastparquet cycler==0.12.1 @@ -229,7 +227,9 @@ pyproj==3.6.1 # pysheds # rioxarray pyproject-hooks==1.0.0 - # via build + # via + # build + # pip-tools pysheds==0.3.5 # via swmmanywhere (pyproject.toml) pyswmm==1.5.1 diff --git a/swmmanywhere/geospatial_utilities.py b/swmmanywhere/geospatial_utilities.py index 712e775a..ed8d511f 100644 --- a/swmmanywhere/geospatial_utilities.py +++ b/swmmanywhere/geospatial_utilities.py @@ -25,6 +25,7 @@ from pysheds import grid as pgrid from rasterio import features from scipy.interpolate import RegularGridInterpolator +from scipy.spatial import KDTree from shapely import geometry as sgeom from shapely import ops as sops from shapely.errors import GEOSException @@ -614,8 +615,8 @@ def remove_(mp): return remove_zero_area_subareas(mp, removed_subareas) return polys_gdf def derive_rc(polys_gdf: gpd.GeoDataFrame, - G: nx.Graph, - building_footprints: gpd.GeoDataFrame) -> gpd.GeoDataFrame: + building_footprints: gpd.GeoDataFrame, + streetcover: gpd.GeoDataFrame) -> gpd.GeoDataFrame: """Derive the Runoff Coefficient (RC) of each subcatchment. The runoff coefficient is the ratio of impervious area to total area. The @@ -626,10 +627,10 @@ def derive_rc(polys_gdf: gpd.GeoDataFrame, Args: polys_gdf (gpd.GeoDataFrame): A GeoDataFrame containing polygons that represent subcatchments with columns: 'geometry', 'area', and 'id'. - G (nx.Graph): The input graph, with node 'ids' that match polys_gdf and - edges with the 'id', 'width' and 'geometry' property. building_footprints (gpd.GeoDataFrame): A GeoDataFrame containing building footprints with a 'geometry' column. + streetcover (gpd.GeoDataFrame): A GeoDataFrame containing street cover + with a 'geometry' column. Returns: gpd.GeoDataFrame: A GeoDataFrame containing polygons with columns: @@ -638,23 +639,7 @@ def derive_rc(polys_gdf: gpd.GeoDataFrame, polys_gdf = polys_gdf.copy() ## Format as swmm type catchments - - # TODO think harder about lane widths (am I double counting here?) - lines = [ - { - 'geometry': x['geometry'].buffer(x['width'], - cap_style=2, - join_style=2), - 'id': x['id'] - } - for u, v, x in G.edges(data=True) - ] - lines_df = pd.DataFrame(lines) - lines_gdf = gpd.GeoDataFrame(lines_df, - geometry=lines_df.geometry, - crs = polys_gdf.crs) - - result = gpd.overlay(lines_gdf[['geometry']], + result = gpd.overlay(streetcover[['geometry']], building_footprints[['geometry']], how='union') result = gpd.overlay(polys_gdf, result) @@ -786,4 +771,51 @@ def graph_to_geojson(graph: nx.Graph, } with fid.open('w') as output_file: - json.dump(geojson, output_file, indent=2) \ No newline at end of file + json.dump(geojson, output_file, indent=2) + +def merge_points(coordinates: list[tuple[float, float]], + threshold: float)-> dict: + """Merge points that are within a threshold distance. + + Args: + coordinates (list): List of coordinates as tuples. + threshold(float): The threshold distance for merging points. + + Returns: + dict: A dictionary mapping the original point index to the merged point + and new coordinate. + """ + # Create a KDTtree to pair together points within thresholds + tree = KDTree(coordinates) + pairs = tree.query_pairs(threshold) + + # Merge pairs into families of points that are all nearby + families: list = [] + + for pair in pairs: + matched_families = [family for family in families + if pair[0] in family or pair[1] in family] + + if matched_families: + # Merge all matched families and add the current pair + new_family = set(pair) + for family in matched_families: + new_family.update(family) + + # Remove the old families and add the newly formed one + for family in matched_families: + families.remove(family) + families.append(new_family) + else: + # No matching family found, so create a new one + families.append(set(pair)) + + # Create a mapping of the original point to the merged point + mapping = {} + for family in families: + average_point = np.mean([coordinates[i] for i in family], axis=0) + family_head = min(list(family)) + for i in family: + mapping[i] = {'maps_to' : family_head, + 'coordinate' : tuple(average_point)} + return mapping \ No newline at end of file diff --git a/swmmanywhere/graph_utilities.py b/swmmanywhere/graph_utilities.py index 9a121d0f..7b6e691c 100644 --- a/swmmanywhere/graph_utilities.py +++ b/swmmanywhere/graph_utilities.py @@ -18,7 +18,6 @@ import geopandas as gpd import networkx as nx import numpy as np -import osmnx as ox import pandas as pd import shapely from tqdm import tqdm @@ -184,7 +183,8 @@ def iterate_graphfcns(G: nx.Graph, G = graphfcns[function](G, addresses = addresses, **params) logger.info(f"graphfcn: {function} completed.") if verbose: - save_graph(G, addresses.model / f"{function}_graph.json") + save_graph(graphfcns.fix_geometries(G), + addresses.model / f"{function}_graph.json") return G @register_graphfcn @@ -266,36 +266,55 @@ def __call__(self, return G @register_graphfcn -class format_osmnx_lanes(BaseGraphFunction, - required_edge_attributes = ['lanes'], - adds_edge_attributes = ['width']): - """format_osmnx_lanes class.""" +class calculate_streetcover(BaseGraphFunction, + required_edge_attributes = ['lanes'] + ): + """calculate_streetcover class.""" # i.e., in osmnx format, i.e., empty for single lane, an int for a # number of lanes or a list if the edge has multiple carriageways def __call__(self, G: nx.Graph, - subcatchment_derivation: parameters.SubcatchmentDerivation, - **kwargs) -> nx.Graph: + subcatchment_derivation: parameters.SubcatchmentDerivation, + addresses: parameters.FilePaths, + **kwargs) -> nx.Graph: """Format the lanes attribute of each edge and calculates width. Args: G (nx.Graph): A graph subcatchment_derivation (parameters.SubcatchmentDerivation): A SubcatchmentDerivation parameter object + addresses (parameters.FilePaths): A FilePaths parameter object **kwargs: Additional keyword arguments are ignored. Returns: G (nx.Graph): A graph """ G = G.copy() + lines = [] for u, v, data in G.edges(data=True): lanes = data.get('lanes',1) if isinstance(lanes, list): lanes = sum([float(x) for x in lanes]) else: lanes = float(lanes) - data['width'] = lanes * subcatchment_derivation.lane_width + lines.append({'geometry' : data['geometry'].buffer(lanes * + subcatchment_derivation.lane_width, + cap_style=2, + join_style=2), + 'u' : u, + 'v' : v + } + ) + lines_df = pd.DataFrame(lines) + lines_gdf = gpd.GeoDataFrame(lines_df, + geometry=lines_df.geometry, + crs = G.graph['crs']) + if addresses.streetcover.suffix in ('.geoparquet','.parquet'): + lines_gdf.to_parquet(addresses.streetcover) + else: + lines_gdf.to_file(addresses.streetcover, driver='GeoJSON') + return G @register_graphfcn @@ -320,24 +339,54 @@ def __call__(self, G: nx.Graph, **kwargs) -> nx.Graph: Returns: G (nx.Graph): A graph """ - #TODO the geometry is left as is currently - should be reversed, however - # in original osmnx geometry there are some incorrectly directed ones - # someone with more patience might check start and end Points to check - # which direction the line should be going in... + # Convert to directed G_new = G.copy() + G_new = nx.MultiDiGraph(G.copy()) + + # MultiDiGraph adds edges in both directions, but rivers (and geometries) + # are only in one direction. So we remove the reverse edges and add them + # back in with the correct geometry. + # This assumes that 'id' is of format 'start-end' (see assign_id) + arcs_to_remove = [(u,v) for u,v,d in G_new.edges(data=True) + if f'{u}-{v}' != d.get('id')] + + # Remove the reverse edges + for u, v in arcs_to_remove: + G_new.remove_edge(u, v) + + # Add in reversed edges for streets only and with geometry for u, v, data in G.edges(data=True): include = data.get('edge_type', True) if isinstance(include, str): include = include == 'street' - if ((v, u) not in G.edges) & include: + if ((v, u) not in G_new.edges) & include: reverse_data = data.copy() reverse_data['id'] = f"{data['id']}.reversed" G_new.add_edge(v, u, **reverse_data) return G_new + +@register_graphfcn +class to_undirected(BaseGraphFunction): + """to_undirected class.""" + + def __call__(self, G: nx.Graph, **kwargs) -> nx.Graph: + """Convert the graph to an undirected graph. + + Args: + G (nx.Graph): A graph + **kwargs: Additional keyword arguments are ignored. + + Returns: + G (nx.Graph): An undirected graph + """ + # Don't use osmnx.to_undirected! It enables multigraph if the geometries + # are different, but we have already saved the street cover so don't + # want this! + return G.to_undirected() @register_graphfcn class split_long_edges(BaseGraphFunction, - required_edge_attributes = ['id', 'geometry', 'length']): + required_edge_attributes = ['id', 'geometry']): """split_long_edges class.""" def __call__(self, @@ -347,10 +396,9 @@ def __call__(self, """Split long edges into shorter edges. This function splits long edges into shorter edges. The edges are split - into segments of length 'max_street_length'. The first and last segment - are connected to the original nodes. Intermediate segments are connected - to newly created nodes. The 'geometry' of the original edge must be - a LineString. + into segments of length 'max_street_length'. The 'geometry' of the + original edge must be a LineString. Intended to follow up with call of + `merge_nodes`. Args: G (nx.Graph): A graph @@ -361,104 +409,93 @@ def __call__(self, Returns: graph (nx.Graph): A graph """ - #TODO refactor obviously max_length = subcatchment_derivation.max_street_length - graph = G.copy() - edges_to_remove = [] - edges_to_add = [] - nodes_to_add = [] - maxlabel = max(graph.nodes) + 1 - ll = 0 - - def create_new_edge_data(line, data, id_): - new_line = shapely.LineString(line) - new_data = data.copy() - new_data['id'] = id_ - new_data['length'] = new_line.length - new_data['geometry'] = shapely.LineString([(x[0], x[1]) - for x in new_line.coords]) - return new_data - - for u, v, data in graph.edges(data=True): - line = data['geometry'] - length = data['length'] - if ((u, v) not in edges_to_remove) & ((v, u) not in edges_to_remove): - if length > max_length: - new_points = [shapely.Point(x) - for x in ox.utils_geo.interpolate_points(line, - max_length)] - if len(new_points) > 2: - for ix, (start, end) in enumerate(zip(new_points[:-1], - new_points[1:])): - new_data = create_new_edge_data([start, - end], - data, - f"{data['id']}.{ix}") - if (v,u) in graph.edges: - # Create reversed data - data_r = graph.get_edge_data(v, u).copy()[0] - id_ = f"{data_r['id']}.{ix}" - new_data_r = create_new_edge_data([end, start], - data_r.copy(), - id_) - if ix == 0: - # Create start to first intermediate - edges_to_add.append((u, maxlabel + ll, new_data.copy())) - nodes_to_add.append((maxlabel + ll, - {'x': - new_data['geometry'].coords[-1][0], - 'y': - new_data['geometry'].coords[-1][1]})) - - if (v, u) in graph.edges: - # Create first intermediate to start - edges_to_add.append((maxlabel + ll, - u, - new_data_r.copy())) - - ll += 1 - elif ix == len(new_points) - 2: - # Create last intermediate to end - edges_to_add.append((maxlabel + ll - 1, - v, - new_data.copy())) - if (v, u) in graph.edges: - # Create end to last intermediate - edges_to_add.append((v, - maxlabel + ll - 1, - new_data_r.copy())) - else: - nodes_to_add.append((maxlabel + ll, - {'x': - new_data['geometry'].coords[-1][0], - 'y': - new_data['geometry'].coords[-1][1]})) - # Create N-1 intermediate to N intermediate - edges_to_add.append((maxlabel + ll - 1, - maxlabel + ll, - new_data.copy())) - if (v, u) in graph.edges: - # Create N intermediate to N-1 intermediate - edges_to_add.append((maxlabel + ll, - maxlabel + ll - 1, - new_data_r.copy())) - ll += 1 - edges_to_remove.append((u, v)) - if (v, u) in graph.edges: - edges_to_remove.append((v, u)) - - for u, v in edges_to_remove: - if (u, v) in graph.edges: - graph.remove_edge(u, v) - - for node in nodes_to_add: - graph.add_node(node[0], **node[1]) - - for edge in edges_to_add: - graph.add_edge(edge[0], edge[1], **edge[2]) - - return graph + # Split edges + new_linestrings = shapely.segmentize([d['geometry'] + for u,v,d in G.edges(data=True)], + max_length) + new_nodes = shapely.get_coordinates(new_linestrings) + + + new_edges = {} + for new_linestring, (u,v,d) in zip(new_linestrings, G.edges(data=True)): + # Create an arc for each segment + for start, end in zip(new_linestring.coords[:-1], + new_linestring.coords[1:]): + geom = shapely.LineString([start, end]) + new_edges[(start, end, 0)] = {**d, + 'length' : geom.length + } + + # Create new graph + new_graph = nx.MultiGraph() + new_graph.graph = G.graph.copy() + new_graph.add_edges_from(new_edges) + nx.set_edge_attributes(new_graph, new_edges) + nx.set_node_attributes( + new_graph, + {tuple(node): {'x': node[0], 'y': node[1]} for node in new_nodes} + ) + return nx.relabel_nodes(new_graph, + {node: ix for ix, node in enumerate(new_graph.nodes)} + ) + +@register_graphfcn +class merge_nodes(BaseGraphFunction): + """merge_nodes class.""" + def __call__(self, + G: nx.Graph, + subcatchment_derivation: parameters.SubcatchmentDerivation, + **kwargs) -> nx.Graph: + """Merge nodes that are close together. + + This function merges nodes that are within a certain distance of each + other. The distance is specified in the `node_merge_distance` attribute + of the `subcatchment_derivation` parameter. The merged nodes are given + the same coordinates, and the graph is relabeled with nx.relabel_nodes. + + Args: + G (nx.Graph): A graph + subcatchment_derivation (parameters.SubcatchmentDerivation): A + SubcatchmentDerivation parameter object + **kwargs: Additional keyword arguments are ignored. + + Returns: + G (nx.Graph): A graph + """ + G = G.copy() + + # Identify nodes that are within threshold of each other + mapping = go.merge_points([(d['x'], d['y']) for u,d in G.nodes(data=True)], + subcatchment_derivation.node_merge_distance) + + # Get indexes of node names + node_indices = {ix: node for ix, node in enumerate(G.nodes)} + + # Create a mapping of old node names to new node names + node_names = {} + for ix, node in enumerate(G.nodes): + if ix in mapping: + # If the node is in the mapping, then it is mapped and + # given the new coordinate (all nodes in a mapping family must + # be given the same coordinate because of how relabel_nodes + # works) + node_names[node] = node_indices[mapping[ix]['maps_to']] + G.nodes[node]['x'] = mapping[ix]['coordinate'][0] + G.nodes[node]['y'] = mapping[ix]['coordinate'][1] + else: + node_names[node] = node + + G = nx.relabel_nodes(G, node_names) + + # Relabelling will create selfloops within a mapping family, which + # are removed + self_loops = list(nx.selfloop_edges(G)) + G.remove_edges_from(self_loops) + + return G + @register_graphfcn class fix_geometries(BaseGraphFunction, required_edge_attributes = ['geometry'], @@ -479,10 +516,16 @@ def __call__(self, G: nx.Graph, **kwargs) -> nx.Graph: """ G = G.copy() for u, v, data in G.edges(data=True): - start_point_node = (G.nodes[u]['x'], G.nodes[u]['y']) - start_point_edge = data['geometry'].coords[0] + if not data.get('geometry', None): + start_point_edge = (None,None) + end_point_edge = (None,None) + else: + start_point_edge = data['geometry'].coords[0] + end_point_edge = data['geometry'].coords[-1] + + start_point_node = (G.nodes[u]['x'], G.nodes[u]['y']) end_point_node = (G.nodes[v]['x'], G.nodes[v]['y']) - end_point_edge = data['geometry'].coords[-1] + if (start_point_edge == end_point_node) & \ (end_point_edge == start_point_node): data['geometry'] = data['geometry'].reverse() @@ -494,7 +537,7 @@ def __call__(self, G: nx.Graph, **kwargs) -> nx.Graph: @register_graphfcn class calculate_contributing_area(BaseGraphFunction, - required_edge_attributes = ['id', 'geometry', 'width'], + required_edge_attributes = ['id', 'geometry'], adds_edge_attributes = ['contributing_area'], adds_node_attributes = ['contributing_area']): """calculate_contributing_area class.""" @@ -507,7 +550,10 @@ def __call__(self, G: nx.Graph, This function calculates the contributing area for each edge. The contributing area is the area of the subcatchment that drains to the - edge. The contributing area is calculated from the elevation data. + edge. The contributing area is calculated from the elevation data. + Runoff coefficient (RC) for each contributing area is also calculated, + the RC is calculated using `addresses.buildings` and + `addresses.streetcover`. Also writes the file 'subcatchments.geojson' to addresses.subcatchments. @@ -543,7 +589,12 @@ def __call__(self, G: nx.Graph, buildings = gpd.read_parquet(addresses.building) else: buildings = gpd.read_file(addresses.building) - subs_rc = go.derive_rc(subs_gdf, G, buildings) + if addresses.streetcover.suffix in ('.geoparquet','.parquet'): + streetcover = gpd.read_parquet(addresses.streetcover) + else: + streetcover = gpd.read_file(addresses.streetcover) + + subs_rc = go.derive_rc(subs_gdf, buildings, streetcover) # Write subs # TODO - could just attach subs to nodes where each node has a list of subs diff --git a/swmmanywhere/parameters.py b/swmmanywhere/parameters.py index 05bbf187..1bf37cf3 100644 --- a/swmmanywhere/parameters.py +++ b/swmmanywhere/parameters.py @@ -51,6 +51,12 @@ class SubcatchmentDerivation(BaseModel): unit = "m", description = "Distance to split streets into segments.") + node_merge_distance: float = Field(default = 10, + ge = 1, + le = 20, # This should probably be less than max_street_length + unit = 'm', + description = "Distance within which to merge street nodes.") + class OutletDerivation(BaseModel): """Parameters for outlet derivation.""" max_river_length: float = Field(default = 30.0, @@ -292,6 +298,9 @@ def _generate_elevation(self): def _generate_building(self): return self._generate_property(f'building.geo{self.extension}', 'download') + def _generate_streetcover(self): + return self._generate_property(f'streetcover.geo{self.extension}', + 'model') def _generate_precipitation(self): return self._generate_property(f'precipitation.{self.extension}', 'download') diff --git a/tests/test_data/demo_config.yml b/tests/test_data/demo_config.yml index ecbb15aa..a130b1e6 100644 --- a/tests/test_data/demo_config.yml +++ b/tests/test_data/demo_config.yml @@ -14,13 +14,16 @@ real: starting_graph: null graphfcn_list: - assign_id - - format_osmnx_lanes - - remove_non_pipe_allowable_links - - double_directed - fix_geometries + - remove_non_pipe_allowable_links + - calculate_streetcover + - to_undirected - split_long_edges + - merge_nodes + - assign_id - calculate_contributing_area - set_elevation + - double_directed - set_surface_slope - set_chahinian_slope - set_chahinian_angle @@ -28,6 +31,7 @@ graphfcn_list: - identify_outlets - derive_topology - pipe_by_pipe + - fix_geometries - assign_id metric_list: - nc_deltacon0 diff --git a/tests/test_geospatial_utilities.py b/tests/test_geospatial_utilities.py index b158a367..5d52decb 100644 --- a/tests/test_geospatial_utilities.py +++ b/tests/test_geospatial_utilities.py @@ -279,20 +279,16 @@ def test_derive_rc(): crs = crs) subs['area'] = subs.geometry.area - subs_rc = go.derive_rc(subs, G, buildings).set_index('id') + subs_rc = go.derive_rc(subs, buildings, buildings).set_index('id') assert subs_rc.loc[6277683849,'impervious_area'] == 0 assert subs_rc.loc[107733,'impervious_area'] > 0 - for u,v,d in G.edges(data=True): - d['width'] = 10 - subs_rc = go.derive_rc(subs, G, buildings).set_index('id') + buildings.geometry = buildings.buffer(50) + subs_rc = go.derive_rc(subs, buildings, buildings).set_index('id') assert subs_rc.loc[6277683849,'impervious_area'] > 0 assert subs_rc.loc[6277683849,'rc'] > 0 assert subs_rc.rc.max() <= 100 - for u,v,d in G.edges(data=True): - d['width'] = 0 - def test_calculate_angle(): """Test the calculate_angle function.""" # Test with points forming a right angle @@ -388,4 +384,15 @@ def test_graph_to_geojson(): assert gdf.shape[0] == len(G.nodes) gdf = gpd.read_file(temp_path / 'graph_edges.geojson') - assert gdf.shape[0] == len(G.edges) \ No newline at end of file + assert gdf.shape[0] == len(G.edges) + +def test_merge_points(): + """Test the merge_points function.""" + G = load_street_network() + mapping = go.merge_points([(d['x'], d['y']) for u,d in G.nodes(data=True)], + 20) + assert set(mapping.keys()) == set([2,3,5,15,16,18,22]) + assert set([x['maps_to'] for x in mapping.values()]) == set([2,5,15]) + assert mapping[15]['maps_to'] == 15 + assert mapping[18]['maps_to'] == 15 + assert almost_equal(mapping[18]['coordinate'][0], 700445.0112082) \ No newline at end of file diff --git a/tests/test_graph_utilities.py b/tests/test_graph_utilities.py index fe597056..62b0b159 100644 --- a/tests/test_graph_utilities.py +++ b/tests/test_graph_utilities.py @@ -52,14 +52,23 @@ def test_double_directed(): for u, v in G.edges(): assert (v,u) in G.edges -def test_format_osmnx_lanes(): - """Test the format_osmnx_lanes function.""" +def test_calculate_streetcover(): + """Test the calculate_streetcover function.""" G, _ = load_street_network() params = parameters.SubcatchmentDerivation() - G = gu.format_osmnx_lanes(G, params) - for u, v, data in G.edges(data=True): - assert 'width' in data.keys() - assert isinstance(data['width'], float) + addresses = parameters.FilePaths(base_dir = None, + project_name = None, + bbox_number = None, + model_number = None, + extension = 'json') + with tempfile.TemporaryDirectory() as temp_dir: + addresses.streetcover = Path(temp_dir) / 'streetcover.geojson' + _ = gu.calculate_streetcover(G, params, addresses) + # TODO test that G hasn't changed? or is that a waste of time? + assert addresses.streetcover.exists() + gdf = gpd.read_file(addresses.streetcover) + assert len(gdf) == len(G.edges) + assert gdf.geometry.area.sum() > 0 def test_split_long_edges(): """Test the split_long_edges function.""" @@ -82,6 +91,7 @@ def test_derive_subcatchments(): model_number = 1) addresses.elevation = Path(__file__).parent / 'test_data' / 'elevation.tif' addresses.building = temp_path / 'building.geojson' + addresses.streetcover = temp_path / 'building.geojson' addresses.subcatchments = temp_path / 'subcatchments.geojson' params = parameters.SubcatchmentDerivation() G, bbox = load_street_network() @@ -275,6 +285,7 @@ def test_iterate_graphfcns(): """Test the iterate_graphfcns function.""" G = load_graph(Path(__file__).parent / 'test_data' / 'graph_topo_derived.json') params = parameters.get_full_parameters() + params['topology_derivation'].omit_edges = ['primary', 'bridge'] with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) addresses = parameters.FilePaths(base_dir = None, @@ -286,12 +297,13 @@ def test_iterate_graphfcns(): addresses.model = temp_path G = iterate_graphfcns(G, ['assign_id', - 'format_osmnx_lanes'], + 'remove_non_pipe_allowable_links'], params, addresses) for u, v, d in G.edges(data=True): assert 'id' in d.keys() - assert 'width' in d.keys() + assert 'primary' not in get_edge_types(G) + assert len(set([d.get('bridge',None) for u,v,d in G.edges(data=True)])) == 1 def test_fix_geometries(): """Test the fix_geometries function.""" @@ -307,4 +319,17 @@ def test_fix_geometries(): # Check that the edge geometry now matches the node coordinates assert G_fixed.get_edge_data(107733, 25472373,0)['geometry'].coords[0] == \ - (G_fixed.nodes[107733]['x'], G_fixed.nodes[107733]['y']) \ No newline at end of file + (G_fixed.nodes[107733]['x'], G_fixed.nodes[107733]['y']) + +def almost_equal(a, b, tol=1e-6): + """Check if two numbers are almost equal.""" + return abs(a-b) < tol + +def test_merge_nodes(): + """Test the merge_nodes function.""" + G, _ = load_street_network() + subcatchment_derivation = parameters.SubcatchmentDerivation( + node_merge_distance = 20) + G_ = gu.merge_nodes(G, subcatchment_derivation) + assert not set([107736,266325461,2623975694,32925453]).intersection(G_.nodes) + assert almost_equal(G_.nodes[25510321]['x'], 700445.0112082) \ No newline at end of file