-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
improved autofix strategy #148
base: main
Are you sure you want to change the base?
Changes from 4 commits
4615637
2a7cf91
e7a3d07
72fc919
d67bbc3
fc4bf7c
9f00909
d2a3432
6bcec4c
cc52ce2
62efa2d
e5c4872
02294c8
1d644a0
3ff2507
7235b40
1b99d60
330aa44
a19c88c
69ccda6
19143a3
e5b97f5
20a532c
3bbfc1c
b892e87
b54a0a7
f870e04
eb106d1
a7acfa6
1f0344d
692efe4
afbe4a9
7b96faa
b31674c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
import pathlib | ||
from typing import Any, Optional, TypeVar, Union | ||
from typing import Any, Optional, TypeVar, Union, List | ||
import math | ||
|
||
import numpy as np | ||
import copy | ||
|
||
import pandas as pd | ||
|
||
try: | ||
|
@@ -63,3 +64,130 @@ def check_none(x: Any) -> bool: | |
|
||
def check_not_none(x: Any) -> bool: | ||
return not check_none(x) | ||
|
||
|
||
def _get_autofix_default_params() -> dict: | ||
"""returns default percentage-wise params of autofix""" | ||
return { | ||
"ambiguous": 0.2, | ||
"label_issue": 0.5, | ||
"near_duplicate": 0.2, | ||
"outlier": 0.5, | ||
"confidence_threshold": 0.95, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change to: "relabel_confidence_threshold" |
||
} | ||
|
||
|
||
def _get_autofix_defaults(cleanset_df: pd.DataFrame) -> dict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: Studio team should move this function to backend of the app so it happens on server (eventually should be used in web app too) |
||
""" | ||
Generate default values for autofix parameters based on the size of the cleaned dataset. | ||
""" | ||
default_params = _get_autofix_default_params() | ||
default_values = {} | ||
|
||
for param_name, param_value in default_params.items(): | ||
if param_name != "confidence_threshold": | ||
num_rows = cleanset_df[f"is_{param_name}"].sum() | ||
default_values[f"drop_{param_name}"] = math.ceil(num_rows * param_value) | ||
else: | ||
default_values[f"drop_{param_name}"] = param_value | ||
return default_values | ||
|
||
|
||
def _get_top_fraction_ids( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: Studio team should move this function to backend of the app so it happens on server (eventually should be used in web app too) |
||
cleanset_df: pd.DataFrame, name_col: str, num_rows: int, asc=True | ||
) -> List[str]: | ||
""" | ||
Extracts the top specified number of rows based on a specified score column from a DataFrame. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will only return the IDs of datapoints to drop for a given setting of the num_rows to drop during autofix |
||
|
||
Parameters: | ||
- cleanset_df (pd.DataFrame): The input DataFrame containing the cleanset. | ||
- name_col (str): The name of the column indicating the category for which the top rows should be extracted. | ||
- num_rows (int): The number of rows to be extracted. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In autofix, we can simply multiply the fraction of issues that are the cleanset defaults by the number of datapoints to get this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. right when we spoke originally, we wanted this call to be similar to the Studio web interface call, hence I rewrote it this way, it was floating percentage before. |
||
- asc (bool, optional): If True, the rows are sorted in ascending order based on the score column; if False, in descending order. | ||
Default is True. | ||
|
||
Returns: | ||
- list: A list of row indices representing the top specified number of rows based on the specified score column. | ||
""" | ||
bool_column_name = f"is_{name_col}" | ||
|
||
# Construct a filter based on the 'label_issue' variable | ||
filter_condition = cleanset_df[bool_column_name] | ||
|
||
# Create a new DataFrame based on the filter | ||
filtered_df = cleanset_df[filter_condition] | ||
if name_col == "near_duplicate": | ||
# Group by the 'near_duplicate_cluster_ID' column | ||
df_n = filtered_df.sort_values(by="near_duplicate_score").reset_index(drop=True) | ||
sorted_df = df_n.head(num_rows) | ||
grouped_df = sorted_df.groupby("near_duplicate_cluster_id") | ||
|
||
# Initialize an empty list to store the aggregated indices | ||
aggregated_indices = [] | ||
|
||
# Iterate over each group | ||
for group_name, group_df in grouped_df: | ||
# Sort the group DataFrame by the 'near_duplicate_score' column in ascending order | ||
sorted_group_df = group_df.sort_values( | ||
by=f"{name_col}_score", ascending=asc | ||
).reset_index(drop=True) | ||
|
||
# Extract every other index and append to the aggregated indices list | ||
selected_indices = sorted_group_df.loc[::2, "cleanlab_row_ID"] | ||
aggregated_indices.extend(selected_indices) | ||
|
||
return aggregated_indices | ||
else: | ||
# Construct the boolean column name with 'is_' prefix and 'label_issue_score' suffix | ||
score_col_name = f"{name_col}_score" | ||
|
||
# Sort the filtered DataFrame by the constructed boolean column in descending order | ||
sorted_df = filtered_df.sort_values(by=score_col_name, ascending=asc) | ||
|
||
# Extract the top specified number of rows and return the 'cleanlab_row_ID' column | ||
top_rows_ids = sorted_df["cleanlab_row_ID"].head(num_rows) | ||
|
||
return top_rows_ids | ||
|
||
|
||
def _update_label_based_on_confidence(row, conf_threshold): | ||
"""Update the label and is_issue based on confidence threshold if there is a label issue. | ||
|
||
Args: | ||
row (pd.Series): The row containing label information. | ||
conf_threshold (float): The confidence threshold for updating the label. | ||
|
||
Returns: | ||
pd.Series: The updated row. | ||
""" | ||
if row["is_label_issue"] and row["suggested_label_confidence_score"] > conf_threshold: | ||
row["is_issue"] = False | ||
aditya1503 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
row["label"] = row["suggested_label"] | ||
return row | ||
|
||
|
||
def _apply_autofixed_cleanset_to_new_dataframe( | ||
original_df: pd.DataFrame, cleanset_df: pd.DataFrame, parameters: pd.DataFrame | ||
) -> pd.DataFrame: | ||
"""Apply a cleanset to update original dataaset labels and remove top rows based on specified parameters.""" | ||
original_df_copy = copy.deepcopy(original_df) | ||
original_columns = original_df_copy.columns | ||
merged_df = pd.merge(original_df_copy, cleanset_df, left_index=True, right_on="cleanlab_row_ID") | ||
|
||
merged_df = merged_df.apply( | ||
lambda row: _update_label_based_on_confidence( | ||
row, conf_threshold=parameters["drop_confidence_threshold"] | ||
), | ||
axis=1, | ||
) | ||
|
||
indices_to_drop = set() | ||
for drop_name, top_num in parameters.items(): | ||
column_name = drop_name.replace("drop_", "") | ||
if column_name == "confidence_threshold": | ||
continue | ||
top_percent_ids = _get_top_fraction_ids(merged_df, column_name, top_num, asc=False) | ||
indices_to_drop.update(top_percent_ids) | ||
|
||
merged_df = merged_df.drop(list(indices_to_drop), axis=0).reset_index(drop=True) | ||
return merged_df[original_columns] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can choose more specific key names here