Skip to content

Commit

Permalink
Start of geospatial analysis
Browse files Browse the repository at this point in the history
-Add raster interpolation and tests
  • Loading branch information
Dobson committed Jan 19, 2024
1 parent 492bac3 commit e3e7e89
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 0 deletions.
81 changes: 81 additions & 0 deletions swmmanywhere/geospatial_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
"""Created 2024-01-20.
@author: Barnaby Dobson
"""
import numpy as np
import rasterio as rst
from scipy.interpolate import RegularGridInterpolator


def interp_wrap(xy: tuple[float,float],
interp: RegularGridInterpolator,
grid: np.ndarray,
values: list[float]) -> float:
"""Wrap the interpolation function to handle NaNs.
Picks the nearest non NaN grid point if the interpolated value is NaN,
otherwise returns the interpolated value.
Args:
xy (tuple): Coordinate of interest
interp (RegularGridInterpolator): The interpolator object.
grid (np.ndarray): List of xy coordinates of the grid points.
values (list): The list of values at each point in the grid.
Returns:
float: The interpolated value.
"""
# Call the interpolator
val = float(interp(xy))
# If the value is NaN, we need to pick nearest non nan grid point
if np.isnan(val):
# Get the distances to all grid points
distances = np.linalg.norm(grid - xy, axis=1)
# Get the indices of the grid points sorted by distance
indices = np.argsort(distances)
# Iterate over the grid points in order of increasing distance
for index in indices:
# If the value at this grid point is not NaN, return it
if not np.isnan(values[index]):
return values[index]
else:
return val

raise ValueError("No non NaN values found in grid.")

def interpolate_points_on_raster(x: list[float],
y: list[float],
elevation_fid: str) -> list[float ]:
"""Interpolate points on a raster.
Args:
x (list): X coordinates.
y (list): Y coordinates.
elevation_fid (str): Filepath to elevation raster.
Returns:
elevation (float): Elevation at point.
"""
with rst.open(elevation_fid) as src:
# Read the raster data
data = src.read(1).astype(float) # Assuming it's a single-band raster
data[data == src.nodata] = None

# Get the raster's coordinates
x = np.linspace(src.bounds.left, src.bounds.right, src.width)
y = np.linspace(src.bounds.bottom, src.bounds.top, src.height)

# Define grid
xx, yy = np.meshgrid(x, y)
grid = np.vstack([xx.ravel(), yy.ravel()]).T
values = data.ravel()

# Define interpolator
interp = RegularGridInterpolator((y,x),
np.flipud(data),
method='linear',
bounds_error=False,
fill_value=None)
# Interpolate for x,y
return [interp_wrap((y_, x_), interp, grid, values) for x_, y_ in zip(x,y)]
65 changes: 65 additions & 0 deletions tests/test_geospatial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
"""Created on Tue Oct 18 10:35:51 2022.
@author: Barney
"""

# import pytest
from unittest.mock import MagicMock, patch

import numpy as np
from scipy.interpolate import RegularGridInterpolator

from swmmanywhere import geospatial_operations as go


def test_interp_wrap():
"""Test the interp_wrap function."""
# Define a simple grid and values
x = np.linspace(0, 1, 5)
y = np.linspace(0, 1, 5)
xx, yy = np.meshgrid(x, y)
grid = np.vstack([xx.ravel(), yy.ravel()]).T
values = np.linspace(0, 1, 25)
values_grid = values.reshape(5, 5)

# Define an interpolator
interp = RegularGridInterpolator((x,y),
values_grid)

# Test the function at a point inside the grid
yx = (0.875, 0.875)
result = go.interp_wrap(yx, interp, grid, values)
assert result == 0.875

# Test the function on a nan point
values_grid[1][1] = np.nan
yx = (0.251, 0.25)
result = go.interp_wrap(yx, interp, grid, values)
assert result == values_grid[1][2]

@patch('rasterio.open')
def test_interpolate_points_on_raster(mock_rst_open):
"""Test the interpolate_points_on_raster function."""
# Mock the raster file
mock_src = MagicMock()
mock_src.read.return_value = np.array([[1, 2], [3, 4]])
mock_src.bounds = MagicMock()
mock_src.bounds.left = 0
mock_src.bounds.right = 1
mock_src.bounds.bottom = 0
mock_src.bounds.top = 1
mock_src.width = 2
mock_src.height = 2
mock_src.nodata = None
mock_rst_open.return_value.__enter__.return_value = mock_src

# Define the x and y coordinates
x = [0.25, 0.75]
y = [0.25, 0.75]

# Call the function
result = go.interpolate_points_on_raster(x, y, 'fake_path')

# [3,2] feels unintuitive but it's because rasters measure from the top
assert result == [3.0, 2.0]

0 comments on commit e3e7e89

Please sign in to comment.