From 535ff6e21c136d840644a1de1ded8594adb2217d Mon Sep 17 00:00:00 2001 From: Max Ploner Date: Mon, 11 Nov 2024 15:12:15 +0100 Subject: [PATCH 1/7] Updated ruff call --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f039016..4b63577 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,7 @@ dependencies = [ [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive {args:src/cordage tests}" style = [ - "ruff {args:.}", + "ruff check {args:.}", "black --check --diff {args:.}", ] fmt = [ From 9ac1e5817b0980d1e378da443c926f010a14b9c2 Mon Sep 17 00:00:00 2001 From: Max Ploner Date: Mon, 11 Nov 2024 15:23:14 +0100 Subject: [PATCH 2/7] Renamed workflow runs --- .github/workflows/docs.yml | 2 +- .github/workflows/publish.yml | 2 +- .github/workflows/tests.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index d3a2dd1..ec536ad 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -10,7 +10,7 @@ on: permissions: contents: write jobs: - deploy: + deploy_documentation: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index bccd32a..66da8a3 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -8,7 +8,7 @@ permissions: contents: read jobs: - deploy: + publish: runs-on: ubuntu-latest diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 267e06d..b4eed85 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ on: - "mkdocs.yml" jobs: - run: + test: name: "Test & Coverage" runs-on: ubuntu-latest strategy: From 1875a5576a2b25afd21c362a22f15226bd673e7c Mon Sep 17 00:00:00 2001 From: Max Ploner Date: Mon, 11 Nov 2024 15:27:49 +0100 Subject: [PATCH 3/7] Added job names to workflow --- .github/workflows/docs.yml | 1 + .github/workflows/publish.yml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index ec536ad..80decb6 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -11,6 +11,7 @@ permissions: contents: write jobs: deploy_documentation: + name: "Deploy documentation to github pages" runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 66da8a3..e309f33 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -9,7 +9,7 @@ permissions: jobs: publish: - + name: "Publish to PyPI" runs-on: ubuntu-latest environment: release From 7576f6841fbf364aa9821d94670ca605e337ece1 Mon Sep 17 00:00:00 2001 From: Max Ploner Date: Mon, 11 Nov 2024 15:41:33 +0100 Subject: [PATCH 4/7] Updated linting --- pyproject.toml | 12 +++++++----- src/cordage/context.py | 4 ++-- tests/test_cli.py | 2 +- tests/test_exceptions.py | 4 ++-- tests/test_field_types.py | 10 +++++----- tests/test_func_params.py | 2 +- tests/test_loading.py | 2 +- tests/test_misc_functionality.py | 8 ++++---- tests/test_nested_config.py | 2 +- tests/test_output_dir.py | 2 +- tests/test_series_creation.py | 2 +- tests/test_status_and_tags.py | 2 +- 12 files changed, 27 insertions(+), 25 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4b63577..fec5320 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ detached = true dependencies = [ "black>=23.1.0", "mypy>=1.0.0", - "ruff>=0.0.243", + "ruff>=0.7.0", ] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive {args:src/cordage tests}" @@ -88,7 +88,7 @@ style = [ ] fmt = [ "black {args:.}", - "ruff --fix {args:.}", + "ruff check --fix {args:.}", "style", ] all = [ @@ -104,6 +104,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -148,13 +150,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["cordage"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/src/cordage/context.py b/src/cordage/context.py index 093d4f7..58add99 100644 --- a/src/cordage/context.py +++ b/src/cordage/context.py @@ -230,7 +230,7 @@ def _add_argument_to_parser(self, arg_name: str, arg_type: Any, help: str, **kw) ) elif get_origin(arg_type) is Union: - args = [arg for arg in get_args(arg_type) if arg != type(None)] + args = [arg for arg in get_args(arg_type) if arg is not type(None)] if len(args) == 1: # optional @@ -241,7 +241,7 @@ def _add_argument_to_parser(self, arg_name: str, arg_type: Any, help: str, **kw) raise TypeError(msg) # Boolean field - elif arg_type == bool: + elif arg_type is bool: # Create a true and a false flag -> the destination is identical self.arg_group_config.add_argument( f"--{arg_name}", action="store_true", default=MISSING, help=help + " (set the value to True)", **kw diff --git a/tests/test_cli.py b/tests/test_cli.py index fbfa6b9..79303f2 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -9,7 +9,7 @@ def test_manual_output_dir(global_config, tmp_path): - def func(config: SimpleConfig): # noqa: ARG001 + def func(config: SimpleConfig): pass experiment = cordage.run( diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index fcf9e45..3315f12 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -36,7 +36,7 @@ def func(config: Config): # noqa: ARG001 def test_function_without_annotation(): - def func(config): # noqa: ARG001 + def func(config): pass with pytest.raises(TypeError) as e_info: @@ -56,7 +56,7 @@ def func(): def test_function_invalid_object_to_execute(global_config): - def func(config: Config): # noqa: ARG001 + def func(config: Config): pass context = cordage.FunctionContext(func, global_config=global_config) diff --git a/tests/test_field_types.py b/tests/test_field_types.py index a988e7e..48ed901 100644 --- a/tests/test_field_types.py +++ b/tests/test_field_types.py @@ -34,7 +34,7 @@ def func(config: Config): def test_literal_fields(global_config, resources_path): - def func(config: Config): # noqa: ARG001 + def func(config: Config): pass config_file = resources_path / "simple_b.json" @@ -44,7 +44,7 @@ def func(config: Config): # noqa: ARG001 def test_tuple_length_fields(global_config, resources_path): - def func(config: Config): # noqa: ARG001 + def func(config: Config): pass config_file = resources_path / "simple_c.toml" @@ -55,7 +55,7 @@ def func(config: Config): # noqa: ARG001 @pytest.mark.skip(reason="dacite currently does not properly support mixed tuples") def test_valid_mixed_tuple(global_config, resources_path): - def func(config: Config): # noqa: ARG001 + def func(config: Config): pass config_file = resources_path / "simple_d.json" @@ -65,7 +65,7 @@ def func(config: Config): # noqa: ARG001 @pytest.mark.skip(reason="dacite currently does not properly support mixed tuples") def test_invalid_mixed_tuple(global_config, resources_path): - def func(config: Config): # noqa: ARG001 + def func(config: Config): pass config_file = resources_path / "simple_e.toml" @@ -75,7 +75,7 @@ def func(config: Config): # noqa: ARG001 def test_valid_optional(global_config, resources_path): - def func(config: Config): # noqa: ARG001 + def func(config: Config): pass config_file = resources_path / "simple_f.json" diff --git a/tests/test_func_params.py b/tests/test_func_params.py index 0b530f6..c998b01 100644 --- a/tests/test_func_params.py +++ b/tests/test_func_params.py @@ -58,7 +58,7 @@ def func(config: Config, cordage_trial): def test_explicit_config_class(global_config): - def func(config): # noqa: ARG001 + def func(config): pass cordage.run(func, args=["--a", "1", "--b", "test"], global_config=global_config, config_cls=Config) diff --git a/tests/test_loading.py b/tests/test_loading.py index 69b2689..0802b41 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -10,7 +10,7 @@ @pytest.mark.timeout(1) def test_metadata_loading_config_class_casting(global_config): - def func(config: SimpleConfig): # noqa: ARG001 + def func(config: SimpleConfig): pass trial = cordage.run(func, args=["--a", "1", "--b", "2"], global_config=global_config) diff --git a/tests/test_misc_functionality.py b/tests/test_misc_functionality.py index 1c33e40..921efe1 100644 --- a/tests/test_misc_functionality.py +++ b/tests/test_misc_functionality.py @@ -50,7 +50,7 @@ def func(config: SimpleConfig, cordage_trial: cordage.Trial, trial_store=trial_s @pytest.mark.timeout(1) def test_function_context_from_configuration(global_config): - def func(config: SimpleConfig): # noqa: ARG001 + def func(config: SimpleConfig): pass context = FunctionContext(func, global_config=global_config) @@ -68,7 +68,7 @@ def func(config: SimpleConfig): # noqa: ARG001 @pytest.mark.timeout(1) def test_output_dir_path_correction(monkeypatch, tmp_path): - def func(config: SimpleConfig): # noqa: ARG001 + def func(config: SimpleConfig): pass run_dir = tmp_path / "run" @@ -92,7 +92,7 @@ def func(config: SimpleConfig): # noqa: ARG001 def test_config_only(global_config): - def func(config: SimpleConfig): # noqa: ARG001 + def func(config: SimpleConfig): pass cordage.run(func, args=[], global_config=global_config, config_only=True) @@ -102,7 +102,7 @@ def func(config: SimpleConfig): # noqa: ARG001 def test_config_only_func_params(global_config): - def func(config: SimpleConfig, output_dir): # noqa: ARG001 + def func(config: SimpleConfig, output_dir): pass with pytest.raises(TypeError): diff --git a/tests/test_nested_config.py b/tests/test_nested_config.py index ad8d9e9..08c7b36 100644 --- a/tests/test_nested_config.py +++ b/tests/test_nested_config.py @@ -93,7 +93,7 @@ def func(config: NestedConfig): def test_additional_keys_exception(global_config, resources_path): - def func(config: NestedConfig): # noqa: ARG001 + def func(config: NestedConfig): pass with pytest.raises(cordage.exceptions.CordageError): diff --git a/tests/test_output_dir.py b/tests/test_output_dir.py index 97e3b83..5bf6500 100644 --- a/tests/test_output_dir.py +++ b/tests/test_output_dir.py @@ -122,7 +122,7 @@ class ConfigWithOutputDir: b: str output_dir: int = 1 - def func(config: ConfigWithOutputDir): # noqa: ARG001 + def func(config: ConfigWithOutputDir): pass with pytest.raises(TypeError): diff --git a/tests/test_series_creation.py b/tests/test_series_creation.py index c55f04b..38807dc 100644 --- a/tests/test_series_creation.py +++ b/tests/test_series_creation.py @@ -68,7 +68,7 @@ def func(config: Config, cordage_trial: cordage.Trial, trial_store=trial_store): def test_invalid_trial_series(global_config, resources_path): - def func(config: Config, cordage_trial: cordage.Trial): # noqa: ARG001 + def func(config: Config, cordage_trial: cordage.Trial): pass config_file = resources_path / "series_invalid.json" diff --git a/tests/test_status_and_tags.py b/tests/test_status_and_tags.py index be9be97..5bddf90 100644 --- a/tests/test_status_and_tags.py +++ b/tests/test_status_and_tags.py @@ -107,7 +107,7 @@ def func(config: Config, cordage_trial): # noqa: ARG001 def test_function_name_saving(global_config, resources_path): - def func(config: Config): # noqa: ARG001 + def func(config: Config): pass conf_path = resources_path / "annotation.yaml" From 95e8d6b82cffd003122db93ca185693b1c1b327a Mon Sep 17 00:00:00 2001 From: Max Ploner Date: Mon, 11 Nov 2024 15:43:40 +0100 Subject: [PATCH 5/7] Updated black --- pyproject.toml | 2 +- src/cordage/experiment.py | 6 ++---- src/cordage/util.py | 6 ++---- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fec5320..de8d4ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] detached = true dependencies = [ - "black>=23.1.0", + "black>=24.8.0", "mypy>=1.0.0", "ruff>=0.7.0", ] diff --git a/src/cordage/experiment.py b/src/cordage/experiment.py index d584830..872b692 100644 --- a/src/cordage/experiment.py +++ b/src/cordage/experiment.py @@ -565,12 +565,10 @@ def __exit__(self, *args): # else: do nothing @overload - def get_changing_fields(self, sep: Literal[None] = None) -> Set[Tuple[Any, ...]]: - ... + def get_changing_fields(self, sep: Literal[None] = None) -> Set[Tuple[Any, ...]]: ... @overload - def get_changing_fields(self, sep: str) -> Set[str]: - ... + def get_changing_fields(self, sep: str) -> Set[str]: ... def get_changing_fields(self, sep: Optional[str] = None) -> Union[Set[Tuple[Any, ...]], Set[str]]: keys: Set = set() diff --git a/src/cordage/util.py b/src/cordage/util.py index 606d10a..3bfe404 100644 --- a/src/cordage/util.py +++ b/src/cordage/util.py @@ -149,16 +149,14 @@ def write_dict_to_file(path: PathLike, data: Mapping[str, Any]): @overload def flattened_items( nested_dict: Dict[Any, Any], *, sep: Literal[None] = None, prefix: Tuple[Any, ...] = () -) -> Generator[Tuple[Tuple[Any, ...], Any], None, None]: - ... +) -> Generator[Tuple[Tuple[Any, ...], Any], None, None]: ... # spearator given @overload def flattened_items( nested_dict: Dict[Any, Any], *, sep: str, prefix: Tuple[str, ...] = () -) -> Generator[Tuple[str, Any], None, None]: - ... +) -> Generator[Tuple[str, Any], None, None]: ... def flattened_items( From f68ea91822655fab07d0eaa3743936ec565425e3 Mon Sep 17 00:00:00 2001 From: Max Ploner Date: Mon, 11 Nov 2024 17:04:22 +0100 Subject: [PATCH 6/7] Improved typing --- pyproject.toml | 2 +- src/cordage/context.py | 5 ++- src/cordage/experiment.py | 77 ++++++++++++++++++++++++++----------- src/cordage/util.py | 12 ++++-- tests/test_exceptions.py | 4 +- tests/test_global_config.py | 6 ++- tests/test_loading.py | 15 ++++++-- 7 files changed, 84 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index de8d4ec..37bacad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ python = ["3.8", "3.9", "3.10", "3.11"] detached = true dependencies = [ "black>=24.8.0", - "mypy>=1.0.0", + "mypy>=1.7.0", "ruff>=0.7.0", ] [tool.hatch.envs.lint.scripts] diff --git a/src/cordage/context.py b/src/cordage/context.py index 58add99..3848472 100644 --- a/src/cordage/context.py +++ b/src/cordage/context.py @@ -10,8 +10,8 @@ from cordage.experiment import Experiment, Series, Trial from cordage.global_config import GlobalConfig +from cordage.util import ConfigClass, logger, nest_items, nested_update, read_dict_from_file from cordage.util import from_dict as config_from_dict -from cordage.util import logger, nest_items, nested_update, read_dict_from_file class MissingType: @@ -105,7 +105,7 @@ def __init__( func: Callable, # expects a dataclass global_config: GlobalConfig, description: Optional[str] = None, - config_cls: Optional[Type] = None, + config_cls: Optional[Type[ConfigClass]] = None, ): self.global_config = global_config self.set_function(func) @@ -208,6 +208,7 @@ def construct_argument_parser(self): def _add_argument_to_parser(self, arg_name: str, arg_type: Any, help: str, **kw): # noqa: A002 # If the field is also a dataclass, recurse (nested config) if dataclasses.is_dataclass(arg_type): + assert isinstance(arg_type, type) self.arg_group_config.add_argument( f"--{arg_name}", type=Path, default=MISSING, help=help, metavar="PATH", **kw ) diff --git a/src/cordage/experiment.py b/src/cordage/experiment.py index 872b692..8953aa0 100644 --- a/src/cordage/experiment.py +++ b/src/cordage/experiment.py @@ -25,9 +25,9 @@ Set, Tuple, Type, + TypedDict, TypeVar, Union, - cast, overload, ) @@ -38,26 +38,49 @@ except ImportError: colorlog = None # type: ignore +import typing + from cordage.global_config import GlobalConfig -from cordage.util import config_output_dir_type, flattened_items, from_dict, logger, nest_items, nested_update, to_dict +from cordage.util import ( + config_output_dir_type, + flattened_items, + from_dict, + logger, + nest_items, + nested_update, + to_dict, +) + +if typing.TYPE_CHECKING: + from _typeshed import DataclassInstance + + +ConfigClass = TypeVar("ConfigClass", bound="DataclassInstance") + + +class SeriesConfiguration(Generic[ConfigClass], TypedDict): + base_config: ConfigClass + series_spec: Union[List[Dict], Dict[str, List], None] + series_skip: Optional[int] -T = TypeVar("T") + +Configuration = TypeVar("Configuration", SeriesConfiguration, "DataclassInstance") @dataclass -class Metadata: +class Metadata(Generic[Configuration]): function: str global_config: GlobalConfig + configuration: Union[Configuration, Dict[str, Any]] + output_dir: Optional[Path] = None status: str = "pending" start_time: Optional[datetime] = None end_time: Optional[datetime] = None - configuration: Any = None - result: Any = None parent_dir: Optional[Path] = None @@ -85,10 +108,12 @@ def from_dict(cls, data: Mapping): return from_dict(cls, data) -class MetadataStore: +class MetadataStore(Generic[Configuration]): _warned_deprecated_nested_global_config: bool = False - def __init__(self, metadata: Optional[Metadata] = None, /, global_config: Optional[GlobalConfig] = None, **kw): + def __init__( + self, metadata: Optional[Metadata[Configuration]] = None, /, global_config: Optional[GlobalConfig] = None, **kw + ): self.metadata: Metadata if metadata is not None: @@ -209,7 +234,7 @@ def load_metadata(cls, path: PathLike) -> Metadata: return metadata -class Annotatable(MetadataStore): +class Annotatable(MetadataStore[Configuration]): TAG_PATTERN = re.compile(r"\B#(\w*[a-zA-Z]+\w*)") def __init__(self, *args, **kw): @@ -262,7 +287,7 @@ def load_annotations(self): self.annotations = json.load(fp) -class Experiment(Annotatable): +class Experiment(Annotatable[Configuration]): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -359,7 +384,7 @@ def synchronize(self): return self @classmethod - def from_path(cls, path: PathLike, config_cls: Optional[Type] = None): + def from_path(cls, path: PathLike, config_cls: Optional[Type[ConfigClass]] = None): metadata: Metadata = cls.load_metadata(path) experiment: Experiment @@ -448,12 +473,12 @@ def teardown_log(self): logger.removeHandler(handler) -class Trial(Generic[T], Experiment): +class Trial(Experiment["DataclassInstance"], Generic[ConfigClass]): def __init__( self, - metadata: Optional[Metadata] = None, + metadata: Optional[Metadata["DataclassInstance"]] = None, /, - config: Optional[T] = None, + config: Optional[ConfigClass] = None, **kw, ): if metadata is not None: @@ -466,7 +491,13 @@ def __init__( super().__init__(configuration=config, **kw) @property - def config(self) -> T: + def config(self) -> ConfigClass: + if isinstance(self.metadata.configuration, dict): + msg = ( + "`trial.config` is only available if the configuration was loaded with a configuration dataclass " + "(you could use `trial.metadata.configuration` instead)." + ) + raise AttributeError(msg) return self.metadata.configuration def set_output_dir(self, path: Path): @@ -479,12 +510,12 @@ def set_output_dir(self, path: Path): logging.info("Config: %s", self.metadata.configuration) -class Series(Generic[T], Experiment): +class Series(Generic[ConfigClass], Experiment[SeriesConfiguration]): def __init__( self, metadata: Optional[Metadata] = None, /, - base_config: Optional[T] = None, + base_config: Optional[ConfigClass] = None, series_spec: Union[List[Dict], Dict[str, List], None] = None, series_skip: Optional[int] = None, **kw, @@ -502,7 +533,7 @@ def __init__( self.validate_series_spec() - self.trials: Optional[List[Trial[T]]] = None + self.trials: Optional[List[Trial[ConfigClass]]] = None self.make_all_trials() def set_output_dir(self, path: Path): @@ -534,7 +565,7 @@ def only_list_nodes(d): assert series_spec is None @property - def base_config(self) -> T: + def base_config(self) -> ConfigClass: return self.metadata.configuration["base_config"] @property @@ -661,11 +692,11 @@ def make_all_trials(self): nested_update(trial_config_data, trial_update) - trial_config: T + trial_config: ConfigClass if isinstance(self.base_config, dict): - trial_config = cast(T, trial_config_data) + trial_config = trial_config_data else: - trial_config = cast(T, from_dict(type(self.base_config), trial_config_data)) + trial_config = from_dict(type(self.base_config), trial_config_data) if i < self.series_skip: status = "skipped" @@ -678,7 +709,7 @@ def make_all_trials(self): def __iter__(self): return self.get_all_trials(include_skipped=False) - def get_all_trials(self, *, include_skipped: bool = False) -> Generator[Trial[T], None, None]: + def get_all_trials(self, *, include_skipped: bool = False) -> Generator[Trial[ConfigClass], None, None]: assert self.trials is not None if not self.is_singular: diff --git a/src/cordage/util.py b/src/cordage/util.py index 3bfe404..790c224 100644 --- a/src/cordage/util.py +++ b/src/cordage/util.py @@ -1,5 +1,6 @@ import dataclasses import logging +import typing from datetime import datetime, timedelta from os import PathLike from pathlib import Path @@ -24,12 +25,15 @@ import dacite import dacite.exceptions +if typing.TYPE_CHECKING: + from _typeshed import DataclassInstance + import cordage.exceptions logger = logging.getLogger("cordage") -T = TypeVar("T") +ConfigClass = TypeVar("ConfigClass", bound="DataclassInstance") serialization_map: Dict[Type[Any], Callable[..., Any]] = { Path: str, @@ -230,7 +234,7 @@ def nest_items(flat_items: Iterable[Tuple[Union[str, Tuple[Any, ...]], Any]]) -> return nested_dict -def from_dict(data_class: Type[T], data: Mapping, *, strict: bool = True) -> T: +def from_dict(data_class: Type[ConfigClass], data: Mapping, *, strict: bool = True) -> ConfigClass: config = dacite.Config(cast=types_to_cast, type_hooks=deserialization_map, strict=strict) try: return dacite.from_dict(data_class, data, config) @@ -253,7 +257,7 @@ def from_dict(data_class: Type[T], data: Mapping, *, strict: bool = True) -> T: raise cordage.exceptions.CordageError(msg) from e -def from_file(config_cls: Type[T], path: PathLike, **kwargs) -> T: +def from_file(config_cls: Type[ConfigClass], path: PathLike, **kwargs) -> ConfigClass: data: Mapping = read_dict_from_file(path) return from_dict(config_cls, data, **kwargs) @@ -297,7 +301,7 @@ def set_nested_field(dataclass_instance, field_name: str, value: Any): setattr(obj, last_key, value) -def to_dict(data: Any) -> dict: +def to_dict(data: Union[ConfigClass, Mapping]) -> dict: """Represent the fields and values of configuration as a (nested) dict.""" mapping: Mapping diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 3315f12..a8f7561 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -68,12 +68,12 @@ def func(config: Config): def test_multiple_runtime_exceptions(global_config): - metadata: cordage.Metadata = cordage.Metadata(function="no_function", global_config=global_config) + metadata: cordage.Metadata = cordage.Metadata(function="no_function", global_config=global_config, configuration={}) with pytest.raises(TypeError): exp = cordage.Experiment(metadata, global_config=global_config) - exp = cordage.Experiment(function="no_function", global_config=global_config) + exp = cordage.Experiment(function="no_function", global_config=global_config, configuration={}) with pytest.raises(RuntimeError): log.info(str(exp.output_dir)) diff --git a/tests/test_global_config.py b/tests/test_global_config.py index ce29799..65f3aff 100644 --- a/tests/test_global_config.py +++ b/tests/test_global_config.py @@ -1,13 +1,17 @@ import dataclasses import json +from typing import TYPE_CHECKING import pytest +if TYPE_CHECKING: + from _typeshed import DataclassInstance + from cordage import GlobalConfig from cordage.util import to_dict -def test_global_config(global_config): +def test_global_config(global_config: "DataclassInstance"): assert dataclasses.is_dataclass(global_config) json.dumps(to_dict(global_config)) diff --git a/tests/test_loading.py b/tests/test_loading.py index 0802b41..a43207d 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -22,9 +22,11 @@ def func(config: SimpleConfig): # try loading as dict experiment = Experiment.from_path(metadata_path) assert isinstance(experiment, Trial) - assert isinstance(experiment.config, dict) - assert experiment.config["a"] == 1 - assert experiment.config["b"] == "2" + with pytest.raises(AttributeError): + assert isinstance(experiment.config, dict) + assert isinstance(experiment.metadata.configuration, dict) + assert experiment.metadata.configuration["a"] == 1 + assert experiment.metadata.configuration["b"] == "2" # try loading with config class experiment = Experiment.from_path(metadata_path, config_cls=SimpleConfig) @@ -69,7 +71,12 @@ def func(config: NestedConfig, cordage_trial: cordage.Trial): # after loading the series trials, the configs are merely nested dictionaries for i, trial in enumerate(trial_store): - assert trial.config["alpha"]["b"] == f"b{i+1}" + with pytest.raises(AttributeError): + assert isinstance(trial.config, dict) + + config = trial.metadata.configuration + + assert config["alpha"]["b"] == f"b{i+1}" assert trial.has_tag(f"b{i+1}") assert isinstance(trial.metadata.start_time, datetime) From 59a4b9ea36c2bab71f7ca2823c2adf30cd6aae4f Mon Sep 17 00:00:00 2001 From: Max Ploner Date: Mon, 11 Nov 2024 17:09:51 +0100 Subject: [PATCH 7/7] Fixed issue with generic TypedDict See https://alexocallaghan.com/python-typeddict-with-generics for more details on the issue --- src/cordage/experiment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cordage/experiment.py b/src/cordage/experiment.py index 8953aa0..a6d1432 100644 --- a/src/cordage/experiment.py +++ b/src/cordage/experiment.py @@ -58,8 +58,8 @@ ConfigClass = TypeVar("ConfigClass", bound="DataclassInstance") -class SeriesConfiguration(Generic[ConfigClass], TypedDict): - base_config: ConfigClass +class SeriesConfiguration(TypedDict): + base_config: "DataclassInstance" series_spec: Union[List[Dict], Dict[str, List], None] series_skip: Optional[int]