Skip to content

Commit

Permalink
Energy distance num drift test (#400)
Browse files Browse the repository at this point in the history
* Added energy distance num drift test

* removed nans from test example

* isort changes
  • Loading branch information
AntonisCSt authored Oct 26, 2022
1 parent 947b88f commit f6811df
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/book/customization/options-for-statistical-tests.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,7 @@ example_stat_test = StatTest(
- only for numerical features
- returns `p_value`
- drift detected when `p_value < threshold`
- `ed` - Energy distance
- only for numerical features
- returns `distance`
- drift detected when `distance >= threshold`
1 change: 1 addition & 0 deletions src/evidently/calculations/stattests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .anderson_darling_stattest import anderson_darling_test
from .chisquare_stattest import chi_stat_test
from .cramer_von_mises_stattest import cramer_von_mises
from .energy_distance import energy_dist_test
from .fisher_exact_stattest import fisher_exact_test
from .g_stattest import g_test
from .jensenshannon import jensenshannon_stat_test
Expand Down
30 changes: 30 additions & 0 deletions src/evidently/calculations/stattests/energy_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Tuple

import pandas as pd
from scipy.stats import energy_distance

from evidently.calculations.stattests.registry import StatTest
from evidently.calculations.stattests.registry import register_stattest


def _energy_dist(
reference_data: pd.Series, current_data: pd.Series, feature_type: str, threshold: float
) -> Tuple[float, bool]:
"""Run the energy_distance test of two samples.
Args:
reference_data: reference data
current_data: current data
threshold: all values above this threshold propose a data drift
Returns:
distance: energy distance
test_result: whether the drift is detected
"""
distance = energy_distance(reference_data, current_data)
return distance, distance > threshold


energy_dist_test = StatTest(
name="ed", display_name="Energy-distance", func=_energy_dist, allowed_feature_types=["num"], default_threshold=0.1
)

register_stattest(energy_dist_test)
8 changes: 8 additions & 0 deletions tests/stattests/test_stattests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from evidently.calculations.stattests.anderson_darling_stattest import anderson_darling_test
from evidently.calculations.stattests.chisquare_stattest import chi_stat_test
from evidently.calculations.stattests.cramer_von_mises_stattest import cramer_von_mises
from evidently.calculations.stattests.energy_distance import energy_dist_test
from evidently.calculations.stattests.fisher_exact_stattest import fisher_exact_test
from evidently.calculations.stattests.g_stattest import g_test
from evidently.calculations.stattests.hellinger_distance import hellinger_stat_test
Expand Down Expand Up @@ -231,3 +232,10 @@ def test_mann_whitney() -> None:
reference = pd.Series([1, 2, 3, 4, 5, 6]).repeat([16, 18, 16, 14, 12, 12])
current = pd.Series([1, 2, 3, 4, 5, 6]).repeat([16, 16, 16, 16, 16, 8])
assert mann_whitney_u_stat_test.func(reference, current, "num", 0.05) == (approx(0.481, abs=1e-2), False)


def test_energy_distance() -> None:
reference = pd.Series([38.7, 41.5, 43.8, 44.5, 45.5, 46.0, 47.7, 58.0])
current = pd.Series([38.7, 41.5, 43.8, 44.5, 45.5, 46.0, 47.7, 58.0])
assert energy_dist_test.func(reference, current, "num", 0.1) == (approx(0, abs=1e-5), False)
assert energy_dist_test.func(reference, current + 5, "num", 0.1) == (approx(1.9, abs=1e-1), True)

0 comments on commit f6811df

Please sign in to comment.