From fd1d8804111db0c9d27cb4a23cd12d70751fae16 Mon Sep 17 00:00:00 2001 From: Dobson Date: Fri, 19 Jan 2024 15:42:18 +0000 Subject: [PATCH] Update geospatial -Add reproject raster and tests -Add get_utm_epsg and tests --- swmmanywhere/geospatial_operations.py | 64 ++++++++++++++++++++++++++- tests/test_geospatial.py | 50 ++++++++++++++++++++- 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py index bde8de0e..9f095a75 100644 --- a/swmmanywhere/geospatial_operations.py +++ b/swmmanywhere/geospatial_operations.py @@ -3,11 +3,34 @@ @author: Barnaby Dobson """ +from typing import Optional + import numpy as np import rasterio as rst +from rasterio.warp import Resampling, calculate_default_transform, reproject from scipy.interpolate import RegularGridInterpolator +def get_utm_epsg(lon: float, lat: float) -> str: + """Get the formatted UTM EPSG code for a given longitude and latitude. + + Args: + lon (float): Longitude in EPSG:4326 (x) + lat (float): Latitude in EPSG:4326 (y) + + Returns: + str: Formatted EPSG code for the UTM zone. + + Example: + >>> get_utm_epsg(-0.1276, 51.5074) + 'EPSG:32630' + """ + # Determine the UTM zone number + zone_number = int((lon + 180) / 6) + 1 + # Determine the UTM EPSG code based on the hemisphere + utm_epsg = 32600 + zone_number if lat >= 0 else 32700 + zone_number + return 'EPSG:{0}'.format(utm_epsg) + def interp_wrap(xy: tuple[float,float], interp: RegularGridInterpolator, grid: np.ndarray, @@ -78,4 +101,43 @@ def interpolate_points_on_raster(x: list[float], bounds_error=False, fill_value=None) # Interpolate for x,y - return [interp_wrap((y_, x_), interp, grid, values) for x_, y_ in zip(x,y)] \ No newline at end of file + return [interp_wrap((y_, x_), interp, grid, values) for x_, y_ in zip(x,y)] + +def reproject_raster(target_crs: str, + fid: str, + new_fid: Optional[str] = None): + """Reproject a raster to a new CRS. + + Args: + target_crs (str): Target CRS in EPSG format (e.g., EPSG:32630). + fid (str): Filepath to the raster to reproject. + new_fid (str, optional): Filepath to save the reprojected raster. + Defaults to None, which will just use fid with '_reprojected'. + """ + with rst.open(fid) as src: + # Define the transformation parameters for reprojection + transform, width, height = calculate_default_transform( + src.crs, target_crs, src.width, src.height, *src.bounds) + + # Create the output raster file + kwargs = src.meta.copy() + kwargs.update({ + 'crs': target_crs, + 'transform': transform, + 'width': width, + 'height': height + }) + if new_fid is None: + new_fid = fid.replace('.tif','_reprojected.tif') + + with rst.open(new_fid, 'w', **kwargs) as dst: + # Reproject the data + reproject( + source=rst.band(src, 1), + destination=rst.band(dst, 1), + src_transform=src.transform, + src_crs=src.crs, + dst_transform=transform, + dst_crs=target_crs, + resampling=Resampling.bilinear + ) \ No newline at end of file diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py index 1dade0ce..76859a7d 100644 --- a/tests/test_geospatial.py +++ b/tests/test_geospatial.py @@ -5,9 +5,11 @@ """ # import pytest +import os from unittest.mock import MagicMock, patch import numpy as np +import rasterio as rst from scipy.interpolate import RegularGridInterpolator from swmmanywhere import geospatial_operations as go @@ -62,4 +64,50 @@ def test_interpolate_points_on_raster(mock_rst_open): result = go.interpolate_points_on_raster(x, y, 'fake_path') # [3,2] feels unintuitive but it's because rasters measure from the top - assert result == [3.0, 2.0] \ No newline at end of file + assert result == [3.0, 2.0] + +def test_get_utm(): + """Test the get_utm_epsg function.""" + # Test a northern hemisphere point + crs = go.get_utm_epsg(-1, 51) + assert crs == 'EPSG:32630' + + # Test a southern hemisphere point + crs = go.get_utm_epsg(-1, -51) + assert crs == 'EPSG:32730' + + +def test_reproject_raster(): + """Test the reproject_raster function.""" + # Create a mock raster file + fid = 'test.tif' + data = np.random.randint(0, 255, (100, 100)).astype('uint8') + transform = rst.transform.from_origin(0, 0, 0.1, 0.1) + with rst.open(fid, + 'w', + driver='GTiff', + height=100, + width=100, + count=1, + dtype='uint8', + crs='EPSG:4326', + transform=transform) as src: + src.write(data, 1) + + # Define the input parameters + target_crs = 'EPSG:32630' + new_fid = 'test_reprojected.tif' + + # Call the function + go.reproject_raster(target_crs, fid) + + # Check if the reprojected file exists + assert os.path.exists(new_fid) + + # Check if the reprojected file has the correct CRS + with rst.open(new_fid) as src: + assert src.crs.to_string() == target_crs + + # Clean up the created files + os.remove(fid) + os.remove(new_fid) \ No newline at end of file