Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: initialize ShapPlotData with preprocessed feature names #381

Open
wants to merge 4 commits into
base: 2.1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions src/facet/inspection/_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@

from pytools.api import AllTracker

from ..data import Sample

log = logging.getLogger(__name__)

__all__ = [
Expand Down Expand Up @@ -46,15 +44,22 @@ class ShapPlotData:
"""

def __init__(
self, shap_values: Union[FloatArray, List[FloatArray]], sample: Sample
self,
*,
shap_values: Union[FloatArray, List[FloatArray]],
features: pd.DataFrame,
target: Union[pd.Series, pd.DataFrame],
) -> None:
"""
:param shap_values: the shap values for all observations and outputs
:param sample: (sub)sample of all observations for which SHAP values are
available; aligned with param ``shap_values``
:param features: features for which SHAP values are available;
aligned with param ``shap_values``
:param target: target values for all observations;
aligned with param ``shap_values``
"""
self._shap_values = shap_values
self._sample = sample
self._features = features
self._target = target

@property
def shap_values(self) -> Union[FloatArray, List[FloatArray]]:
Expand All @@ -69,7 +74,7 @@ def features(self) -> pd.DataFrame:
"""
Matrix of feature values (number of observations by number of features).
"""
return self._sample.features
return self._features

@property
def target(self) -> Union[pd.Series, pd.DataFrame]:
Expand All @@ -78,7 +83,7 @@ def target(self) -> Union[pd.Series, pd.DataFrame]:
or matrix of target values for multi-output models
(number of observations by number of outputs).
"""
return self._sample.target
return self._target


__tracker.validate()
3 changes: 2 additions & 1 deletion src/facet/inspection/base/_model_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,8 @@ def shap_plot_data(self) -> ShapPlotData:

return ShapPlotData(
shap_values=shap_values_numpy,
sample=sample,
features=self.preprocess_features(sample.features),
target=sample.target,
)

@property
Expand Down
8 changes: 7 additions & 1 deletion test/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,13 @@ def fit_classifier_selector(
parameter_space = ParameterSpace(
ClassifierPipelineDF(
classifier=RandomForestClassifierDF(random_state=42),
preprocessing=None,
# this column transformer is a no-op, but we need it to
# run tests where preprocessing changes feature names
preprocessing=ColumnTransformerDF(
# we prefix all feature names with "pass__" except the last one
[("pass", "passthrough", sample.feature_names[:-1])],
remainder="passthrough",
),
)
)
parameter_space.classifier.n_estimators = [10, 50]
Expand Down
55 changes: 43 additions & 12 deletions test/test/facet/test_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@

T = TypeVar("T")

IRIS_FEATURE_NAMES_PREPROCESSED = [
"pass__sepal length (cm)",
"pass__sepal width (cm)",
"pass__petal length (cm)",
"remainder__petal width (cm)",
]


def test_regressor_selector(
regressor_selector: LearnerSelector[
Expand Down Expand Up @@ -343,7 +350,7 @@ def test_model_inspection_classifier_multi_class(

feature_importance: pd.DataFrame = iris_inspector_multi_class.feature_importance()
assert feature_importance.index.equals(
pd.Index(iris_sample.feature_names, name="feature")
pd.Index(IRIS_FEATURE_NAMES_PREPROCESSED, name="feature")
)
assert feature_importance.columns.equals(
pd.Index(iris_inspector_multi_class.output_names, name="class")
Expand Down Expand Up @@ -515,30 +522,46 @@ def test_model_inspection_classifier_interaction(
) -> None:
warnings.filterwarnings("ignore", message="You are accessing a training score")

assert (
iris_classifier_binary.preprocessing is not None
), "preprocessing step must be defined"

cls_inspector: Type[
Union[
LearnerInspector[RandomForestClassifierDF],
NativeLearnerInspector[RandomForestClassifier],
]
]
learner: Union[RandomForestClassifierDF, RandomForestClassifier]
classifier: Union[ClassifierPipelineDF[RandomForestClassifierDF], Pipeline]
if native:
cls_inspector = NativeLearnerInspector[RandomForestClassifier]
learner = iris_classifier_binary.final_estimator.native_estimator
# create a native pipeline from the classifier pipeline
classifier = Pipeline(
steps=[
(
"preprocessing",
iris_classifier_binary.preprocessing.native_estimator,
),
(
"classifier",
iris_classifier_binary.classifier.native_estimator,
),
]
)
else:
cls_inspector = LearnerInspector[RandomForestClassifierDF]
learner = iris_classifier_binary.final_estimator
classifier = iris_classifier_binary

model_inspector = cls_inspector(
model=learner,
model=classifier,
explainer_factory=TreeExplainerFactory(
feature_perturbation="tree_path_dependent", uses_background_dataset=True
),
n_jobs=n_jobs,
).fit(iris_sample_binary)

model_inspector_no_interaction = cls_inspector(
model=learner,
model=classifier,
shap_interaction=False,
explainer_factory=TreeExplainerFactory(
feature_perturbation="tree_path_dependent", uses_background_dataset=True
Expand All @@ -560,17 +583,16 @@ def test_model_inspection_classifier_interaction(
).abs().max().max() < 0.015

# the column names of the shap value data frames are the feature names
feature_columns = iris_sample_binary.feature_names
assert shap_values.columns.to_list() == feature_columns
assert shap_interaction_values.columns.to_list() == feature_columns
assert shap_values.columns.to_list() == IRIS_FEATURE_NAMES_PREPROCESSED
assert shap_interaction_values.columns.to_list() == IRIS_FEATURE_NAMES_PREPROCESSED

# the length of rows in shap_values should be equal to the number of observations
assert len(shap_values) == len(iris_sample_binary)

# the length of rows in shap_interaction_values should be equal to the number of
# observations, times the number of features
assert len(shap_interaction_values) == (
len(iris_sample_binary) * len(feature_columns)
len(iris_sample_binary) * len(IRIS_FEATURE_NAMES_PREPROCESSED)
)

# do the shap values add up to predictions minus a constant value?
Expand Down Expand Up @@ -840,11 +862,20 @@ def test_shap_plot_data(
assert all(shap.shape == features_shape for shap in shap_values)

shap_index = shap_plot_data.features.index
preprocessing = iris_inspector_multi_class.model.preprocessing
assert preprocessing is not None, "preprocessing step must be defined"

assert_frame_equal(
shap_plot_data.features, iris_sample_multi_class.features.loc[shap_index]
# the shap plot data should contain the same observations as the
# preprocessed features in the sample
shap_plot_data.features,
preprocessing.transform(iris_sample_multi_class.features).loc[shap_index],
)
assert_series_equal(
shap_plot_data.target, iris_sample_multi_class.target.loc[shap_index]
# the shap plot data should contain the same target values as the
# sample
shap_plot_data.target,
iris_sample_multi_class.target.loc[shap_index],
)


Expand Down