From 87a38b5f0bcf1856b4963f4e9efbee783d872f34 Mon Sep 17 00:00:00 2001 From: Tommaso Comparin <3862206+tcompa@users.noreply.github.com> Date: Tue, 30 May 2023 16:04:19 +0200 Subject: [PATCH] Revamp argument validation (ref #377) * Use PydanticV1 `validate_arguments` instead of creating `TaskArguments` models by hand. * Update tests. * Update `args_schema` script. --- fractal_tasks_core/__FRACTAL_MANIFEST__.json | 55 +++++--- fractal_tasks_core/_utils.py | 25 ++-- fractal_tasks_core/cellpose_segmentation.py | 34 +---- fractal_tasks_core/copy_ome_zarr.py | 13 +- fractal_tasks_core/create_ome_zarr.py | 29 +--- .../create_ome_zarr_multiplex.py | 16 +-- fractal_tasks_core/illumination_correction.py | 17 +-- .../maximum_intensity_projection.py | 11 +- .../napari_workflows_wrapper.py | 19 +-- fractal_tasks_core/yokogawa_to_ome_zarr.py | 13 +- scripts/update_args_schemas.py | 49 ++++++- tests/test_valid_args_schemas.py | 133 ++++++++++++++---- tests/test_valid_task_interface.py | 7 +- 13 files changed, 216 insertions(+), 205 deletions(-) diff --git a/fractal_tasks_core/__FRACTAL_MANIFEST__.json b/fractal_tasks_core/__FRACTAL_MANIFEST__.json index 6cf8c99b6..85da01000 100644 --- a/fractal_tasks_core/__FRACTAL_MANIFEST__.json +++ b/fractal_tasks_core/__FRACTAL_MANIFEST__.json @@ -17,7 +17,6 @@ }, "coarsening_xy": { "default": 2, - "description": "TBD", "title": "Coarsening Xy", "type": "integer" }, @@ -27,7 +26,6 @@ "type": "string" }, "image_glob_patterns": { - "description": "TBD", "items": { "type": "string" }, @@ -35,7 +33,6 @@ "type": "array" }, "input_paths": { - "description": "TBD", "items": { "type": "string" }, @@ -48,18 +45,15 @@ }, "metadata_table": { "default": "mrf_mlf", - "description": "TBD", "title": "Metadata Table", "type": "string" }, "num_levels": { "default": 2, - "description": "TBD", "title": "Num Levels", "type": "integer" }, "output_path": { - "description": "TBD", "title": "Output Path", "type": "string" } @@ -70,7 +64,7 @@ "metadata", "allowed_channels" ], - "title": "TaskArguments", + "title": "CreateOmeZarr", "type": "object" }, "default_args": { @@ -97,6 +91,7 @@ "type": "string" }, "delete_input": { + "default": false, "title": "Delete Input", "type": "boolean" }, @@ -119,10 +114,10 @@ "required": [ "input_paths", "output_path", - "metadata", - "component" + "component", + "metadata" ], - "title": "TaskArguments", + "title": "YokogawaToOmeZarr", "type": "object" }, "executable": "yokogawa_to_ome_zarr.py", @@ -162,6 +157,7 @@ "type": "string" }, "project_to_2D": { + "default": true, "title": "Project To 2D", "type": "boolean" }, @@ -175,7 +171,7 @@ "output_path", "metadata" ], - "title": "TaskArguments", + "title": "CopyOmeZarr", "type": "object" }, "default_args": { @@ -218,10 +214,10 @@ "required": [ "input_paths", "output_path", - "metadata", - "component" + "component", + "metadata" ], - "title": "TaskArguments", + "title": "MaximumIntensityProjection", "type": "object" }, "executable": "maximum_intensity_projection.py", @@ -243,10 +239,12 @@ "type": "number" }, "augment": { + "default": false, "title": "Augment", "type": "boolean" }, "cellprob_threshold": { + "default": 0.0, "title": "Cellprob Threshold", "type": "number" }, @@ -263,14 +261,17 @@ "type": "string" }, "diameter_level0": { + "default": 30.0, "title": "Diameter Level0", "type": "number" }, "flow_threshold": { + "default": 0.4, "title": "Flow Threshold", "type": "number" }, "input_ROI_table": { + "default": "FOV_ROI_table", "title": "Input Roi Table", "type": "string" }, @@ -290,14 +291,17 @@ "type": "object" }, "min_size": { + "default": 15, "title": "Min Size", "type": "integer" }, "model_type": { + "default": "cyto2", "title": "Model Type", "type": "string" }, "net_avg": { + "default": false, "title": "Net Avg", "type": "boolean" }, @@ -323,10 +327,12 @@ "type": "boolean" }, "use_gpu": { + "default": true, "title": "Use Gpu", "type": "boolean" }, "use_masks": { + "default": true, "title": "Use Masks", "type": "boolean" }, @@ -346,7 +352,7 @@ "metadata", "level" ], - "title": "TaskArguments", + "title": "CellposeSegmentation", "type": "object" }, "executable": "cellpose_segmentation.py", @@ -365,6 +371,7 @@ "additionalProperties": false, "properties": { "background": { + "default": 100, "title": "Background", "type": "integer" }, @@ -396,6 +403,7 @@ "type": "string" }, "overwrite": { + "default": false, "title": "Overwrite", "type": "boolean" } @@ -407,7 +415,7 @@ "metadata", "dict_corr" ], - "title": "TaskArguments", + "title": "IlluminationCorrection", "type": "object" }, "default_args": { @@ -433,10 +441,12 @@ "type": "string" }, "expected_dimensions": { + "default": 3, "title": "Expected Dimensions", "type": "integer" }, "input_ROI_table": { + "default": "FOV_ROI_table", "title": "Input Roi Table", "type": "string" }, @@ -458,6 +468,7 @@ "type": "object" }, "level": { + "default": 0, "title": "Level", "type": "integer" }, @@ -480,6 +491,7 @@ "type": "object" }, "relabeling": { + "default": true, "title": "Relabeling", "type": "boolean" }, @@ -491,13 +503,13 @@ "required": [ "input_paths", "output_path", - "metadata", "component", + "metadata", "workflow_file", "input_specs", "output_specs" ], - "title": "TaskArguments", + "title": "NapariWorkflowsWrapper", "type": "object" }, "default_args": { @@ -531,10 +543,12 @@ "type": "object" }, "coarsening_xy": { + "default": 2, "title": "Coarsening Xy", "type": "integer" }, "image_extension": { + "default": "tif", "title": "Image Extension", "type": "string" }, @@ -571,9 +585,11 @@ "type": "object" } ], + "default": "mrf_mlf", "title": "Metadata Table" }, "num_levels": { + "default": 2, "title": "Num Levels", "type": "integer" }, @@ -586,10 +602,9 @@ "input_paths", "output_path", "metadata", - "image_extension", "allowed_channels" ], - "title": "TaskArguments", + "title": "CreateOmeZarrMultiplex", "type": "object" }, "default_args": { diff --git a/fractal_tasks_core/_utils.py b/fractal_tasks_core/_utils.py index 287243804..1f764b16d 100644 --- a/fractal_tasks_core/_utils.py +++ b/fractal_tasks_core/_utils.py @@ -25,8 +25,6 @@ from pathlib import Path from typing import Callable -from pydantic import BaseModel - class TaskParameterEncoder(JSONEncoder): """ @@ -42,15 +40,16 @@ def default(self, value): def run_fractal_task( *, task_function: Callable, - TaskArgsModel: type[BaseModel] = None, + validate: bool = False, logger_name: str = None, ): """ - Implement standard task interface and call task_function. If TaskArgsModel - is not None, validate arguments against given model. + Implement standard task interface and call task_function. If `validate`, + validate arguments via `pydantic.decorator.validate_arguments`. :param task_function: the callable function that runs the task - :param TaskArgsModel: a class specifying all types for task arguments + :param validate: TBD + :logger_name: TBD """ # Parse `-j` and `--metadata-out` arguments @@ -79,16 +78,24 @@ def run_fractal_task( with open(args.json, "r") as f: pars = json.load(f) - if TaskArgsModel is None: + if not validate: # Run task without validating arguments' types logger.info(f"START {task_function.__name__} task") metadata_update = task_function(**pars) logger.info(f"END {task_function.__name__} task") else: + from pydantic.decorator import validate_arguments + from devtools import debug + + debug(pars) # Validating arguments' types and run task - task_args = TaskArgsModel(**pars) logger.info(f"START {task_function.__name__} task") - metadata_update = task_function(**task_args.dict(exclude_unset=True)) + debug(task_function) + + vf = validate_arguments(task_function) + debug(vf) + + metadata_update = vf(**pars) logger.info(f"END {task_function.__name__} task") # Write output metadata to file, with custom JSON encoder diff --git a/fractal_tasks_core/cellpose_segmentation.py b/fractal_tasks_core/cellpose_segmentation.py index 373be80ad..91bf26a18 100644 --- a/fractal_tasks_core/cellpose_segmentation.py +++ b/fractal_tasks_core/cellpose_segmentation.py @@ -33,8 +33,6 @@ import zarr from anndata.experimental import write_elem from cellpose import models -from pydantic import BaseModel -from pydantic import Extra import fractal_tasks_core from fractal_tasks_core.lib_channels import ChannelNotFoundError @@ -663,42 +661,12 @@ def cellpose_segmentation( return {} -class TaskArguments(BaseModel, extra=Extra.forbid): - # Fractal arguments - input_paths: Sequence[str] - output_path: str - component: str - metadata: Dict[str, Any] - # Task-specific arguments - channel_label: Optional[str] - channel_label_c2: Optional[str] - wavelength_id: Optional[str] - wavelength_id_c2: Optional[str] - level: int - relabeling: bool = True - input_ROI_table: Optional[str] - output_ROI_table: Optional[str] - output_label_name: Optional[str] - # Cellpose-related arguments: - use_gpu: Optional[bool] - anisotropy: Optional[float] - diameter_level0: Optional[float] - cellprob_threshold: Optional[float] - flow_threshold: Optional[float] - model_type: Optional[str] - pretrained_model: Optional[str] - min_size: Optional[int] - augment: Optional[bool] - net_avg: Optional[bool] - use_masks: Optional[bool] - - if __name__ == "__main__": from fractal_tasks_core._utils import run_fractal_task run_fractal_task( task_function=cellpose_segmentation, - TaskArgsModel=TaskArguments, + validate=True, logger_name=logger.name, ) diff --git a/fractal_tasks_core/copy_ome_zarr.py b/fractal_tasks_core/copy_ome_zarr.py index 0369f56dd..9ebc8e44b 100644 --- a/fractal_tasks_core/copy_ome_zarr.py +++ b/fractal_tasks_core/copy_ome_zarr.py @@ -24,8 +24,6 @@ import anndata as ad import zarr from anndata.experimental import write_elem -from pydantic import BaseModel -from pydantic import Extra import fractal_tasks_core from fractal_tasks_core.lib_regions_of_interest import ( @@ -198,20 +196,11 @@ def copy_ome_zarr( return meta_update -class TaskArguments(BaseModel, extra=Extra.forbid): - input_paths: Sequence[str] - output_path: str - metadata: Dict[str, Any] - project_to_2D: Optional[bool] - suffix: Optional[str] - ROI_table_names: Optional[Sequence[str]] - - if __name__ == "__main__": from fractal_tasks_core._utils import run_fractal_task run_fractal_task( task_function=copy_ome_zarr, - TaskArgsModel=TaskArguments, + validate=True, logger_name=logger.name, ) diff --git a/fractal_tasks_core/create_ome_zarr.py b/fractal_tasks_core/create_ome_zarr.py index 3f868ce36..e63398929 100644 --- a/fractal_tasks_core/create_ome_zarr.py +++ b/fractal_tasks_core/create_ome_zarr.py @@ -24,9 +24,6 @@ import pandas as pd import zarr from anndata.experimental import write_elem -from pydantic import BaseModel -from pydantic import Extra -from pydantic import Field import fractal_tasks_core from fractal_tasks_core.lib_channels import check_well_channel_labels @@ -47,31 +44,17 @@ logger = logging.getLogger(__name__) -class TaskArguments(BaseModel, extra=Extra.forbid): - input_paths: Sequence[str] = Field(description="TBD") - output_path: str = Field(description="TBD") - metadata: Dict[str, Any] - image_extension: str = "tif" - image_glob_patterns: Optional[list[str]] = Field( - description="TBD", default=None - ) - allowed_channels: Sequence[Dict[str, Any]] - num_levels: int = Field(description="TBD", default=2) - coarsening_xy: int = Field(description="TBD", default=2) - metadata_table: str = Field(description="TBD", default="mrf_mlf") - - def create_ome_zarr( *, input_paths: Sequence[str], output_path: str, metadata: Dict[str, Any], - image_extension: str = "tif", # FIXME: remove default - image_glob_patterns: Optional[list[str]] = None, # FIXME: remove default + image_extension: str = "tif", + image_glob_patterns: Optional[list[str]] = None, allowed_channels: Sequence[Dict[str, Any]], - num_levels: int = 2, # FIXME: remove default - coarsening_xy: int = 2, # FIXME: remove default - metadata_table: str = "mrf_mlf", # FIXME: remove default + num_levels: int = 2, + coarsening_xy: int = 2, + metadata_table: str = "mrf_mlf", ) -> Dict[str, Any]: """ Create a OME-NGFF zarr folder, without reading/writing image data @@ -451,6 +434,6 @@ def create_ome_zarr( run_fractal_task( task_function=create_ome_zarr, - TaskArgsModel=TaskArguments, + validate=True, logger_name=logger.name, ) diff --git a/fractal_tasks_core/create_ome_zarr_multiplex.py b/fractal_tasks_core/create_ome_zarr_multiplex.py index bdc271ad2..0bee2cfdb 100644 --- a/fractal_tasks_core/create_ome_zarr_multiplex.py +++ b/fractal_tasks_core/create_ome_zarr_multiplex.py @@ -26,8 +26,6 @@ import pandas as pd import zarr from anndata.experimental import write_elem -from pydantic import BaseModel -from pydantic import Extra import fractal_tasks_core from fractal_tasks_core.lib_channels import check_well_channel_labels @@ -482,23 +480,11 @@ def create_ome_zarr_multiplex( return metadata_update -class TaskArguments(BaseModel, extra=Extra.forbid): - input_paths: Sequence[str] - output_path: str - metadata: Dict[str, Any] - image_extension: str - image_glob_patterns: Optional[list[str]] - allowed_channels: Dict[str, Sequence[Dict[str, Any]]] - num_levels: Optional[int] - coarsening_xy: Optional[int] - metadata_table: Optional[Union[Literal["mrf_mlf"], Dict[str, str]]] - - if __name__ == "__main__": from fractal_tasks_core._utils import run_fractal_task run_fractal_task( task_function=create_ome_zarr_multiplex, - TaskArgsModel=TaskArguments, + validate=True, logger_name=logger.name, ) diff --git a/fractal_tasks_core/illumination_correction.py b/fractal_tasks_core/illumination_correction.py index c5e0bc41b..04f80a2ab 100644 --- a/fractal_tasks_core/illumination_correction.py +++ b/fractal_tasks_core/illumination_correction.py @@ -27,8 +27,6 @@ import dask.array as da import numpy as np import zarr -from pydantic import BaseModel -from pydantic import Extra from skimage.io import imread from fractal_tasks_core.lib_channels import get_omero_channel_list @@ -272,24 +270,11 @@ def illumination_correction( return {} -class TaskArguments(BaseModel, extra=Extra.forbid): - # Fractal arguments - input_paths: Sequence[str] - output_path: str - component: str - metadata: Dict[str, Any] - # Task-specific arguments - overwrite: Optional[bool] - new_component: Optional[str] - dict_corr: dict - background: Optional[int] - - if __name__ == "__main__": from fractal_tasks_core._utils import run_fractal_task run_fractal_task( task_function=illumination_correction, - TaskArgsModel=TaskArguments, + validate=True, logger_name=logger.name, ) diff --git a/fractal_tasks_core/maximum_intensity_projection.py b/fractal_tasks_core/maximum_intensity_projection.py index 20410854f..645f60e5b 100644 --- a/fractal_tasks_core/maximum_intensity_projection.py +++ b/fractal_tasks_core/maximum_intensity_projection.py @@ -22,8 +22,6 @@ import anndata as ad import dask.array as da -from pydantic import BaseModel -from pydantic import Extra from fractal_tasks_core.lib_pyramid_creation import build_pyramid from fractal_tasks_core.lib_regions_of_interest import ( @@ -132,19 +130,12 @@ def maximum_intensity_projection( return {} -class TaskArguments(BaseModel, extra=Extra.forbid): - input_paths: Sequence[str] - output_path: str - metadata: Dict[str, Any] - component: str - - if __name__ == "__main__": from fractal_tasks_core._utils import run_fractal_task run_fractal_task( task_function=maximum_intensity_projection, - TaskArgsModel=TaskArguments, + validate=True, logger_name=logger.name, ) diff --git a/fractal_tasks_core/napari_workflows_wrapper.py b/fractal_tasks_core/napari_workflows_wrapper.py index 07c463f72..7e363b886 100644 --- a/fractal_tasks_core/napari_workflows_wrapper.py +++ b/fractal_tasks_core/napari_workflows_wrapper.py @@ -20,7 +20,6 @@ from typing import Any from typing import Dict from typing import List -from typing import Optional from typing import Sequence import anndata as ad @@ -31,8 +30,6 @@ import zarr from anndata.experimental import write_elem from napari_workflows._io_yaml_v1 import load_workflow -from pydantic import BaseModel -from pydantic import Extra import fractal_tasks_core from fractal_tasks_core.lib_channels import get_channel_from_image_zarr @@ -605,25 +602,11 @@ def napari_workflows_wrapper( return {} -class TaskArguments(BaseModel, extra=Extra.forbid): - input_paths: Sequence[str] - output_path: str - metadata: Dict[str, Any] - component: str - workflow_file: str - input_specs: Dict[str, Dict[str, str]] - output_specs: Dict[str, Dict[str, str]] - input_ROI_table: Optional[str] - level: Optional[int] - relabeling: Optional[bool] - expected_dimensions: Optional[int] - - if __name__ == "__main__": from fractal_tasks_core._utils import run_fractal_task run_fractal_task( task_function=napari_workflows_wrapper, - TaskArgsModel=TaskArguments, + validate=True, logger_name=logger.name, ) diff --git a/fractal_tasks_core/yokogawa_to_ome_zarr.py b/fractal_tasks_core/yokogawa_to_ome_zarr.py index bb1c414b1..6f8f83dfe 100644 --- a/fractal_tasks_core/yokogawa_to_ome_zarr.py +++ b/fractal_tasks_core/yokogawa_to_ome_zarr.py @@ -19,15 +19,12 @@ from pathlib import Path from typing import Any from typing import Dict -from typing import Optional from typing import Sequence import dask.array as da import zarr from anndata import read_zarr from dask.array.image import imread -from pydantic import BaseModel -from pydantic import Extra from fractal_tasks_core.lib_channels import get_omero_channel_list from fractal_tasks_core.lib_glob import glob_with_multiple_patterns @@ -218,19 +215,11 @@ def yokogawa_to_ome_zarr( return {} -class TaskArguments(BaseModel, extra=Extra.forbid): - input_paths: Sequence[str] - output_path: str - metadata: Dict[str, Any] - component: str - delete_input: Optional[bool] - - if __name__ == "__main__": from fractal_tasks_core._utils import run_fractal_task run_fractal_task( task_function=yokogawa_to_ome_zarr, - TaskArgsModel=TaskArguments, + validate=True, logger_name=logger.name, ) diff --git a/scripts/update_args_schemas.py b/scripts/update_args_schemas.py index b24fe0d9b..5054d049d 100644 --- a/scripts/update_args_schemas.py +++ b/scripts/update_args_schemas.py @@ -6,6 +6,13 @@ import json from importlib import import_module from pathlib import Path +from typing import Any + +from pydantic.decorator import ALT_V_ARGS +from pydantic.decorator import ALT_V_KWARGS +from pydantic.decorator import V_DUPLICATE_KWARGS +from pydantic.decorator import V_POSITIONAL_ONLY_NAME +from pydantic.decorator import ValidatedFunction import fractal_tasks_core @@ -13,6 +20,42 @@ FRACTAL_TASKS_CORE_DIR = Path(fractal_tasks_core.__file__).parent +def _clean_up_pydantic_generated_schema(old_schema: dict[str, Any]): + """ + FIXME: duplicate of the function in tests/test_valid_args_schemas.py + + Strip some properties from the generated JSON Schema, see + https://github.com/pydantic/pydantic/blob/1.10.X-fixes/pydantic/decorator.py. + """ + new_schema = old_schema.copy() + + # Check that args and kwargs properties match with some expected dummy + # values, and remove them from the the schema properties. + args_property = new_schema["properties"].pop("args") + kwargs_property = new_schema["properties"].pop("kwargs") + expected_args_property = {"title": "Args", "type": "array", "items": {}} + expected_kwargs_property = {"title": "Kwargs", "type": "object"} + if args_property != expected_args_property: + raise ValueError( + f"{args_property=}\ndiffers from\n{expected_args_property=}" + ) + if kwargs_property != expected_kwargs_property: + raise ValueError( + f"{kwargs_property=}\ndiffers from\n" + f"{expected_kwargs_property=}" + ) + + # Remove other properties, since they come from pydantic internals + for key in ( + V_POSITIONAL_ONLY_NAME, + V_DUPLICATE_KWARGS, + ALT_V_ARGS, + ALT_V_KWARGS, + ): + new_schema["properties"].pop(key, None) + return new_schema + + def get_task_list_from_manifest() -> list[dict]: with (FRACTAL_TASKS_CORE_DIR / "__FRACTAL_MANIFEST__.json").open("r") as f: manifest = json.load(f) @@ -26,8 +69,10 @@ def create_schema_for_single_task(task: dict): raise ValueError(f"Invalid {executable=}") module_name = executable[:-3] module = import_module(f"fractal_tasks_core.{module_name}") - TaskArguments = getattr(module, "TaskArguments") - schema = TaskArguments.schema() + task_function = getattr(module, module_name) + vf = ValidatedFunction(task_function, config=None) + schema = vf.model.schema() + schema = _clean_up_pydantic_generated_schema(schema) return schema, module_name diff --git a/tests/test_valid_args_schemas.py b/tests/test_valid_args_schemas.py index e75455f68..394fa61ab 100644 --- a/tests/test_valid_args_schemas.py +++ b/tests/test_valid_args_schemas.py @@ -1,31 +1,66 @@ import json from importlib import import_module +from inspect import signature from pathlib import Path +from typing import Any +from typing import Callable import pytest from devtools import debug from jsonschema.validators import Draft201909Validator from jsonschema.validators import Draft202012Validator from jsonschema.validators import Draft7Validator +from pydantic.decorator import ALT_V_ARGS +from pydantic.decorator import ALT_V_KWARGS +from pydantic.decorator import V_DUPLICATE_KWARGS +from pydantic.decorator import V_POSITIONAL_ONLY_NAME +from pydantic.decorator import ValidatedFunction import fractal_tasks_core +def _clean_up_pydantic_generated_schema(old_schema: dict[str, Any]): + """ + FIXME: duplicate of the same function in scripts/update_args_schema.py + + Strip some properties from the generated JSON Schema, see + https://github.com/pydantic/pydantic/blob/1.10.X-fixes/pydantic/decorator.py. + """ + new_schema = old_schema.copy() + + # Check that args and kwargs properties match with some expected dummy + # values, and remove them from the the schema properties. + args_property = new_schema["properties"].pop("args") + kwargs_property = new_schema["properties"].pop("kwargs") + expected_args_property = {"title": "Args", "type": "array", "items": {}} + expected_kwargs_property = {"title": "Kwargs", "type": "object"} + assert args_property == expected_args_property + assert kwargs_property == expected_kwargs_property + + # Remove other properties, since they come from pydantic internals + for key in ( + V_POSITIONAL_ONLY_NAME, + V_DUPLICATE_KWARGS, + ALT_V_ARGS, + ALT_V_KWARGS, + ): + new_schema["properties"].pop(key, None) + return new_schema + + FRACTAL_TASKS_CORE_DIR = Path(fractal_tasks_core.__file__).parent with (FRACTAL_TASKS_CORE_DIR / "__FRACTAL_MANIFEST__.json").open("r") as f: MANIFEST = json.load(f) TASK_LIST = MANIFEST["task_list"] - -def _create_schema_for_single_task(task: dict): - executable = task["executable"] - if not executable.endswith(".py"): - raise ValueError(f"Invalid {executable=}") - module_name = executable[:-3] - module = import_module(f"fractal_tasks_core.{module_name}") - TaskArguments = getattr(module, "TaskArguments") - schema = TaskArguments.schema() - return schema +FORBIDDEN_PARAM_NAMES = ( + "args", + "kwargs", + V_POSITIONAL_ONLY_NAME, + V_DUPLICATE_KWARGS, + ALT_V_ARGS, + ALT_V_KWARGS, +) def _extract_function(task: dict): @@ -38,20 +73,72 @@ def _extract_function(task: dict): return task_function +def _create_schema_for_single_task(task: dict): + task_function = _extract_function(task) + vf = ValidatedFunction(task_function, config=None) + schema = vf.model.schema() + schema = _clean_up_pydantic_generated_schema(schema) + return schema + + +def _validate_function_signature(function: Callable): + """ + Check that function parameters do not have forbidden names + """ + for param in signature(function).parameters.values(): + if param.name in FORBIDDEN_PARAM_NAMES: + raise ValueError( + f"Function {function} has argument with name {param.name}" + ) + + +def test_validate_function_signature(): + """ + Showcase the expected behavior of _validate_function_signature + """ + + def fun1(x: int): + pass + + _validate_function_signature(fun1) + + def fun2(x, *args): + pass + + with pytest.raises(ValueError): + _validate_function_signature(fun2) + + def fun3(x, **kwargs): + pass + + with pytest.raises(ValueError): + _validate_function_signature(fun3) + + def test_manifest_has_args_schemas_is_true(): debug(MANIFEST) assert MANIFEST["has_args_schemas"] +def test_task_functions_have_no_args_or_kwargs(): + """ + Test that task functions do not use forbidden parameter names (e.g. `args` + or `kwargs`) + """ + for ind_task, task in enumerate(TASK_LIST): + task_function = _extract_function(task) + _validate_function_signature(task_function) + + def test_args_schemas_are_up_to_date(): + """ + Test that args_schema attributes in the manifest are up-to-date + """ for ind_task, task in enumerate(TASK_LIST): print(f"Now handling {task['executable']}") - new_schema = _create_schema_for_single_task(task) old_schema = TASK_LIST[ind_task]["args_schema"] - if not new_schema == old_schema: - raise ValueError("Schemas are different.") - print(f"Schema for task {task['executable']} is up-to-date.") - print() + new_schema = _create_schema_for_single_task(task) + assert new_schema == old_schema @pytest.mark.parametrize( @@ -70,19 +157,3 @@ def test_args_schema_comply_with_jsonschema_specs(jsonschema_validator): f"Schema for task {task['executable']} is valid for " f"{jsonschema_validator}." ) - print() - - -def test_args_schema_match_with_function_arguments(): - for ind_task, task in enumerate(TASK_LIST): - print(f"Now handling {task['executable']}") - schema = _create_schema_for_single_task(task) - fun = _extract_function(task) - debug(fun) - names_from_signature = set( - name - for name, _type in fun.__annotations__.items() - if name != "return" - ) - name_from_args_schema = set(schema["properties"].keys()) - assert names_from_signature == name_from_args_schema diff --git a/tests/test_valid_task_interface.py b/tests/test_valid_task_interface.py index e4f519140..64cc27959 100644 --- a/tests/test_valid_task_interface.py +++ b/tests/test_valid_task_interface.py @@ -27,10 +27,9 @@ def validate_command(cmd: str): # Valid stderr includes pydantic.error_wrappers.ValidationError (type # match between model and function, but tmp_file_args has wrong arguments) assert "pydantic.error_wrappers.ValidationError" in stderr - # Valid stderr must include a mention of "extra fields not permitted". If - # this is missing, it probably means that we forgot to add - # `extra=Extra.forbid` in a `TaskArguments` definition - assert "extra fields not permitted (type=value_error.extra)" in stderr + # Valid stderr must include a mention of "unexpected keyword arguments", + # because we are including some invalid arguments + assert "unexpected keyword arguments" in stderr # Invalid stderr includes ValueError assert "ValueError" not in stderr