diff --git a/pyproject.toml b/pyproject.toml index 041db53..7ab0a8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,8 +26,6 @@ dependencies = [ "xarray", "flox", "scipy", - "sparse", - "opt-einsum", ] [tool.hatch.build] @@ -41,10 +39,15 @@ Issues = "https://github.com/EXCITED-CO2/xarray-regrid/issues" Source = "https://github.com/EXCITED-CO2/xarray-regrid" [project.optional-dependencies] +accel = [ + "sparse", + "opt-einsum", +] benchmarking = [ "dask[distributed]", "matplotlib", "zarr", + "h5netcdf", "requests", "aiohttp", ] @@ -71,7 +74,7 @@ docs = [ # Required for ReadTheDocs path = "src/xarray_regrid/__init__.py" [tool.hatch.envs.default] -features = ["dev", "benchmarking"] +features = ["accel", "dev", "benchmarking"] [tool.hatch.envs.default.scripts] lint = [ diff --git a/src/xarray_regrid/methods/conservative.py b/src/xarray_regrid/methods/conservative.py index 30a96a4..8c384b6 100644 --- a/src/xarray_regrid/methods/conservative.py +++ b/src/xarray_regrid/methods/conservative.py @@ -5,7 +5,11 @@ import numpy as np import xarray as xr -from sparse import COO # type: ignore + +try: + import sparse # type: ignore +except ImportError: + sparse = None from xarray_regrid import utils @@ -126,7 +130,11 @@ def conservative_regrid_dataset( for array in data_vars.keys(): if coord in data_vars[array].dims: - var_weights = sparsify_weights(weights, data_vars[array]) + if sparse is not None: + var_weights = sparsify_weights(weights, data_vars[array]) + else: + var_weights = weights + data_vars[array], valid_fracs[array] = apply_weights( da=data_vars[array], weights=var_weights, @@ -200,8 +208,12 @@ def apply_weights( valid_frac = valid_frac.clip(0, 1) # In some cases, dot product of dask data and sparse weights fails - # to densify, which prevents future conversion to numpy - if da_reduced.chunks and isinstance(da_reduced.data._meta, COO): + # to automatically densify, which prevents future conversion to numpy + if ( + sparse is not None + and da_reduced.chunks + and isinstance(da_reduced.data._meta, sparse.COO) + ): da_reduced.data = da_reduced.data.map_blocks( lambda x: x.todense(), dtype=da_reduced.dtype ) @@ -268,8 +280,8 @@ def sparsify_weights(weights: xr.DataArray, da: xr.DataArray) -> xr.DataArray: new_weights = weights.copy().astype(da.dtype) if da.chunks: chunks = {k: v for k, v in da.chunksizes.items() if k in weights.dims} - new_weights.data = new_weights.chunk(chunks).data.map_blocks(COO) + new_weights.data = new_weights.chunk(chunks).data.map_blocks(sparse.COO) else: - new_weights.data = COO(weights.data) + new_weights.data = sparse.COO(weights.data) return new_weights diff --git a/tests/test_regrid.py b/tests/test_regrid.py index b720a3a..be9770a 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -208,7 +208,7 @@ def test_conservative_nan_thresholds_against_coarsen(nan_threshold): @pytest.mark.skipif(xesmf is None, reason="xesmf required") def test_conservative_nan_thresholds_against_xesmf(): - ds = xr.tutorial.open_dataset("ersstv5").sst.compute().isel(time=[0]) + ds = xr.tutorial.open_dataset("ersstv5").sst.isel(time=[0]).compute() ds = ds.rename(lon="longitude", lat="latitude") new_grid = xarray_regrid.Grid( north=90,