Skip to content

Commit

Permalink
Fixed ndfilters.mean_filter() to support instances of `astropy.unit…
Browse files Browse the repository at this point in the history
…s.Quantity`.
  • Loading branch information
byrdie committed May 27, 2024
1 parent ecbe490 commit 916d47f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
7 changes: 4 additions & 3 deletions ndfilters/_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def _mean_filter_1d(
array=array,
size=size,
where=where,
out=np.empty_like(array),
)

result = result.reshape(shape)
Expand All @@ -105,9 +106,9 @@ def _mean_filter_1d_numba(
array: np.ndarray,
size: int,
where: np.ndarray,
out: np.ndarray,
) -> np.ndarray:

result = np.empty_like(array)
num_t, num_x = array.shape

halfsize = size // 2
Expand All @@ -132,6 +133,6 @@ def _mean_filter_1d_numba(
sum += array[t, k]
count += 1

result[t, i] = sum / count
out[t, i] = sum / count

return result
return out
11 changes: 8 additions & 3 deletions ndfilters/_tests/test_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import scipy.ndimage
import scipy.stats
import astropy.units as u
import ndfilters


Expand All @@ -10,7 +11,7 @@
argvalues=[
np.random.random(5),
np.random.random((5, 6)),
np.random.random((5, 6, 7)),
np.random.random((5, 6, 7)) * u.mm,
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -75,9 +76,13 @@ def test_mean_filter(
size=size_scipy,
mode="constant",
) / scipy.ndimage.uniform_filter(
input=np.ones_like(array),
input=np.ones(array.shape),
size=size_scipy,
mode="constant",
)

assert np.allclose(result, expected)
if isinstance(result, u.Quantity):
assert np.allclose(result.value, expected)
assert result.unit == array.unit
else:
assert np.allclose(result, expected)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dynamic = ["version"]
test = [
"pytest",
"scipy",
"astropy",
]
doc = [
"pytest",
Expand Down

0 comments on commit 916d47f

Please sign in to comment.