Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ndfilters.mean_filter() function. #8

Merged
merged 5 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ndfilters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from ._mean import mean_filter
from ._trimmed_mean import trimmed_mean_filter

__all__ = [
"mean_filter",
"trimmed_mean_filter",
]
137 changes: 137 additions & 0 deletions ndfilters/_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import numpy as np
import numba

__all__ = [
"mean_filter",
]


def mean_filter(
array: np.ndarray,
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...] = None,
where: bool | np.ndarray = True,
) -> np.ndarray:
"""
Calculate a multidimensional rolling mean.
The kernel is truncated at the edges of the array.

Parameters
----------
array
The input array to be filtered
size
The shape of the kernel over which the trimmed mean will be calculated.
axis
The axes over which to apply the kernel. If :class:`None` the kernel
is applied to every axis.
where
A boolean mask used to select which elements of the input array to filter.

Returns
-------
A copy of the array with a mean filter applied.

Examples
--------

.. jupyter-execute::

import matplotlib.pyplot as plt
import scipy.datasets
import ndfilters

img = scipy.datasets.ascent()
img_filtered = ndfilters.mean_filter(img, size=21)

fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True)
axs[0].set_title("original image");
axs[0].imshow(img, cmap="gray");
axs[1].set_title("mean filtered image");
axs[1].imshow(img_filtered, cmap="gray");
"""
array, where = np.broadcast_arrays(array, where, subok=True)

if axis is None:
axis = tuple(range(array.ndim))
else:
axis = np.core.numeric.normalize_axis_tuple(axis=axis, ndim=array.ndim)

if isinstance(size, int):
size = (size,) * len(axis)

result = array
for sz, ax in zip(size, axis, strict=True):
result = _mean_filter_1d(
array=result,
size=sz,
axis=ax,
where=where,
)

return result


def _mean_filter_1d(
array: np.ndarray,
size: int,
axis: int,
where: np.ndarray,
) -> np.ndarray:

array = np.moveaxis(array, axis, ~0)
where = np.moveaxis(where, axis, ~0)

shape = array.shape

array = array.reshape(-1, shape[~0])
where = where.reshape(-1, shape[~0])

result = _mean_filter_1d_numba(
array=array,
size=size,
where=where,
)

result = result.reshape(shape)

result = np.moveaxis(result, ~0, axis)

return result


@numba.njit(parallel=True)
def _mean_filter_1d_numba(
array: np.ndarray,
size: int,
where: np.ndarray,
) -> np.ndarray:

result = np.empty_like(array)
num_t, num_x = array.shape

halfsize = size // 2

for t in numba.prange(num_t):

for i in range(num_x):

sum = 0
count = 0
for j in range(size):

j2 = j - halfsize

k = i + j2
if k < 0:
continue
elif k >= num_x:
continue

if where[t, k]:
sum += array[t, k]
count += 1

result[t, i] = sum / count

return result
83 changes: 83 additions & 0 deletions ndfilters/_tests/test_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pytest
import numpy as np
import scipy.ndimage
import scipy.stats
import ndfilters


@pytest.mark.parametrize(
argnames="array",
argvalues=[
np.random.random(5),
np.random.random((5, 6)),
np.random.random((5, 6, 7)),
],
)
@pytest.mark.parametrize(
argnames="size",
argvalues=[2, (3,), (3, 4), (3, 4, 5)],
)
@pytest.mark.parametrize(
argnames="axis",
argvalues=[
None,
0,
-1,
(0,),
(-1,),
(0, 1),
(-2, -1),
(0, 1, 2),
(2, 1, 0),
],
)
def test_mean_filter(
array: np.ndarray,
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...],
):
kwargs = dict(
array=array,
size=size,
axis=axis,
)

if axis is None:
axis_normalized = tuple(range(array.ndim))
else:
try:
axis_normalized = np.core.numeric.normalize_axis_tuple(
axis, ndim=array.ndim
)
except np.AxisError:
with pytest.raises(np.AxisError):
ndfilters.mean_filter(**kwargs)
return

if isinstance(size, int):
size_normalized = (size,) * len(axis_normalized)
else:
size_normalized = size

if len(size_normalized) != len(axis_normalized):
with pytest.raises(ValueError):
ndfilters.mean_filter(**kwargs)
return

result = ndfilters.mean_filter(**kwargs)

size_scipy = [1] * array.ndim
for i, ax in enumerate(axis_normalized):
size_scipy[ax] = size_normalized[i]

expected = scipy.ndimage.uniform_filter(
input=array,
size=size_scipy,
mode="constant",
) / scipy.ndimage.uniform_filter(
input=np.ones_like(array),
size=size_scipy,
mode="constant",
)

assert np.allclose(result, expected)
Loading