diff --git a/src/facet/inspection/_inspection.py b/src/facet/inspection/_inspection.py index 6b8bea5d..0a9b927e 100644 --- a/src/facet/inspection/_inspection.py +++ b/src/facet/inspection/_inspection.py @@ -11,8 +11,6 @@ from pytools.api import AllTracker -from ..data import Sample - log = logging.getLogger(__name__) __all__ = [ @@ -46,15 +44,22 @@ class ShapPlotData: """ def __init__( - self, shap_values: Union[FloatArray, List[FloatArray]], sample: Sample + self, + *, + shap_values: Union[FloatArray, List[FloatArray]], + features: pd.DataFrame, + target: Union[pd.Series, pd.DataFrame], ) -> None: """ :param shap_values: the shap values for all observations and outputs - :param sample: (sub)sample of all observations for which SHAP values are - available; aligned with param ``shap_values`` + :param features: features for which SHAP values are available; + aligned with param ``shap_values`` + :param target: target values for all observations; + aligned with param ``shap_values`` """ self._shap_values = shap_values - self._sample = sample + self._features = features + self._target = target @property def shap_values(self) -> Union[FloatArray, List[FloatArray]]: @@ -69,7 +74,7 @@ def features(self) -> pd.DataFrame: """ Matrix of feature values (number of observations by number of features). """ - return self._sample.features + return self._features @property def target(self) -> Union[pd.Series, pd.DataFrame]: @@ -78,7 +83,7 @@ def target(self) -> Union[pd.Series, pd.DataFrame]: or matrix of target values for multi-output models (number of observations by number of outputs). """ - return self._sample.target + return self._target __tracker.validate() diff --git a/src/facet/inspection/base/_model_inspector.py b/src/facet/inspection/base/_model_inspector.py index 86f3b210..ce663c9f 100644 --- a/src/facet/inspection/base/_model_inspector.py +++ b/src/facet/inspection/base/_model_inspector.py @@ -736,7 +736,8 @@ def shap_plot_data(self) -> ShapPlotData: return ShapPlotData( shap_values=shap_values_numpy, - sample=sample, + features=self.preprocess_features(sample.features), + target=sample.target, ) @property diff --git a/test/test/conftest.py b/test/test/conftest.py index cbeb1826..387301b5 100644 --- a/test/test/conftest.py +++ b/test/test/conftest.py @@ -453,7 +453,13 @@ def fit_classifier_selector( parameter_space = ParameterSpace( ClassifierPipelineDF( classifier=RandomForestClassifierDF(random_state=42), - preprocessing=None, + # this column transformer is a no-op, but we need it to + # run tests where preprocessing changes feature names + preprocessing=ColumnTransformerDF( + # we prefix all feature names with "pass__" except the last one + [("pass", "passthrough", sample.feature_names[:-1])], + remainder="passthrough", + ), ) ) parameter_space.classifier.n_estimators = [10, 50] diff --git a/test/test/facet/test_inspection.py b/test/test/facet/test_inspection.py index f59a8eb6..93bf6a22 100644 --- a/test/test/facet/test_inspection.py +++ b/test/test/facet/test_inspection.py @@ -45,6 +45,13 @@ T = TypeVar("T") +IRIS_FEATURE_NAMES_PREPROCESSED = [ + "pass__sepal length (cm)", + "pass__sepal width (cm)", + "pass__petal length (cm)", + "remainder__petal width (cm)", +] + def test_regressor_selector( regressor_selector: LearnerSelector[ @@ -343,7 +350,7 @@ def test_model_inspection_classifier_multi_class( feature_importance: pd.DataFrame = iris_inspector_multi_class.feature_importance() assert feature_importance.index.equals( - pd.Index(iris_sample.feature_names, name="feature") + pd.Index(IRIS_FEATURE_NAMES_PREPROCESSED, name="feature") ) assert feature_importance.columns.equals( pd.Index(iris_inspector_multi_class.output_names, name="class") @@ -515,22 +522,38 @@ def test_model_inspection_classifier_interaction( ) -> None: warnings.filterwarnings("ignore", message="You are accessing a training score") + assert ( + iris_classifier_binary.preprocessing is not None + ), "preprocessing step must be defined" + cls_inspector: Type[ Union[ LearnerInspector[RandomForestClassifierDF], NativeLearnerInspector[RandomForestClassifier], ] ] - learner: Union[RandomForestClassifierDF, RandomForestClassifier] + classifier: Union[ClassifierPipelineDF[RandomForestClassifierDF], Pipeline] if native: cls_inspector = NativeLearnerInspector[RandomForestClassifier] - learner = iris_classifier_binary.final_estimator.native_estimator + # create a native pipeline from the classifier pipeline + classifier = Pipeline( + steps=[ + ( + "preprocessing", + iris_classifier_binary.preprocessing.native_estimator, + ), + ( + "classifier", + iris_classifier_binary.classifier.native_estimator, + ), + ] + ) else: cls_inspector = LearnerInspector[RandomForestClassifierDF] - learner = iris_classifier_binary.final_estimator + classifier = iris_classifier_binary model_inspector = cls_inspector( - model=learner, + model=classifier, explainer_factory=TreeExplainerFactory( feature_perturbation="tree_path_dependent", uses_background_dataset=True ), @@ -538,7 +561,7 @@ def test_model_inspection_classifier_interaction( ).fit(iris_sample_binary) model_inspector_no_interaction = cls_inspector( - model=learner, + model=classifier, shap_interaction=False, explainer_factory=TreeExplainerFactory( feature_perturbation="tree_path_dependent", uses_background_dataset=True @@ -560,9 +583,8 @@ def test_model_inspection_classifier_interaction( ).abs().max().max() < 0.015 # the column names of the shap value data frames are the feature names - feature_columns = iris_sample_binary.feature_names - assert shap_values.columns.to_list() == feature_columns - assert shap_interaction_values.columns.to_list() == feature_columns + assert shap_values.columns.to_list() == IRIS_FEATURE_NAMES_PREPROCESSED + assert shap_interaction_values.columns.to_list() == IRIS_FEATURE_NAMES_PREPROCESSED # the length of rows in shap_values should be equal to the number of observations assert len(shap_values) == len(iris_sample_binary) @@ -570,7 +592,7 @@ def test_model_inspection_classifier_interaction( # the length of rows in shap_interaction_values should be equal to the number of # observations, times the number of features assert len(shap_interaction_values) == ( - len(iris_sample_binary) * len(feature_columns) + len(iris_sample_binary) * len(IRIS_FEATURE_NAMES_PREPROCESSED) ) # do the shap values add up to predictions minus a constant value? @@ -840,11 +862,20 @@ def test_shap_plot_data( assert all(shap.shape == features_shape for shap in shap_values) shap_index = shap_plot_data.features.index + preprocessing = iris_inspector_multi_class.model.preprocessing + assert preprocessing is not None, "preprocessing step must be defined" + assert_frame_equal( - shap_plot_data.features, iris_sample_multi_class.features.loc[shap_index] + # the shap plot data should contain the same observations as the + # preprocessed features in the sample + shap_plot_data.features, + preprocessing.transform(iris_sample_multi_class.features).loc[shap_index], ) assert_series_equal( - shap_plot_data.target, iris_sample_multi_class.target.loc[shap_index] + # the shap plot data should contain the same target values as the + # sample + shap_plot_data.target, + iris_sample_multi_class.target.loc[shap_index], )