From 916d47fa30b39526163a18380ec371d71d878e64 Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Mon, 27 May 2024 12:35:42 -0600 Subject: [PATCH] Fixed `ndfilters.mean_filter()` to support instances of `astropy.units.Quantity`. --- ndfilters/_mean.py | 7 ++++--- ndfilters/_tests/test_mean.py | 11 ++++++++--- pyproject.toml | 1 + 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/ndfilters/_mean.py b/ndfilters/_mean.py index 4f9c079..02bf59b 100644 --- a/ndfilters/_mean.py +++ b/ndfilters/_mean.py @@ -91,6 +91,7 @@ def _mean_filter_1d( array=array, size=size, where=where, + out=np.empty_like(array), ) result = result.reshape(shape) @@ -105,9 +106,9 @@ def _mean_filter_1d_numba( array: np.ndarray, size: int, where: np.ndarray, + out: np.ndarray, ) -> np.ndarray: - result = np.empty_like(array) num_t, num_x = array.shape halfsize = size // 2 @@ -132,6 +133,6 @@ def _mean_filter_1d_numba( sum += array[t, k] count += 1 - result[t, i] = sum / count + out[t, i] = sum / count - return result + return out diff --git a/ndfilters/_tests/test_mean.py b/ndfilters/_tests/test_mean.py index ac9197d..e13c64a 100644 --- a/ndfilters/_tests/test_mean.py +++ b/ndfilters/_tests/test_mean.py @@ -2,6 +2,7 @@ import numpy as np import scipy.ndimage import scipy.stats +import astropy.units as u import ndfilters @@ -10,7 +11,7 @@ argvalues=[ np.random.random(5), np.random.random((5, 6)), - np.random.random((5, 6, 7)), + np.random.random((5, 6, 7)) * u.mm, ], ) @pytest.mark.parametrize( @@ -75,9 +76,13 @@ def test_mean_filter( size=size_scipy, mode="constant", ) / scipy.ndimage.uniform_filter( - input=np.ones_like(array), + input=np.ones(array.shape), size=size_scipy, mode="constant", ) - assert np.allclose(result, expected) + if isinstance(result, u.Quantity): + assert np.allclose(result.value, expected) + assert result.unit == array.unit + else: + assert np.allclose(result, expected) diff --git a/pyproject.toml b/pyproject.toml index 15998be..f464ac9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dynamic = ["version"] test = [ "pytest", "scipy", + "astropy", ] doc = [ "pytest",