From 33648488484052307f43749f8de2cee1f7cbb51f Mon Sep 17 00:00:00 2001 From: Kin Chan Date: Wed, 2 Nov 2022 16:53:09 -0700 Subject: [PATCH] Score card add fairness for classification notebook --- ...using-classification-model-debugging.ipynb | 33 +++++++++++-------- .../_score_card/classification_components.py | 10 +++--- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/examples/notebooks/responsibleaidashboard-housing-classification-model-debugging.ipynb b/examples/notebooks/responsibleaidashboard-housing-classification-model-debugging.ipynb index 7fefb941..d8377310 100644 --- a/examples/notebooks/responsibleaidashboard-housing-classification-model-debugging.ipynb +++ b/examples/notebooks/responsibleaidashboard-housing-classification-model-debugging.ipynb @@ -134,7 +134,7 @@ " return X, y\n", "\n", "target_feature = 'Sold_HigherThan_Median'\n", - "categorical_features = []\n", + "categorical_features = [\"OverallQual\", \"OverallCond\"]\n", "\n", "all_data = pd.read_csv(data_path)\n", "all_data = all_data.drop(['SalePrice','SalePriceK'], axis=1)\n", @@ -660,6 +660,11 @@ " \"threshold\": \">=0.7\"\n", " },\n", " \"precision_score\": {}\n", + " },\n", + " \"Fairness\": {\n", + " \"metric\": [\"accuracy_score\"],\n", + " \"sensitive_features\": [\"OverallQual\", \"OverallCond\"],\n", + " \"fairness_evaluation_kind\": \"ratio\"\n", " }\n", "}\n", "\n", @@ -730,19 +735,19 @@ " explain_job.set_limits(timeout=120)\n", " \n", " # Add causal analysis\n", - " causal_job = rai_causal_component(\n", - " rai_insights_dashboard=create_rai_job.outputs.rai_insights_dashboard,\n", - " treatment_features=treatment_features,\n", - " )\n", - " causal_job.set_limits(timeout=120)\n", + " # causal_job = rai_causal_component(\n", + " # rai_insights_dashboard=create_rai_job.outputs.rai_insights_dashboard,\n", + " # treatment_features=treatment_features,\n", + " # )\n", + " # causal_job.set_limits(timeout=120)\n", " \n", " # Add counterfactual analysis\n", - " counterfactual_job = rai_counterfactual_component(\n", - " rai_insights_dashboard=create_rai_job.outputs.rai_insights_dashboard,\n", - " total_cfs=10,\n", - " desired_class='opposite',\n", - " )\n", - " counterfactual_job.set_limits(timeout=600)\n", + " # counterfactual_job = rai_counterfactual_component(\n", + " # rai_insights_dashboard=create_rai_job.outputs.rai_insights_dashboard,\n", + " # total_cfs=10,\n", + " # desired_class='opposite',\n", + " # )\n", + " # counterfactual_job.set_limits(timeout=600)\n", " \n", " # Add error analysis\n", " erroranalysis_job = rai_erroranalysis_component(\n", @@ -754,8 +759,8 @@ " rai_gather_job = rai_gather_component(\n", " constructor=create_rai_job.outputs.rai_insights_dashboard,\n", " insight_1=explain_job.outputs.explanation,\n", - " insight_2=causal_job.outputs.causal,\n", - " insight_3=counterfactual_job.outputs.counterfactual,\n", + " # insight_2=causal_job.outputs.causal,\n", + " # insight_3=counterfactual_job.outputs.counterfactual,\n", " insight_4=erroranalysis_job.outputs.error_analysis,\n", " )\n", " rai_gather_job.set_limits(timeout=120)\n", diff --git a/src/responsibleai/rai_analyse/_score_card/classification_components.py b/src/responsibleai/rai_analyse/_score_card/classification_components.py index bac84ddd..0907f12c 100644 --- a/src/responsibleai/rai_analyse/_score_card/classification_components.py +++ b/src/responsibleai/rai_analyse/_score_card/classification_components.py @@ -357,11 +357,11 @@ def get_fairness_bar_plot(data): ] x_data = [ 100 * (get_metric( - "selection_rate", - data[c]["y_test"], - data[c]["y_pred"]), - data[c]["pos_label"] - ) + metric="selection_rate", + y_test=data[c]["y_test"], + y_pred=data[c]["y_pred"], + pos_label=data[c]["pos_label"] + )) for c in data ] x_data = [[x, 100 - x] for x in x_data]