Skip to content

Commit

Permalink
Unittest coverage for Data Porter feature
Browse files Browse the repository at this point in the history
  • Loading branch information
meta-paul committed May 6, 2024
1 parent fe83237 commit b86f12c
Show file tree
Hide file tree
Showing 23 changed files with 1,297 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

"""
List of changes:
1. Rename `unit_review.created_at` -> `unit_review.creation_date`
2. Remove autoincrement parameter for all Primary Keys
3. Add missed Foreign Keys in `agents` table
Expand All @@ -13,7 +14,7 @@
"""


PREPARING_DB_FOR_MERGE_DBS_COMMAND = """
MODIFICATIONS_QUERY_FOR_DATA_PORTER = """
ALTER TABLE unit_review RENAME COLUMN created_at TO creation_date;
/* Disable FK constraints */
Expand Down
4 changes: 2 additions & 2 deletions mephisto/abstractions/databases/migrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from ._001_20240325_preparing_db_for_merge_dbs_command import *
from ._001_20240325_data_porter_feature import *


migrations = {
"20240418_preparing_db_for_merge_dbs_command": PREPARING_DB_FOR_MERGE_DBS_COMMAND,
"20240418_data_porter_feature": MODIFICATIONS_QUERY_FOR_DATA_PORTER,
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
# LICENSE file in the root directory of this source tree.

"""
1. Modified default value for `creation_date`
List of changes:
1. Modify default value for `creation_date`
"""


PREPARING_DB_FOR_MERGE_DBS_COMMAND = """
MODIFICATIONS_QUERY_FOR_DATA_PORTER = """
/* Disable FK constraints */
PRAGMA foreign_keys = off;
Expand All @@ -36,8 +37,8 @@
INSERT INTO _run_mappings SELECT * FROM run_mappings;
DROP TABLE run_mappings;
ALTER TABLE _run_mappings RENAME TO run_mappings;
/* Runs */
CREATE TABLE IF NOT EXISTS _runs (
run_id TEXT PRIMARY KEY UNIQUE,
Expand All @@ -50,8 +51,8 @@
INSERT INTO _runs SELECT * FROM runs;
DROP TABLE runs;
ALTER TABLE _runs RENAME TO runs;
/* Qualifications */
CREATE TABLE IF NOT EXISTS _qualifications (
qualification_name TEXT PRIMARY KEY UNIQUE,
Expand Down
4 changes: 2 additions & 2 deletions mephisto/abstractions/providers/mturk/migrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from ._001_20240325_preparing_db_for_merge_dbs_command import *
from ._001_20240325_data_porter_feature import *


migrations = {
"20240418_preparing_db_for_merge_dbs_command": PREPARING_DB_FOR_MERGE_DBS_COMMAND,
"20240418_data_porter_feature": MODIFICATIONS_QUERY_FOR_DATA_PORTER,
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
# LICENSE file in the root directory of this source tree.

"""
List of changes:
1. Remove autoincrement parameter for all Primary Keys
2. Added `update_date` and `creation_date` in `workers` table
3. Added `creation_date` in `units` table
2. Add `update_date` and `creation_date` in `workers` table
3. Add `creation_date` in `units` table
4. Rename field `run_id` -> `task_run_id`
5. Remove table `requesters`
6. Modified default value for `creation_date`
6. Modify default value for `creation_date`
"""


PREPARING_DB_FOR_MERGE_DBS_COMMAND = """
MODIFICATIONS_QUERY_FOR_DATA_PORTER = """
/* Disable FK constraints */
PRAGMA foreign_keys = off;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from ._001_20240325_preparing_db_for_merge_dbs_command import *
from ._001_20240325_data_porter_feature import *


migrations = {
"20240418_preparing_db_for_merge_dbs_command": PREPARING_DB_FOR_MERGE_DBS_COMMAND,
"20240418_data_porter_feature": MODIFICATIONS_QUERY_FOR_DATA_PORTER,
}
3 changes: 3 additions & 0 deletions mephisto/data_model/task_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import os
import json
from dataclasses import dataclass, field
from datetime import datetime
from dateutil.parser import parse

from mephisto.data_model.requester import Requester
from mephisto.data_model.constants.assignment_state import AssignmentState
Expand Down Expand Up @@ -202,6 +204,7 @@ def __init__(
self.task_type: str = row["task_type"]
self.sandbox: bool = row["sandbox"]
self.assignments_generator_done: bool = False
self.creation_date: Optional[datetime] = parse(row["creation_date"])

# properties with deferred loading
self.__is_completed = row["is_completed"]
Expand Down
2 changes: 1 addition & 1 deletion mephisto/tools/db_data_porter/backups.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ def restore_from_backup(
if remove_backup:
Path(backup_file_path).unlink(missing_ok=True)
except Exception as e:
logger.exception(f"[red]Could not restore backup '{backup_file_path}'. Error: {e}[/red]")
logger.exception(f"[red]Could not restore backup {backup_file_path}. Error: {e}[/red]")
exit()
5 changes: 3 additions & 2 deletions mephisto/tools/db_data_porter/db_data_porter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
from mephisto.abstractions.databases.local_database import LocalMephistoDB
from mephisto.generators.form_composer.config_validation.utils import make_error_message
from mephisto.tools.db_data_porter import backups
from mephisto.tools.db_data_porter import export_dump
from mephisto.tools.db_data_porter import dumps
from mephisto.tools.db_data_porter import export_dump
from mephisto.tools.db_data_porter import import_dump
from mephisto.tools.db_data_porter.constants import BACKUP_OUTPUT_DIR
from mephisto.tools.db_data_porter.constants import DEFAULT_ARCHIVE_FORMAT
from mephisto.tools.db_data_porter.constants import DEFAULT_CONFLICT_RESOLVER
from mephisto.tools.db_data_porter.constants import EXPORT_OUTPUT_DIR
from mephisto.tools.db_data_porter.constants import IMPORTED_DATA_TABLE_NAME
from mephisto.tools.db_data_porter.constants import MEPHISTO_DUMP_KEY
Expand Down Expand Up @@ -301,7 +302,7 @@ def export_dump(
def import_dump(
self,
dump_archive_file_name_or_path: str,
conflict_resolver_name: str,
conflict_resolver_name: Optional[str] = DEFAULT_CONFLICT_RESOLVER,
labels: Optional[List[str]] = None,
keep_import_metadata: Optional[bool] = None,
verbosity: int = 0,
Expand Down
4 changes: 2 additions & 2 deletions mephisto/tools/db_data_porter/dumps.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def prepare_partial_dump_data(
task_ids = task_ids or []

# Get TaskRun IDs by Task IDs
task_run_ids = db_utils.get_task_run_ids_ids_by_task_ids(db, task_ids)
task_run_ids = db_utils.get_task_run_ids_by_task_ids(db, task_ids)
elif task_runs_labels:
# Validate on correct values of passed TaskRun labels
db_labels = db_utils.get_list_of_available_labels(db)
Expand All @@ -117,7 +117,7 @@ def prepare_partial_dump_data(
exit()

# Get TaskRun IDs
task_run_ids = db_utils.get_task_run_ids_ids_by_labels(db, task_runs_labels)
task_run_ids = db_utils.get_task_run_ids_by_labels(db, task_runs_labels)
elif since_datetime:
# Get TaskRun IDs
task_run_ids = db_utils.select_task_run_ids_since_date(db, since_datetime)
Expand Down
2 changes: 1 addition & 1 deletion mephisto/tools/db_data_porter/import_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _update_row_with_pks_from_resolvings_mappings(
row: dict,
resolvings_mapping: MappingResolvingsType,
) -> dict:
table_fks = db_utils.select_fk_mappings_for_table(db, table_name)
table_fks = db_utils.select_fk_mappings_for_single_table(db, table_name)

# Update FK fields from resolving mappings if needed
for fk_table, fk_table_fields in table_fks.items():
Expand Down
2 changes: 1 addition & 1 deletion mephisto/tools/db_data_porter/randomize_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _randomize_ids_for_mephisto(
table_names = [t for t in mephisto_dump.keys() if t not in [IMPORTED_DATA_TABLE_NAME]]

# Find Foreign Keys' field names for all tables in Mephist DB
tables_fks = db_utils.select_fk_mappings_for_all_tables(db, table_names)
tables_fks = db_utils.select_fk_mappings_for_tables(db, table_names)

# Make new Primary Keys for all or legacy values
mephisto_pk_substitutions = {}
Expand Down
28 changes: 11 additions & 17 deletions mephisto/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,6 @@ class EntryDoesNotExistException(MephistoDBException):

# --- Functions ---


def _select_all_rows_from_table(db: "MephistoDB", table_name: str) -> List[dict]:
with db.table_access_condition, db.get_connection() as conn:
c = conn.cursor()
c.execute(f"SELECT * FROM {table_name};")
rows = c.fetchall()
return [dict(row) for row in rows]


def _select_rows_from_table_related_to_task(
db: "MephistoDB",
table_name: str,
Expand Down Expand Up @@ -121,7 +112,7 @@ def get_task_ids_by_task_names(db: "MephistoDB", task_names: List[str]) -> List[
return [r["task_id"] for r in rows]


def get_task_run_ids_ids_by_task_ids(db: "MephistoDB", task_ids: List[str]) -> List[str]:
def get_task_run_ids_by_task_ids(db: "MephistoDB", task_ids: List[str]) -> List[str]:
with db.table_access_condition, db.get_connection() as conn:
c = conn.cursor()
task_ids_string = ",".join([f"'{s}'" for s in task_ids])
Expand All @@ -135,7 +126,7 @@ def get_task_run_ids_ids_by_task_ids(db: "MephistoDB", task_ids: List[str]) -> L
return [r["task_run_id"] for r in rows]


def get_task_run_ids_ids_by_labels(db: "MephistoDB", labels: List[str]) -> List[str]:
def get_task_run_ids_by_labels(db: "MephistoDB", labels: List[str]) -> List[str]:
with db.table_access_condition, db.get_connection() as conn:
if not labels:
return []
Expand Down Expand Up @@ -174,10 +165,13 @@ def get_table_pk_field_name(db: "MephistoDB", table_name: str):
return table_unique_field_name


def select_all_table_rows(db: "MephistoDB", table_name: str) -> List[dict]:
def select_all_table_rows(
db: "MephistoDB", table_name: str, order_by: Optional[str] = None,
) -> List[dict]:
order_by_string = f" ORDER BY {order_by}" if order_by else ""
with db.table_access_condition, db.get_connection() as conn:
c = conn.cursor()
c.execute(f"SELECT * FROM {table_name};")
c.execute(f"SELECT * FROM {table_name}{order_by_string};")
rows = c.fetchall()
return [dict(row) for row in rows]

Expand Down Expand Up @@ -428,7 +422,7 @@ def db_or_datastore_to_dict(db: "MephistoDB") -> dict:
dump_data = {}
table_names = get_list_of_tables_to_export(db)
for table_name in table_names:
table_rows = _select_all_rows_from_table(db, table_name)
table_rows = select_all_table_rows(db, table_name)
table_data = serialize_data_for_table(table_rows)
dump_data[table_name] = table_data

Expand Down Expand Up @@ -547,7 +541,7 @@ def select_task_run_ids_since_date(db: "MephistoDB", since: datetime) -> List[st
return task_run_ids_since


def select_fk_mappings_for_table(db: "MephistoDB", table_name: str) -> dict:
def select_fk_mappings_for_single_table(db: "MephistoDB", table_name: str) -> dict:
with db.table_access_condition, db.get_connection() as conn:
c = conn.cursor()
c.execute(f"SELECT * FROM pragma_foreign_key_list('{table_name}');")
Expand All @@ -567,10 +561,10 @@ def select_fk_mappings_for_table(db: "MephistoDB", table_name: str) -> dict:
return table_fks


def select_fk_mappings_for_all_tables(db: "MephistoDB", table_names: List[str]) -> dict:
def select_fk_mappings_for_tables(db: "MephistoDB", table_names: List[str]) -> dict:
tables_fks = {}
for table_name in table_names:
table_fks = select_fk_mappings_for_table(db, table_name)
table_fks = select_fk_mappings_for_single_table(db, table_name)
tables_fks.update({table_name: table_fks})
return tables_fks

Expand Down
16 changes: 9 additions & 7 deletions mephisto/utils/dirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ def get_root_data_dir() -> str:
actual_data_dir = get_config_arg(CORE_SECTION, DATA_STORAGE_KEY)
if actual_data_dir is None:
data_dir_location = input(
"Please enter the full path to a location to store Mephisto run data. By default this "
f"would be at '{default_data_dir}'. This dir should NOT be on a distributed file "
"store. Press enter to use the default: "
"Please enter the full path to a location to store Mephisto run data. "
"By default this would be at '{default_data_dir}'. "
"This dir should NOT be on a distributed file store. "
"Press enter to use the default: "
).strip()
if len(data_dir_location) == 0:
data_dir_location = default_data_dir
Expand All @@ -85,17 +86,18 @@ def get_root_data_dir() -> str:
if os.path.exists(database_loc) and data_dir_location != default_data_dir:
should_migrate = (
input(
"We have found an existing database in the default data directory, do you want to "
f"copy any existing data from the default location to {data_dir_location}? (y)es/no: "
f"We have found an existing database in the default data directory, "
f"do you want to copy any existing data from the default location to "
f"{data_dir_location}? (y)es/no: "
)
.lower()
.strip()
)
if len(should_migrate) == 0 or should_migrate[0] == "y":
copy_tree(default_data_dir, data_dir_location)
print(
"Mephisto data successfully copied, once you've confirmed the migration worked, "
"feel free to remove all of the contents in "
"Mephisto data successfully copied, once you've confirmed "
"the migration worked, feel free to remove all of the contents in "
f"{default_data_dir} EXCEPT for `README.md`."
)
add_config_arg(CORE_SECTION, DATA_STORAGE_KEY, data_dir_location)
Expand Down
34 changes: 21 additions & 13 deletions mephisto/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,27 @@
)


def get_test_project(db: MephistoDB) -> Tuple[str, str]:
def get_test_project(db: MephistoDB, project_name: Optional[str] = None) -> Tuple[str, str]:
"""Helper to create a project for tests"""
project_name = "test_project"
project_name = project_name or "test_project"
project_id = db.new_project(project_name)
return project_name, project_id


def get_test_worker(db: MephistoDB) -> Tuple[str, str]:
def get_test_worker(db: MephistoDB, worker_name: Optional[str] = None) -> Tuple[str, str]:
"""Helper to create a worker for tests"""
worker_name = "test_worker"
worker_name = worker_name or "test_worker"
provider_type = "mock"
worker_id = db.new_worker(worker_name, provider_type)
return worker_name, worker_id


def get_test_requester(db: MephistoDB) -> Tuple[str, str]:
def get_test_requester(
db: MephistoDB, requester_name: Optional[str] = None, provider_type: Optional[str] = None,
) -> Tuple[str, str]:
"""Helper to create a requester for tests"""
requester_name = "test_requester"
provider_type = "mock"
requester_name = requester_name or "test_requester"
provider_type = provider_type or "mock"
requester_id = db.new_requester(requester_name, provider_type)
return requester_name, requester_id

Expand All @@ -78,18 +80,24 @@ def get_mock_requester(db) -> "Requester":
return mock_requesters[0]


def get_test_task(db: MephistoDB) -> Tuple[str, str]:
def get_test_task(db: MephistoDB, task_name: Optional[str] = None) -> Tuple[str, str]:
"""Helper to create a task for tests"""
task_name = "test_task"
task_name = task_name or "test_task"
task_type = "mock"
task_id = db.new_task(task_name, task_type)
return task_name, task_id


def get_test_task_run(db: MephistoDB) -> str:
def get_test_task_run(
db: MephistoDB, task_id: Optional[str] = None, requester_id: Optional[str] = None,
) -> str:
"""Helper to create a task run for tests"""
task_name, task_id = get_test_task(db)
requester_name, requester_id = get_test_requester(db)
if not task_id:
_, task_id = get_test_task(db)

if not requester_id:
_, requester_id = get_test_requester(db)

init_params = OmegaConf.to_yaml(OmegaConf.structured(MOCK_CONFIG))
return db.new_task_run(task_id, requester_id, json.dumps(init_params), "mock", "mock")

Expand Down Expand Up @@ -191,7 +199,7 @@ def make_completed_unit(db: MephistoDB) -> str:
return unit.db_id


def get_test_qualification(db: MephistoDB, name="test_qualification") -> str:
def get_test_qualification(db: MephistoDB, name: str = "test_qualification") -> str:
return db.make_qualification(name)


Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ addopts = -ra -q -s
markers =
req_creds: test which requires credentials
prolific: Prolific tests
utils: Mephisto utils
db_data_porter: DB Data Porter tool
testpaths =
test
Empty file.
Empty file.
Loading

0 comments on commit b86f12c

Please sign in to comment.