generated from ImperialCollegeLondon/pip-tools-template
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
-Add raster interpolation and tests
- Loading branch information
Dobson
committed
Jan 19, 2024
1 parent
492bac3
commit e3e7e89
Showing
2 changed files
with
146 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Created 2024-01-20. | ||
@author: Barnaby Dobson | ||
""" | ||
import numpy as np | ||
import rasterio as rst | ||
from scipy.interpolate import RegularGridInterpolator | ||
|
||
|
||
def interp_wrap(xy: tuple[float,float], | ||
interp: RegularGridInterpolator, | ||
grid: np.ndarray, | ||
values: list[float]) -> float: | ||
"""Wrap the interpolation function to handle NaNs. | ||
Picks the nearest non NaN grid point if the interpolated value is NaN, | ||
otherwise returns the interpolated value. | ||
Args: | ||
xy (tuple): Coordinate of interest | ||
interp (RegularGridInterpolator): The interpolator object. | ||
grid (np.ndarray): List of xy coordinates of the grid points. | ||
values (list): The list of values at each point in the grid. | ||
Returns: | ||
float: The interpolated value. | ||
""" | ||
# Call the interpolator | ||
val = float(interp(xy)) | ||
# If the value is NaN, we need to pick nearest non nan grid point | ||
if np.isnan(val): | ||
# Get the distances to all grid points | ||
distances = np.linalg.norm(grid - xy, axis=1) | ||
# Get the indices of the grid points sorted by distance | ||
indices = np.argsort(distances) | ||
# Iterate over the grid points in order of increasing distance | ||
for index in indices: | ||
# If the value at this grid point is not NaN, return it | ||
if not np.isnan(values[index]): | ||
return values[index] | ||
else: | ||
return val | ||
|
||
raise ValueError("No non NaN values found in grid.") | ||
|
||
def interpolate_points_on_raster(x: list[float], | ||
y: list[float], | ||
elevation_fid: str) -> list[float ]: | ||
"""Interpolate points on a raster. | ||
Args: | ||
x (list): X coordinates. | ||
y (list): Y coordinates. | ||
elevation_fid (str): Filepath to elevation raster. | ||
Returns: | ||
elevation (float): Elevation at point. | ||
""" | ||
with rst.open(elevation_fid) as src: | ||
# Read the raster data | ||
data = src.read(1).astype(float) # Assuming it's a single-band raster | ||
data[data == src.nodata] = None | ||
|
||
# Get the raster's coordinates | ||
x = np.linspace(src.bounds.left, src.bounds.right, src.width) | ||
y = np.linspace(src.bounds.bottom, src.bounds.top, src.height) | ||
|
||
# Define grid | ||
xx, yy = np.meshgrid(x, y) | ||
grid = np.vstack([xx.ravel(), yy.ravel()]).T | ||
values = data.ravel() | ||
|
||
# Define interpolator | ||
interp = RegularGridInterpolator((y,x), | ||
np.flipud(data), | ||
method='linear', | ||
bounds_error=False, | ||
fill_value=None) | ||
# Interpolate for x,y | ||
return [interp_wrap((y_, x_), interp, grid, values) for x_, y_ in zip(x,y)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Created on Tue Oct 18 10:35:51 2022. | ||
@author: Barney | ||
""" | ||
|
||
# import pytest | ||
from unittest.mock import MagicMock, patch | ||
|
||
import numpy as np | ||
from scipy.interpolate import RegularGridInterpolator | ||
|
||
from swmmanywhere import geospatial_operations as go | ||
|
||
|
||
def test_interp_wrap(): | ||
"""Test the interp_wrap function.""" | ||
# Define a simple grid and values | ||
x = np.linspace(0, 1, 5) | ||
y = np.linspace(0, 1, 5) | ||
xx, yy = np.meshgrid(x, y) | ||
grid = np.vstack([xx.ravel(), yy.ravel()]).T | ||
values = np.linspace(0, 1, 25) | ||
values_grid = values.reshape(5, 5) | ||
|
||
# Define an interpolator | ||
interp = RegularGridInterpolator((x,y), | ||
values_grid) | ||
|
||
# Test the function at a point inside the grid | ||
yx = (0.875, 0.875) | ||
result = go.interp_wrap(yx, interp, grid, values) | ||
assert result == 0.875 | ||
|
||
# Test the function on a nan point | ||
values_grid[1][1] = np.nan | ||
yx = (0.251, 0.25) | ||
result = go.interp_wrap(yx, interp, grid, values) | ||
assert result == values_grid[1][2] | ||
|
||
@patch('rasterio.open') | ||
def test_interpolate_points_on_raster(mock_rst_open): | ||
"""Test the interpolate_points_on_raster function.""" | ||
# Mock the raster file | ||
mock_src = MagicMock() | ||
mock_src.read.return_value = np.array([[1, 2], [3, 4]]) | ||
mock_src.bounds = MagicMock() | ||
mock_src.bounds.left = 0 | ||
mock_src.bounds.right = 1 | ||
mock_src.bounds.bottom = 0 | ||
mock_src.bounds.top = 1 | ||
mock_src.width = 2 | ||
mock_src.height = 2 | ||
mock_src.nodata = None | ||
mock_rst_open.return_value.__enter__.return_value = mock_src | ||
|
||
# Define the x and y coordinates | ||
x = [0.25, 0.75] | ||
y = [0.25, 0.75] | ||
|
||
# Call the function | ||
result = go.interpolate_points_on_raster(x, y, 'fake_path') | ||
|
||
# [3,2] feels unintuitive but it's because rasters measure from the top | ||
assert result == [3.0, 2.0] |