Skip to content

Commit

Permalink
Update geospatial
Browse files Browse the repository at this point in the history
-Add reproject raster and tests
-Add get_utm_epsg and tests
  • Loading branch information
Dobson committed Jan 19, 2024
1 parent e3e7e89 commit fd1d880
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 2 deletions.
64 changes: 63 additions & 1 deletion swmmanywhere/geospatial_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,34 @@
@author: Barnaby Dobson
"""
from typing import Optional

import numpy as np
import rasterio as rst
from rasterio.warp import Resampling, calculate_default_transform, reproject
from scipy.interpolate import RegularGridInterpolator


def get_utm_epsg(lon: float, lat: float) -> str:
"""Get the formatted UTM EPSG code for a given longitude and latitude.
Args:
lon (float): Longitude in EPSG:4326 (x)
lat (float): Latitude in EPSG:4326 (y)
Returns:
str: Formatted EPSG code for the UTM zone.
Example:
>>> get_utm_epsg(-0.1276, 51.5074)
'EPSG:32630'
"""
# Determine the UTM zone number
zone_number = int((lon + 180) / 6) + 1
# Determine the UTM EPSG code based on the hemisphere
utm_epsg = 32600 + zone_number if lat >= 0 else 32700 + zone_number
return 'EPSG:{0}'.format(utm_epsg)

def interp_wrap(xy: tuple[float,float],
interp: RegularGridInterpolator,
grid: np.ndarray,
Expand Down Expand Up @@ -78,4 +101,43 @@ def interpolate_points_on_raster(x: list[float],
bounds_error=False,
fill_value=None)
# Interpolate for x,y
return [interp_wrap((y_, x_), interp, grid, values) for x_, y_ in zip(x,y)]
return [interp_wrap((y_, x_), interp, grid, values) for x_, y_ in zip(x,y)]

def reproject_raster(target_crs: str,
fid: str,
new_fid: Optional[str] = None):
"""Reproject a raster to a new CRS.
Args:
target_crs (str): Target CRS in EPSG format (e.g., EPSG:32630).
fid (str): Filepath to the raster to reproject.
new_fid (str, optional): Filepath to save the reprojected raster.
Defaults to None, which will just use fid with '_reprojected'.
"""
with rst.open(fid) as src:
# Define the transformation parameters for reprojection
transform, width, height = calculate_default_transform(
src.crs, target_crs, src.width, src.height, *src.bounds)

# Create the output raster file
kwargs = src.meta.copy()
kwargs.update({
'crs': target_crs,
'transform': transform,
'width': width,
'height': height
})
if new_fid is None:
new_fid = fid.replace('.tif','_reprojected.tif')

with rst.open(new_fid, 'w', **kwargs) as dst:
# Reproject the data
reproject(
source=rst.band(src, 1),
destination=rst.band(dst, 1),
src_transform=src.transform,
src_crs=src.crs,
dst_transform=transform,
dst_crs=target_crs,
resampling=Resampling.bilinear
)
50 changes: 49 additions & 1 deletion tests/test_geospatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
"""

# import pytest
import os
from unittest.mock import MagicMock, patch

import numpy as np
import rasterio as rst
from scipy.interpolate import RegularGridInterpolator

from swmmanywhere import geospatial_operations as go
Expand Down Expand Up @@ -62,4 +64,50 @@ def test_interpolate_points_on_raster(mock_rst_open):
result = go.interpolate_points_on_raster(x, y, 'fake_path')

# [3,2] feels unintuitive but it's because rasters measure from the top
assert result == [3.0, 2.0]
assert result == [3.0, 2.0]

def test_get_utm():
"""Test the get_utm_epsg function."""
# Test a northern hemisphere point
crs = go.get_utm_epsg(-1, 51)
assert crs == 'EPSG:32630'

# Test a southern hemisphere point
crs = go.get_utm_epsg(-1, -51)
assert crs == 'EPSG:32730'


def test_reproject_raster():
"""Test the reproject_raster function."""
# Create a mock raster file
fid = 'test.tif'
data = np.random.randint(0, 255, (100, 100)).astype('uint8')
transform = rst.transform.from_origin(0, 0, 0.1, 0.1)
with rst.open(fid,
'w',
driver='GTiff',
height=100,
width=100,
count=1,
dtype='uint8',
crs='EPSG:4326',
transform=transform) as src:
src.write(data, 1)

# Define the input parameters
target_crs = 'EPSG:32630'
new_fid = 'test_reprojected.tif'

# Call the function
go.reproject_raster(target_crs, fid)

# Check if the reprojected file exists
assert os.path.exists(new_fid)

# Check if the reprojected file has the correct CRS
with rst.open(new_fid) as src:
assert src.crs.to_string() == target_crs

# Clean up the created files
os.remove(fid)
os.remove(new_fid)

0 comments on commit fd1d880

Please sign in to comment.