Skip to content

Commit

Permalink
Fixed excluding workers in studies with no qualifications
Browse files Browse the repository at this point in the history
  • Loading branch information
meta-paul committed Nov 2, 2023
1 parent 31374ca commit b527961
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ def _base_request(
else:
result = response.json()

logger.debug(f"{log_prefix} Response: {result}")

return result

except ProlificException:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def remove_participants_from_group(
https://docs.prolific.co/docs/api-docs/public/#tag/
Participant-Groups/paths/~1api~1v1~1participant-groups~1%7Bid%7D~1participants~1/delete
"""
from mephisto.utils.logger_core import get_logger
logger = get_logger(name=__name__)
endpoint = cls.list_participants_for_group_api_endpoint.format(id=id)
params = dict(participant_ids=participant_ids)
response_json = cls.delete(endpoint, params=params)
Expand Down
20 changes: 5 additions & 15 deletions mephisto/abstractions/providers/prolific/prolific_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,6 @@ def new_from_provider_data(
agent._unit = unit
task_run: "TaskRun" = agent.get_task_run()

# In case provider API wasn't responsive, we ensure this submission
# doesn't exceed per-worker cap for this Task. Othewrwise don't process submission.
if not worker.can_send_more_submissions_for_task(task_run):
logger.info(
f'Submission from worker "{worker.db_id}" is over the Task\'s submission cap.'
)
try:
worker.exclude_worker_from_task(task_run)
except Exception:
logger.exception(
f"Failed to exclude worker {worker.db_id} in TaskRun {task_run.db_id}."
)
return agent

prolific_study_id = provider_data["prolific_study_id"]
prolific_submission_id = provider_data["assignment_id"]
unit.register_from_provider_data(prolific_study_id, prolific_submission_id)
Expand All @@ -119,6 +105,8 @@ def new_from_provider_data(

# Check whether we need to prevent this worker from future submissions in this Task
if not worker.can_send_more_submissions_for_task(task_run):
# Excluding worker from Participant Group (instead of adding to Block List)
# only because Prolific cannot update Block List for an in-progress Study
try:
worker.exclude_worker_from_task(task_run)
except Exception:
Expand Down Expand Up @@ -267,7 +255,6 @@ def get_status(self) -> str:
if prolific_submission_id:
prolific_submission = prolific_utils.get_submission(client, prolific_submission_id)
else:
# TODO: Not sure about this
self.update_status(AgentState.STATUS_EXPIRED)
return self.db_status

Expand All @@ -277,6 +264,9 @@ def get_status(self) -> str:

if prolific_submission.status == SubmissionStatus.RESERVED:
provider_status = local_status
elif prolific_submission.status == SubmissionStatus.ACTIVE:
# We don't need to map this status in our DB
pass
else:
provider_status = SUBMISSION_STATUS_TO_AGENT_STATE_MAP.get(
prolific_submission.status,
Expand Down
10 changes: 6 additions & 4 deletions mephisto/abstractions/providers/prolific/prolific_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def get_blocked_workers(self) -> List[dict]:
results = c.fetchall()
return results

def get_bloked_participant_ids(self) -> List[str]:
def get_blocked_participant_ids(self) -> List[str]:
return [w["worker_id"] for w in self.get_blocked_workers()]

def ensure_unit_exists(self, unit_id: str) -> None:
Expand Down Expand Up @@ -629,7 +629,7 @@ def find_qualifications_by_ids(
task_run_ids: Optional[List[str]] = None,
) -> List[dict]:
"""Find qualifications by Mephisto ids of qualifications and task runs"""
if not qualification_ids:
if not (qualification_ids or task_run_ids):
return []

with self.table_access_condition, self._get_connection() as conn:
Expand All @@ -645,12 +645,14 @@ def find_qualifications_by_ids(
task_run_ids_block = ""
if task_run_ids:
task_run_ids_str = ",".join([f'"{tid}"' for tid in task_run_ids])
task_run_ids_block = f"AND task_run_id IN ({task_run_ids_str})"
task_run_ids_block = f"task_run_id IN ({task_run_ids_str})"

where_block = " AND ".join(filter(bool, [qualification_ids_block, task_run_ids_block]))

c.execute(
f"""
SELECT * FROM qualifications
WHERE {qualification_ids_block} {task_run_ids_block};
WHERE {where_block};
"""
)
results = c.fetchall()
Expand Down
42 changes: 28 additions & 14 deletions mephisto/abstractions/providers/prolific/prolific_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from mephisto.abstractions.providers.prolific.prolific_unit import ProlificUnit
from mephisto.abstractions.providers.prolific.prolific_worker import ProlificWorker
from mephisto.abstractions.providers.prolific.provider_type import PROVIDER_TYPE
from mephisto.data_model.worker import Worker
from mephisto.operations.registry import register_mephisto_abstraction
from mephisto.utils.logger_core import get_logger
from mephisto.utils.qualifications import QualificationType
Expand All @@ -44,14 +45,13 @@
from .api.exceptions import ProlificException

if TYPE_CHECKING:
from mephisto.data_model.task import Task
from mephisto.data_model.task_run import TaskRun
from mephisto.data_model.unit import Unit
from mephisto.data_model.worker import Worker
from mephisto.data_model.requester import Requester
from mephisto.data_model.agent import Agent
from mephisto.abstractions.blueprint import SharedTaskState


DEFAULT_FRAME_HEIGHT = 0
DEFAULT_PROLIFIC_GROUP_NAME_ALLOW_LIST = "Allow list"
DEFAULT_PROLIFIC_GROUP_NAME_BLOCK_LIST = "Block list"
Expand Down Expand Up @@ -173,18 +173,16 @@ def _get_client(self, requester_name: str) -> ProlificClient:
def _get_qualified_workers(
self,
qualifications: List[QualificationType],
bloked_participant_ids: List[str],
blocked_participant_ids: List[str],
task_run: "TaskRun",
) -> List["Worker"]:
qualified_workers = []
workers: List[Worker] = self.db.find_workers(provider_type="prolific")
# `worker_name` is Prolific Participant ID in provider-specific datastore
available_workers = [w for w in workers if w.worker_name not in bloked_participant_ids]
available_workers = [w for w in workers if w.worker_name not in blocked_participant_ids]

for worker in available_workers:
if worker.can_send_more_submissions_for_task(task_run) and worker_is_qualified(
worker, qualifications
):
if worker_is_qualified(worker, qualifications):
qualified_workers.append(worker)

return qualified_workers
Expand Down Expand Up @@ -216,6 +214,21 @@ def _create_participant_group_with_qualified_workers(
)
return prolific_participant_group

def _get_excluded_participant_ids(self, task_run: "TaskRun") -> List[str]:
""" Find participant_ids that exceeded `maximum_units_per_worker` cap within this Task
"""
task: "Task" = task_run.get_task()
task_units: List["Unit"] = self.db.find_units(task_id=task.db_id)

excluded_participant_ids: List[str] = []
for unit in task_units:
if unit.worker_id:
worker: "Worker" = Worker.get(self.db, unit.worker_id)
if not worker.can_send_more_submissions_for_task(task_run):
excluded_participant_ids.append(worker.worker_name)

return list(set(excluded_participant_ids))

def setup_resources_for_task_run(
self,
task_run: "TaskRun",
Expand Down Expand Up @@ -264,11 +277,12 @@ def setup_resources_for_task_run(
title=args.provider.prolific_project_name,
)

blocked_participant_ids = self.datastore.get_bloked_participant_ids()

blocked_participant_ids: List[str] = self.datastore.get_blocked_participant_ids()
excluded_participant_ids: List[str] = self._get_excluded_participant_ids(task_run)
# If no Mephisto qualifications found,
# we need to block Mephisto workers on Prolific as well
if blocked_participant_ids:
participant_ids_to_add_to_block_list = blocked_participant_ids + excluded_participant_ids
if participant_ids_to_add_to_block_list:
new_prolific_specific_qualifications = []
# Add empty Blacklist in case if there is not in state or config
blacklist_qualification = DictConfig(
Expand All @@ -288,29 +302,29 @@ def setup_resources_for_task_run(
whitelist_qualification = prolific_specific_qualification
prev_value = whitelist_qualification["white_list"]
whitelist_qualification["white_list"] = [
p for p in prev_value if p not in blocked_participant_ids
p for p in prev_value if p not in participant_ids_to_add_to_block_list
]
new_prolific_specific_qualifications.append(whitelist_qualification)
elif name == ParticipantGroupEligibilityRequirement.name:
# Remove blocked Participat IDs from Participant Group Eligibility Requirement
client.ParticipantGroups.remove_participants_from_group(
id=prolific_specific_qualification["id"],
participant_ids=blocked_participant_ids,
participant_ids=participant_ids_to_add_to_block_list,
)
else:
new_prolific_specific_qualifications.append(prolific_specific_qualification)

# Set Blacklist Eligibility Requirement
blacklist_qualification["black_list"] = list(
set(blacklist_qualification["black_list"] + blocked_participant_ids)
set(blacklist_qualification["black_list"] + participant_ids_to_add_to_block_list)
)
new_prolific_specific_qualifications.append(blacklist_qualification)
prolific_specific_qualifications = new_prolific_specific_qualifications

if qualifications:
qualified_workers = self._get_qualified_workers(
qualifications,
blocked_participant_ids,
participant_ids_to_add_to_block_list,
task_run,
)

Expand Down
9 changes: 9 additions & 0 deletions mephisto/abstractions/providers/prolific/prolific_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,15 @@ def compose_completion_codes(code_suffix: str) -> List[dict]:
),
],
),
dict(
code=f"{constants.StudyCodeType.OTHER}_{code_suffix}",
code_type=constants.StudyCodeType.OTHER,
actions=[
dict(
action=constants.StudyAction.MANUALLY_REVIEW,
),
],
),
]

# Task info
Expand Down
3 changes: 1 addition & 2 deletions mephisto/data_model/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,7 @@ def new_from_provider_data(
unit.worker_id = worker.db_id
agent._unit = unit

# In case provider API wasn't responsive, we ensure this submission
# doesn't exceed per-worker cap for this Task. Othewrwise don't process submission.
# Prevent sending more units to worker if worker exceeded submission cap within this Task
task_run: "TaskRun" = agent.get_task_run()
if not worker.can_send_more_submissions_for_task(task_run):
try:
Expand Down
12 changes: 8 additions & 4 deletions mephisto/data_model/task_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,14 @@ def get_valid_units_for_worker(self, worker: "Worker") -> List["Unit"]:

# Cannot pair with self
units: List["Unit"] = []
for unit_set in unit_assigns.values():
is_self_set = map(lambda u: u.worker_id == worker.db_id, unit_set)
if not any(is_self_set):
units += unit_set
for unit_list in unit_assigns.values():
self_linked_units = [
u
for u in unit_list
if u.worker_id == worker.db_id and u.db_status == AssignmentState.LAUNCHED
]
if not self_linked_units:
units += unit_list

# Valid units must be launched and must not be special units (negative indices)
# Can use db_status directly rather than polling in the critical path, as in
Expand Down
3 changes: 1 addition & 2 deletions mephisto/data_model/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,7 @@ def can_send_more_submissions_for_task(self, task_run: "TaskRun") -> bool:
completed_task_units = [
u for u in task_units if u.get_status() in AssignmentState.completed()
]

if len(completed_task_units) >= maximum_units_per_worker:
if len(completed_task_units) >= maximum_units_per_worker - 1:
return False

return True
Expand Down

0 comments on commit b527961

Please sign in to comment.