Skip to content

Commit

Permalink
major updates
Browse files Browse the repository at this point in the history
  • Loading branch information
gykovacs committed Oct 21, 2024
1 parent 3fde74e commit c04ef57
Show file tree
Hide file tree
Showing 12 changed files with 6,683 additions and 1,144 deletions.
4 changes: 2 additions & 2 deletions mlscorecheck/auc/_acc_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def max_acc_lower_from(*, scores: dict, eps: float, p: int, n: int, lower: str =
else:
raise ValueError(f"unsupported lower bound {lower}")

return lower0
return lower0, 1


def max_acc_upper_from(*, scores: dict, eps: float, p: int, n: int, upper: str = "min"):
Expand Down Expand Up @@ -369,7 +369,7 @@ def max_acc_upper_from(*, scores: dict, eps: float, p: int, n: int, upper: str =
else:
raise ValueError(f"unsupported upper bound {upper}")

return upper0
return upper0, 1


def max_acc_from(
Expand Down
81 changes: 68 additions & 13 deletions mlscorecheck/auc/_auc_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,18 @@
translate_scores,
prepare_intervals,
)
from ._auc_single import auc_maxa, auc_armin
from ._auc_single import (
auc_maxa,
auc_armin,
auc_min_grad,
auc_rmin_grad,
auc_max_grad,
auc_maxa_grad,
auc_min_profile,
auc_rmin_profile,
auc_max_profile,
auc_maxa_profile,
)

__all__ = [
"auc_min_aggregated",
Expand Down Expand Up @@ -821,6 +832,7 @@ def auc_lower_from_aggregated(
ns: np.array = None,
folding: dict = None,
lower: str = "min",
correction: str = None
):
"""
This function applies the lower bound estimation schemes to estimate
Expand Down Expand Up @@ -862,20 +874,34 @@ def auc_lower_from_aggregated(

check_applicability_lower_aggregated(intervals, lower, ps, ns)

corr = 1.0

if lower == "min":
lower0 = auc_min_aggregated(intervals["fpr"][1], intervals["tpr"][0], k)
elif lower == "onmin":
lower0 = auc_onmin_aggregated(intervals["fpr"][1], intervals["tpr"][0], k)
if correction == 'gradient':
corr = auc_min_grad(intervals["fpr"][1], intervals["tpr"][0])
elif correction == 'profile':
corr = auc_min_profile(intervals["fpr"][1], intervals["tpr"][0])
elif lower == "rmin":
lower0 = auc_rmin_aggregated(intervals["fpr"][0], intervals["tpr"][1], k)
if correction == 'gradient':
corr = auc_rmin_grad(intervals["fpr"][1], intervals["tpr"][0])
elif correction == 'profile':
corr = auc_rmin_profile(intervals["fpr"][1], intervals["tpr"][0])
elif lower == "onmin":
lower0 = auc_onmin_aggregated(intervals["fpr"][1], intervals["tpr"][0], k)
if correction == 'gradient':
corr = auc_onmin_grad(intervals["fpr"][1], intervals["tpr"][0])
elif correction == 'profile':
corr = auc_onmin_profile(intervals["fpr"][1], intervals["tpr"][0])
elif lower == "amin":
lower0 = auc_amin_aggregated(intervals["acc"][0], ps, ns)
elif lower == "armin":
lower0 = auc_armin_aggregated(intervals["acc"][0], ps, ns)
else:
raise ValueError(f"unsupported lower bound {lower}")

return lower0
return lower0, corr


def auc_upper_from_aggregated(
Expand All @@ -886,7 +912,8 @@ def auc_upper_from_aggregated(
ps: np.array = None,
ns: np.array = None,
folding: dict = None,
upper: str = "min",
upper: str = "max",
correction: str = None
):
"""
This function applies the upper bound estimation schemes to estimate
Expand Down Expand Up @@ -928,16 +955,26 @@ def auc_upper_from_aggregated(

check_applicability_upper_aggregated(intervals, upper, ps, ns)

corr = 1.0

if upper == "max":
upper0 = auc_max_aggregated(intervals["fpr"][0], intervals["tpr"][1], k)
if correction == 'gradient':
corr = auc_max_grad(intervals["fpr"][0], intervals["tpr"][1])
elif correction == 'profile':
corr = auc_max_profile(intervals["fpr"][0], intervals["tpr"][1])
elif upper == "amax":
upper0 = auc_amax_aggregated(intervals["acc"][1], ps, ns)
elif upper == "maxa":
upper0 = auc_maxa_aggregated(intervals["acc"][1], ps, ns)
if correction == 'gradient':
corr = auc_maxa_grad(intervals["acc"][1], p, n)
elif correction == 'profile':
corr = auc_maxa_profile(intervals["acc"][1], p, n)
else:
raise ValueError(f"unsupported upper bound {upper}")

return upper0
return upper0, corr


def auc_from_aggregated(
Expand All @@ -950,6 +987,7 @@ def auc_from_aggregated(
folding: dict = None,
lower: str = "min",
upper: str = "max",
correction: str = None
) -> tuple:
"""
This function applies the estimation schemes to estimate AUC from scores
Expand Down Expand Up @@ -977,12 +1015,29 @@ def auc_from_aggregated(
infeasible, or not enough data is provided for the estimation method
"""

lower0 = auc_lower_from_aggregated(
scores=scores, eps=eps, k=k, ps=ps, ns=ns, folding=folding, lower=lower
)
try:
lower0, corr_lower = auc_lower_from_aggregated(
scores=scores, eps=eps, k=k, ps=ps, ns=ns, folding=folding, lower=lower, correction=correction
)

upper0 = auc_upper_from_aggregated(
scores=scores, eps=eps, k=k, ps=ps, ns=ns, folding=folding, upper=upper
)
upper0, corr_upper = auc_upper_from_aggregated(
scores=scores, eps=eps, k=k, ps=ps, ns=ns, folding=folding, upper=upper, correction=correction
)

if corr_lower == 1.0 and corr_upper == 1.0:
return (lower0, upper0)

print(corr_lower, corr_upper)

corr_lower = corr_lower + 0.01
corr_upper = corr_upper + 0.01

corr_sum = corr_lower + corr_upper
corr_lower = corr_lower / corr_sum
corr_upper = corr_upper / corr_sum

midpoint = lower0 * corr_upper + upper0 * corr_lower

return (lower0, upper0)
return (midpoint, midpoint)
except:
return np.nan, np.nan
Loading

0 comments on commit c04ef57

Please sign in to comment.