From 492f8f376457af625285f707de5e4560951570ad Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Tue, 17 Sep 2024 14:02:56 -0600 Subject: [PATCH] Added short description and installation --- README.md | 9 ++- ndfilters/__init__.py | 2 + ndfilters/_tests/test_variance.py | 96 +++++++++++++++++++++++++++++++ ndfilters/_variance.py | 79 +++++++++++++++++++++++++ 4 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 ndfilters/_tests/test_variance.py create mode 100644 ndfilters/_variance.py diff --git a/README.md b/README.md index 1f5fba0..8a2f93c 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ pip install ndfilters The [mean filter](https://ndfilters.readthedocs.io/en/latest/_autosummary/ndfilters.mean_filter.html#ndfilters.mean_filter) calculates a multidimensional rolling mean for the given kernel shape. -![mean filter](https://ndfilters.readthedocs.io/en/latest/_images/ndfilters.mean_filter_0_2.png) +![mean filter](https://ndfilters.readthedocs.io/en/latest/_images/ndfilters.mean_filter_0_0.png) ### Trimmed mean filter @@ -33,3 +33,10 @@ The [trimmed mean filter](https://ndfilters.readthedocs.io/en/latest/_autosumma is like the mean filter except it ignores a given portion of the dataset before calculating the mean at each pixel. ![trimmed mean filter](https://ndfilters.readthedocs.io/en/latest/_images/ndfilters.trimmed_mean_filter_0_0.png) + +### Variance filter + +The [variance filter](https://ndfilters.readthedocs.io/en/latest/_autosummary/ndfilters.variance_filter.html#ndfilters.variance_filter) +calculates the rolling variance for the given kernel shape. + +![variance filter](https://ndfilters.readthedocs.io/en/latest/_images/ndfilters.variance_filter_0_0.png) diff --git a/ndfilters/__init__.py b/ndfilters/__init__.py index 941cc1d..4c9c6a9 100644 --- a/ndfilters/__init__.py +++ b/ndfilters/__init__.py @@ -5,9 +5,11 @@ from ._generic import generic_filter from ._mean import mean_filter from ._trimmed_mean import trimmed_mean_filter +from ._variance import variance_filter __all__ = [ "generic_filter", "mean_filter", "trimmed_mean_filter", + "variance_filter", ] diff --git a/ndfilters/_tests/test_variance.py b/ndfilters/_tests/test_variance.py new file mode 100644 index 0000000..35e2904 --- /dev/null +++ b/ndfilters/_tests/test_variance.py @@ -0,0 +1,96 @@ +from typing import Literal +import pytest +import numpy as np +import scipy.ndimage +import scipy.stats +import astropy.units as u +import ndfilters + + +@pytest.mark.parametrize( + argnames="array", + argvalues=[ + np.random.random(5), + np.random.random((5, 6)), + np.random.random((5, 6, 7)) * u.mm, + ], +) +@pytest.mark.parametrize( + argnames="size", + argvalues=[2, (3,), (3, 4), (3, 4, 5)], +) +@pytest.mark.parametrize( + argnames="axis", + argvalues=[ + None, + 0, + -1, + (0,), + (-1,), + (0, 1), + (-2, -1), + (0, 1, 2), + (2, 1, 0), + ], +) +@pytest.mark.parametrize( + argnames="mode", + argvalues=[ + "mirror", + "nearest", + "wrap", + ], +) +def test_variance_filter( + array: np.ndarray, + size: int | tuple[int, ...], + axis: None | int | tuple[int, ...], + mode: Literal["mirror", "nearest", "wrap", "truncate"], +): + kwargs = dict( + array=array, + size=size, + axis=axis, + mode=mode, + ) + + if axis is None: + axis_normalized = tuple(range(array.ndim)) + else: + try: + axis_normalized = np.core.numeric.normalize_axis_tuple( + axis, ndim=array.ndim + ) + except np.AxisError: + with pytest.raises(np.AxisError): + ndfilters.variance_filter(**kwargs) + return + + if isinstance(size, int): + size_normalized = (size,) * len(axis_normalized) + else: + size_normalized = size + + if len(size_normalized) != len(axis_normalized): + with pytest.raises(ValueError): + ndfilters.variance_filter(**kwargs) + return + + result = ndfilters.variance_filter(**kwargs) + + size_scipy = [1] * array.ndim + for i, ax in enumerate(axis_normalized): + size_scipy[ax] = size_normalized[i] + + expected = scipy.ndimage.generic_filter( + input=array, + function=np.var, + size=size_scipy, + mode=mode, + ) + + if isinstance(result, u.Quantity): + assert np.allclose(result.value, expected) + assert result.unit == array.unit + else: + assert np.allclose(result, expected) diff --git a/ndfilters/_variance.py b/ndfilters/_variance.py new file mode 100644 index 0000000..bef88c1 --- /dev/null +++ b/ndfilters/_variance.py @@ -0,0 +1,79 @@ +from typing import Literal +import numpy as np +import numba +import astropy.units as u +import ndfilters + +__all__ = [ + "variance_filter", +] + + +def variance_filter( + array: np.ndarray | u.Quantity, + size: int | tuple[int, ...], + axis: None | int | tuple[int, ...] = None, + where: bool | np.ndarray = True, + mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror", +) -> np.ndarray: + """ + Calculate a multidimensional rolling variance. + + Parameters + ---------- + array + The input array to be filtered + size + The shape of the kernel over which the variance will be calculated. + axis + The axes over which to apply the kernel. + Should either be a scalar or have the same number of items as `size`. + If :obj:`None` (the default) the kernel spans every axis of the array. + where + An optional mask that can be used to exclude parts of the array during + filtering. + mode + The method used to extend the input array beyond its boundaries. + See :func:`scipy.ndimage.generic_filter` for the definitions. + Currently, only "mirror", "nearest", "wrap", and "truncate" modes are + supported. + + Returns + ------- + A copy of the array with the variance filter applied. + + Examples + -------- + + .. jupyter-execute:: + + import matplotlib.pyplot as plt + import scipy.datasets + import ndfilters + + img = scipy.datasets.ascent() + img_filtered = ndfilters.variance_filter(img, size=21) + + fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True) + axs[0].set_title("original image"); + axs[0].imshow(img, cmap="gray"); + axs[1].set_title("filtered image"); + axs[1].imshow(img_filtered, cmap="gray"); + + """ + return ndfilters.generic_filter( + array=array, + function=_variance, + size=size, + axis=axis, + where=where, + mode=mode, + ) + + +@numba.njit +def _variance( + array: np.ndarray, + args: tuple[float], +) -> float: + return np.var(array)