Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modified ndfilters.mean_filter() to use ndfilters.generic_filter(). #18

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 27 additions & 86 deletions ndfilters/_mean.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from typing import Literal
import numpy as np
import numba
import astropy.units as u
import ndfilters

__all__ = [
"mean_filter",
]


def mean_filter(
array: np.ndarray,
array: np.ndarray | u.Quantity,
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...] = None,
where: bool | np.ndarray = True,
mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror",
) -> np.ndarray:
"""
Calculate a multidimensional rolling mean.
The kernel is truncated at the edges of the array.

Parameters
----------
Expand All @@ -23,14 +26,21 @@ def mean_filter(
size
The shape of the kernel over which the mean will be calculated.
axis
The axes over which to apply the kernel. If :obj:`None` the kernel
is applied to every axis.
The axes over which to apply the kernel.
Should either be a scalar or have the same number of items as `size`.
If :obj:`None` (the default) the kernel spans every axis of the array.
where
A boolean mask used to select which elements of the input array to filter.
An optional mask that can be used to exclude parts of the array during
filtering.
mode
The method used to extend the input array beyond its boundaries.
See :func:`scipy.ndimage.generic_filter` for the definitions.
Currently, only "mirror", "nearest", "wrap", and "truncate" modes are
supported.

Returns
-------
A copy of the array with a mean filter applied.
A copy of the array with the mean filter applied.

Examples
--------
Expand All @@ -47,92 +57,23 @@ def mean_filter(
fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True)
axs[0].set_title("original image");
axs[0].imshow(img, cmap="gray");
axs[1].set_title("mean filtered image");
axs[1].set_title("filtered image");
axs[1].imshow(img_filtered, cmap="gray");
"""
array, where = np.broadcast_arrays(array, where, subok=True)

if axis is None:
axis = tuple(range(array.ndim))
else:
axis = np.core.numeric.normalize_axis_tuple(axis=axis, ndim=array.ndim)

if isinstance(size, int):
size = (size,) * len(axis)

result = array
for sz, ax in zip(size, axis, strict=True):
result = _mean_filter_1d(
array=result,
size=sz,
axis=ax,
where=where,
)

return result


def _mean_filter_1d(
array: np.ndarray,
size: int,
axis: int,
where: np.ndarray,
) -> np.ndarray:

array = np.moveaxis(array, axis, ~0)
where = np.moveaxis(where, axis, ~0)

shape = array.shape

array = array.reshape(-1, shape[~0])
where = where.reshape(-1, shape[~0])

result = _mean_filter_1d_numba(
"""
return ndfilters.generic_filter(
array=array,
function=_mean,
size=size,
axis=axis,
where=where,
out=np.empty_like(array),
mode=mode,
)

result = result.reshape(shape)

result = np.moveaxis(result, ~0, axis)

return result


@numba.njit(parallel=True, cache=True)
def _mean_filter_1d_numba(
@numba.njit
def _mean(
array: np.ndarray,
size: int,
where: np.ndarray,
out: np.ndarray,
) -> np.ndarray:

num_t, num_x = array.shape

halfsize = size // 2

for t in numba.prange(num_t):

for i in range(num_x):

sum = 0
count = 0
for j in range(size):

j2 = j - halfsize

k = i + j2
if k < 0:
continue
elif k >= num_x:
continue

if where[t, k]:
sum += array[t, k]
count += 1

out[t, i] = sum / count

return out
args: tuple[float],
) -> float:
return np.mean(array)
17 changes: 12 additions & 5 deletions ndfilters/_tests/test_mean.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Literal
import pytest
import numpy as np
import scipy.ndimage
Expand Down Expand Up @@ -32,15 +33,25 @@
(2, 1, 0),
],
)
@pytest.mark.parametrize(
argnames="mode",
argvalues=[
"mirror",
"nearest",
"wrap",
],
)
def test_mean_filter(
array: np.ndarray,
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...],
mode: Literal["mirror", "nearest", "wrap", "truncate"],
):
kwargs = dict(
array=array,
size=size,
axis=axis,
mode=mode,
)

if axis is None:
Expand Down Expand Up @@ -74,11 +85,7 @@ def test_mean_filter(
expected = scipy.ndimage.uniform_filter(
input=array,
size=size_scipy,
mode="constant",
) / scipy.ndimage.uniform_filter(
input=np.ones(array.shape),
size=size_scipy,
mode="constant",
mode=mode,
)

if isinstance(result, u.Quantity):
Expand Down
Loading