Skip to content

Commit

Permalink
Score card add fairness for classification notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
kicha0 committed Nov 2, 2022
1 parent 345364c commit 3364848
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@
" return X, y\n",
"\n",
"target_feature = 'Sold_HigherThan_Median'\n",
"categorical_features = []\n",
"categorical_features = [\"OverallQual\", \"OverallCond\"]\n",
"\n",
"all_data = pd.read_csv(data_path)\n",
"all_data = all_data.drop(['SalePrice','SalePriceK'], axis=1)\n",
Expand Down Expand Up @@ -660,6 +660,11 @@
" \"threshold\": \">=0.7\"\n",
" },\n",
" \"precision_score\": {}\n",
" },\n",
" \"Fairness\": {\n",
" \"metric\": [\"accuracy_score\"],\n",
" \"sensitive_features\": [\"OverallQual\", \"OverallCond\"],\n",
" \"fairness_evaluation_kind\": \"ratio\"\n",
" }\n",
"}\n",
"\n",
Expand Down Expand Up @@ -730,19 +735,19 @@
" explain_job.set_limits(timeout=120)\n",
" \n",
" # Add causal analysis\n",
" causal_job = rai_causal_component(\n",
" rai_insights_dashboard=create_rai_job.outputs.rai_insights_dashboard,\n",
" treatment_features=treatment_features,\n",
" )\n",
" causal_job.set_limits(timeout=120)\n",
" # causal_job = rai_causal_component(\n",
" # rai_insights_dashboard=create_rai_job.outputs.rai_insights_dashboard,\n",
" # treatment_features=treatment_features,\n",
" # )\n",
" # causal_job.set_limits(timeout=120)\n",
" \n",
" # Add counterfactual analysis\n",
" counterfactual_job = rai_counterfactual_component(\n",
" rai_insights_dashboard=create_rai_job.outputs.rai_insights_dashboard,\n",
" total_cfs=10,\n",
" desired_class='opposite',\n",
" )\n",
" counterfactual_job.set_limits(timeout=600)\n",
" # counterfactual_job = rai_counterfactual_component(\n",
" # rai_insights_dashboard=create_rai_job.outputs.rai_insights_dashboard,\n",
" # total_cfs=10,\n",
" # desired_class='opposite',\n",
" # )\n",
" # counterfactual_job.set_limits(timeout=600)\n",
" \n",
" # Add error analysis\n",
" erroranalysis_job = rai_erroranalysis_component(\n",
Expand All @@ -754,8 +759,8 @@
" rai_gather_job = rai_gather_component(\n",
" constructor=create_rai_job.outputs.rai_insights_dashboard,\n",
" insight_1=explain_job.outputs.explanation,\n",
" insight_2=causal_job.outputs.causal,\n",
" insight_3=counterfactual_job.outputs.counterfactual,\n",
" # insight_2=causal_job.outputs.causal,\n",
" # insight_3=counterfactual_job.outputs.counterfactual,\n",
" insight_4=erroranalysis_job.outputs.error_analysis,\n",
" )\n",
" rai_gather_job.set_limits(timeout=120)\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,11 @@ def get_fairness_bar_plot(data):
]
x_data = [
100 * (get_metric(
"selection_rate",
data[c]["y_test"],
data[c]["y_pred"]),
data[c]["pos_label"]
)
metric="selection_rate",
y_test=data[c]["y_test"],
y_pred=data[c]["y_pred"],
pos_label=data[c]["pos_label"]
))
for c in data
]
x_data = [[x, 100 - x] for x in x_data]
Expand Down

0 comments on commit 3364848

Please sign in to comment.