Skip to content

Commit

Permalink
Added ndfilters.mean_filter() function. (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
byrdie authored May 27, 2024
1 parent 96d63d8 commit e39a644
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ndfilters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from ._mean import mean_filter
from ._trimmed_mean import trimmed_mean_filter

__all__ = [
"mean_filter",
"trimmed_mean_filter",
]
137 changes: 137 additions & 0 deletions ndfilters/_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import numpy as np
import numba

__all__ = [
"mean_filter",
]


def mean_filter(
array: np.ndarray,
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...] = None,
where: bool | np.ndarray = True,
) -> np.ndarray:
"""
Calculate a multidimensional rolling mean.
The kernel is truncated at the edges of the array.
Parameters
----------
array
The input array to be filtered
size
The shape of the kernel over which the trimmed mean will be calculated.
axis
The axes over which to apply the kernel. If :class:`None` the kernel
is applied to every axis.
where
A boolean mask used to select which elements of the input array to filter.
Returns
-------
A copy of the array with a mean filter applied.
Examples
--------
.. jupyter-execute::
import matplotlib.pyplot as plt
import scipy.datasets
import ndfilters
img = scipy.datasets.ascent()
img_filtered = ndfilters.mean_filter(img, size=21)
fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True)
axs[0].set_title("original image");
axs[0].imshow(img, cmap="gray");
axs[1].set_title("mean filtered image");
axs[1].imshow(img_filtered, cmap="gray");
"""
array, where = np.broadcast_arrays(array, where, subok=True)

if axis is None:
axis = tuple(range(array.ndim))
else:
axis = np.core.numeric.normalize_axis_tuple(axis=axis, ndim=array.ndim)

if isinstance(size, int):
size = (size,) * len(axis)

result = array
for sz, ax in zip(size, axis, strict=True):
result = _mean_filter_1d(
array=result,
size=sz,
axis=ax,
where=where,
)

return result


def _mean_filter_1d(
array: np.ndarray,
size: int,
axis: int,
where: np.ndarray,
) -> np.ndarray:

array = np.moveaxis(array, axis, ~0)
where = np.moveaxis(where, axis, ~0)

shape = array.shape

array = array.reshape(-1, shape[~0])
where = where.reshape(-1, shape[~0])

result = _mean_filter_1d_numba(
array=array,
size=size,
where=where,
)

result = result.reshape(shape)

result = np.moveaxis(result, ~0, axis)

return result


@numba.njit(parallel=True)
def _mean_filter_1d_numba(
array: np.ndarray,
size: int,
where: np.ndarray,
) -> np.ndarray:

result = np.empty_like(array)
num_t, num_x = array.shape

halfsize = size // 2

for t in numba.prange(num_t):

for i in range(num_x):

sum = 0
count = 0
for j in range(size):

j2 = j - halfsize

k = i + j2
if k < 0:
continue
elif k >= num_x:
continue

if where[t, k]:
sum += array[t, k]
count += 1

result[t, i] = sum / count

return result
83 changes: 83 additions & 0 deletions ndfilters/_tests/test_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pytest
import numpy as np
import scipy.ndimage
import scipy.stats
import ndfilters


@pytest.mark.parametrize(
argnames="array",
argvalues=[
np.random.random(5),
np.random.random((5, 6)),
np.random.random((5, 6, 7)),
],
)
@pytest.mark.parametrize(
argnames="size",
argvalues=[2, (3,), (3, 4), (3, 4, 5)],
)
@pytest.mark.parametrize(
argnames="axis",
argvalues=[
None,
0,
-1,
(0,),
(-1,),
(0, 1),
(-2, -1),
(0, 1, 2),
(2, 1, 0),
],
)
def test_mean_filter(
array: np.ndarray,
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...],
):
kwargs = dict(
array=array,
size=size,
axis=axis,
)

if axis is None:
axis_normalized = tuple(range(array.ndim))
else:
try:
axis_normalized = np.core.numeric.normalize_axis_tuple(
axis, ndim=array.ndim
)
except np.AxisError:
with pytest.raises(np.AxisError):
ndfilters.mean_filter(**kwargs)
return

if isinstance(size, int):
size_normalized = (size,) * len(axis_normalized)
else:
size_normalized = size

if len(size_normalized) != len(axis_normalized):
with pytest.raises(ValueError):
ndfilters.mean_filter(**kwargs)
return

result = ndfilters.mean_filter(**kwargs)

size_scipy = [1] * array.ndim
for i, ax in enumerate(axis_normalized):
size_scipy[ax] = size_normalized[i]

expected = scipy.ndimage.uniform_filter(
input=array,
size=size_scipy,
mode="constant",
) / scipy.ndimage.uniform_filter(
input=np.ones_like(array),
size=size_scipy,
mode="constant",
)

assert np.allclose(result, expected)

0 comments on commit e39a644

Please sign in to comment.