diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py new file mode 100644 index 00000000..bde8de0e --- /dev/null +++ b/swmmanywhere/geospatial_operations.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +"""Created 2024-01-20. + +@author: Barnaby Dobson +""" +import numpy as np +import rasterio as rst +from scipy.interpolate import RegularGridInterpolator + + +def interp_wrap(xy: tuple[float,float], + interp: RegularGridInterpolator, + grid: np.ndarray, + values: list[float]) -> float: + """Wrap the interpolation function to handle NaNs. + + Picks the nearest non NaN grid point if the interpolated value is NaN, + otherwise returns the interpolated value. + + Args: + xy (tuple): Coordinate of interest + interp (RegularGridInterpolator): The interpolator object. + grid (np.ndarray): List of xy coordinates of the grid points. + values (list): The list of values at each point in the grid. + + Returns: + float: The interpolated value. + """ + # Call the interpolator + val = float(interp(xy)) + # If the value is NaN, we need to pick nearest non nan grid point + if np.isnan(val): + # Get the distances to all grid points + distances = np.linalg.norm(grid - xy, axis=1) + # Get the indices of the grid points sorted by distance + indices = np.argsort(distances) + # Iterate over the grid points in order of increasing distance + for index in indices: + # If the value at this grid point is not NaN, return it + if not np.isnan(values[index]): + return values[index] + else: + return val + + raise ValueError("No non NaN values found in grid.") + +def interpolate_points_on_raster(x: list[float], + y: list[float], + elevation_fid: str) -> list[float ]: + """Interpolate points on a raster. + + Args: + x (list): X coordinates. + y (list): Y coordinates. + elevation_fid (str): Filepath to elevation raster. + + Returns: + elevation (float): Elevation at point. + """ + with rst.open(elevation_fid) as src: + # Read the raster data + data = src.read(1).astype(float) # Assuming it's a single-band raster + data[data == src.nodata] = None + + # Get the raster's coordinates + x = np.linspace(src.bounds.left, src.bounds.right, src.width) + y = np.linspace(src.bounds.bottom, src.bounds.top, src.height) + + # Define grid + xx, yy = np.meshgrid(x, y) + grid = np.vstack([xx.ravel(), yy.ravel()]).T + values = data.ravel() + + # Define interpolator + interp = RegularGridInterpolator((y,x), + np.flipud(data), + method='linear', + bounds_error=False, + fill_value=None) + # Interpolate for x,y + return [interp_wrap((y_, x_), interp, grid, values) for x_, y_ in zip(x,y)] \ No newline at end of file diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py new file mode 100644 index 00000000..1dade0ce --- /dev/null +++ b/tests/test_geospatial.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +"""Created on Tue Oct 18 10:35:51 2022. + +@author: Barney +""" + +# import pytest +from unittest.mock import MagicMock, patch + +import numpy as np +from scipy.interpolate import RegularGridInterpolator + +from swmmanywhere import geospatial_operations as go + + +def test_interp_wrap(): + """Test the interp_wrap function.""" + # Define a simple grid and values + x = np.linspace(0, 1, 5) + y = np.linspace(0, 1, 5) + xx, yy = np.meshgrid(x, y) + grid = np.vstack([xx.ravel(), yy.ravel()]).T + values = np.linspace(0, 1, 25) + values_grid = values.reshape(5, 5) + + # Define an interpolator + interp = RegularGridInterpolator((x,y), + values_grid) + + # Test the function at a point inside the grid + yx = (0.875, 0.875) + result = go.interp_wrap(yx, interp, grid, values) + assert result == 0.875 + + # Test the function on a nan point + values_grid[1][1] = np.nan + yx = (0.251, 0.25) + result = go.interp_wrap(yx, interp, grid, values) + assert result == values_grid[1][2] + +@patch('rasterio.open') +def test_interpolate_points_on_raster(mock_rst_open): + """Test the interpolate_points_on_raster function.""" + # Mock the raster file + mock_src = MagicMock() + mock_src.read.return_value = np.array([[1, 2], [3, 4]]) + mock_src.bounds = MagicMock() + mock_src.bounds.left = 0 + mock_src.bounds.right = 1 + mock_src.bounds.bottom = 0 + mock_src.bounds.top = 1 + mock_src.width = 2 + mock_src.height = 2 + mock_src.nodata = None + mock_rst_open.return_value.__enter__.return_value = mock_src + + # Define the x and y coordinates + x = [0.25, 0.75] + y = [0.25, 0.75] + + # Call the function + result = go.interpolate_points_on_raster(x, y, 'fake_path') + + # [3,2] feels unintuitive but it's because rasters measure from the top + assert result == [3.0, 2.0] \ No newline at end of file