Skip to content

Commit

Permalink
Merge pull request #2 from plonerma/config_output_dir_field
Browse files Browse the repository at this point in the history
Updated source code to comply with new ruff version
  • Loading branch information
plonerma authored Nov 11, 2024
2 parents a2f2912 + 59a4b9e commit 62e4bb1
Show file tree
Hide file tree
Showing 18 changed files with 122 additions and 76 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ on:
permissions:
contents: write
jobs:
deploy:
deploy_documentation:
name: "Deploy documentation to github pages"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ permissions:
contents: read

jobs:
deploy:

publish:
name: "Publish to PyPI"
runs-on: ubuntu-latest

environment: release
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ on:
- "mkdocs.yml"

jobs:
run:
test:
name: "Test & Coverage"
runs-on: ubuntu-latest
strategy:
Expand Down
18 changes: 10 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,19 @@ python = ["3.8", "3.9", "3.10", "3.11"]
[tool.hatch.envs.lint]
detached = true
dependencies = [
"black>=23.1.0",
"mypy>=1.0.0",
"ruff>=0.0.243",
"black>=24.8.0",
"mypy>=1.7.0",
"ruff>=0.7.0",
]
[tool.hatch.envs.lint.scripts]
typing = "mypy --install-types --non-interactive {args:src/cordage tests}"
style = [
"ruff {args:.}",
"ruff check {args:.}",
"black --check --diff {args:.}",
]
fmt = [
"black {args:.}",
"ruff --fix {args:.}",
"ruff check --fix {args:.}",
"style",
]
all = [
Expand All @@ -104,6 +104,8 @@ skip-string-normalization = true
[tool.ruff]
target-version = "py38"
line-length = 120

[tool.ruff.lint]
select = [
"A",
"ARG",
Expand Down Expand Up @@ -148,13 +150,13 @@ unfixable = [
"F401",
]

[tool.ruff.isort]
[tool.ruff.lint.isort]
known-first-party = ["cordage"]

[tool.ruff.flake8-tidy-imports]
[tool.ruff.lint.flake8-tidy-imports]
ban-relative-imports = "all"

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
# Tests can use magic values, assertions, and relative imports
"tests/**/*" = ["PLR2004", "S101", "TID252"]

Expand Down
9 changes: 5 additions & 4 deletions src/cordage/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from cordage.experiment import Experiment, Series, Trial
from cordage.global_config import GlobalConfig
from cordage.util import ConfigClass, logger, nest_items, nested_update, read_dict_from_file
from cordage.util import from_dict as config_from_dict
from cordage.util import logger, nest_items, nested_update, read_dict_from_file


class MissingType:
Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(
func: Callable, # expects a dataclass
global_config: GlobalConfig,
description: Optional[str] = None,
config_cls: Optional[Type] = None,
config_cls: Optional[Type[ConfigClass]] = None,
):
self.global_config = global_config
self.set_function(func)
Expand Down Expand Up @@ -208,6 +208,7 @@ def construct_argument_parser(self):
def _add_argument_to_parser(self, arg_name: str, arg_type: Any, help: str, **kw): # noqa: A002
# If the field is also a dataclass, recurse (nested config)
if dataclasses.is_dataclass(arg_type):
assert isinstance(arg_type, type)
self.arg_group_config.add_argument(
f"--{arg_name}", type=Path, default=MISSING, help=help, metavar="PATH", **kw
)
Expand All @@ -230,7 +231,7 @@ def _add_argument_to_parser(self, arg_name: str, arg_type: Any, help: str, **kw)
)

elif get_origin(arg_type) is Union:
args = [arg for arg in get_args(arg_type) if arg != type(None)]
args = [arg for arg in get_args(arg_type) if arg is not type(None)]

if len(args) == 1:
# optional
Expand All @@ -241,7 +242,7 @@ def _add_argument_to_parser(self, arg_name: str, arg_type: Any, help: str, **kw)
raise TypeError(msg)

# Boolean field
elif arg_type == bool:
elif arg_type is bool:
# Create a true and a false flag -> the destination is identical
self.arg_group_config.add_argument(
f"--{arg_name}", action="store_true", default=MISSING, help=help + " (set the value to True)", **kw
Expand Down
83 changes: 56 additions & 27 deletions src/cordage/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
Set,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
cast,
overload,
)

Expand All @@ -38,26 +38,49 @@
except ImportError:
colorlog = None # type: ignore

import typing

from cordage.global_config import GlobalConfig
from cordage.util import config_output_dir_type, flattened_items, from_dict, logger, nest_items, nested_update, to_dict
from cordage.util import (
config_output_dir_type,
flattened_items,
from_dict,
logger,
nest_items,
nested_update,
to_dict,
)

if typing.TYPE_CHECKING:
from _typeshed import DataclassInstance


ConfigClass = TypeVar("ConfigClass", bound="DataclassInstance")


class SeriesConfiguration(TypedDict):
base_config: "DataclassInstance"
series_spec: Union[List[Dict], Dict[str, List], None]
series_skip: Optional[int]

T = TypeVar("T")

Configuration = TypeVar("Configuration", SeriesConfiguration, "DataclassInstance")


@dataclass
class Metadata:
class Metadata(Generic[Configuration]):
function: str

global_config: GlobalConfig

configuration: Union[Configuration, Dict[str, Any]]

output_dir: Optional[Path] = None
status: str = "pending"

start_time: Optional[datetime] = None
end_time: Optional[datetime] = None

configuration: Any = None

result: Any = None

parent_dir: Optional[Path] = None
Expand Down Expand Up @@ -85,10 +108,12 @@ def from_dict(cls, data: Mapping):
return from_dict(cls, data)


class MetadataStore:
class MetadataStore(Generic[Configuration]):
_warned_deprecated_nested_global_config: bool = False

def __init__(self, metadata: Optional[Metadata] = None, /, global_config: Optional[GlobalConfig] = None, **kw):
def __init__(
self, metadata: Optional[Metadata[Configuration]] = None, /, global_config: Optional[GlobalConfig] = None, **kw
):
self.metadata: Metadata

if metadata is not None:
Expand Down Expand Up @@ -209,7 +234,7 @@ def load_metadata(cls, path: PathLike) -> Metadata:
return metadata


class Annotatable(MetadataStore):
class Annotatable(MetadataStore[Configuration]):
TAG_PATTERN = re.compile(r"\B#(\w*[a-zA-Z]+\w*)")

def __init__(self, *args, **kw):
Expand Down Expand Up @@ -262,7 +287,7 @@ def load_annotations(self):
self.annotations = json.load(fp)


class Experiment(Annotatable):
class Experiment(Annotatable[Configuration]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -359,7 +384,7 @@ def synchronize(self):
return self

@classmethod
def from_path(cls, path: PathLike, config_cls: Optional[Type] = None):
def from_path(cls, path: PathLike, config_cls: Optional[Type[ConfigClass]] = None):
metadata: Metadata = cls.load_metadata(path)

experiment: Experiment
Expand Down Expand Up @@ -448,12 +473,12 @@ def teardown_log(self):
logger.removeHandler(handler)


class Trial(Generic[T], Experiment):
class Trial(Experiment["DataclassInstance"], Generic[ConfigClass]):
def __init__(
self,
metadata: Optional[Metadata] = None,
metadata: Optional[Metadata["DataclassInstance"]] = None,
/,
config: Optional[T] = None,
config: Optional[ConfigClass] = None,
**kw,
):
if metadata is not None:
Expand All @@ -466,7 +491,13 @@ def __init__(
super().__init__(configuration=config, **kw)

@property
def config(self) -> T:
def config(self) -> ConfigClass:
if isinstance(self.metadata.configuration, dict):
msg = (
"`trial.config` is only available if the configuration was loaded with a configuration dataclass "
"(you could use `trial.metadata.configuration` instead)."
)
raise AttributeError(msg)
return self.metadata.configuration

def set_output_dir(self, path: Path):
Expand All @@ -479,12 +510,12 @@ def set_output_dir(self, path: Path):
logging.info("Config: %s", self.metadata.configuration)


class Series(Generic[T], Experiment):
class Series(Generic[ConfigClass], Experiment[SeriesConfiguration]):
def __init__(
self,
metadata: Optional[Metadata] = None,
/,
base_config: Optional[T] = None,
base_config: Optional[ConfigClass] = None,
series_spec: Union[List[Dict], Dict[str, List], None] = None,
series_skip: Optional[int] = None,
**kw,
Expand All @@ -502,7 +533,7 @@ def __init__(

self.validate_series_spec()

self.trials: Optional[List[Trial[T]]] = None
self.trials: Optional[List[Trial[ConfigClass]]] = None
self.make_all_trials()

def set_output_dir(self, path: Path):
Expand Down Expand Up @@ -534,7 +565,7 @@ def only_list_nodes(d):
assert series_spec is None

@property
def base_config(self) -> T:
def base_config(self) -> ConfigClass:
return self.metadata.configuration["base_config"]

@property
Expand Down Expand Up @@ -565,12 +596,10 @@ def __exit__(self, *args):
# else: do nothing

@overload
def get_changing_fields(self, sep: Literal[None] = None) -> Set[Tuple[Any, ...]]:
...
def get_changing_fields(self, sep: Literal[None] = None) -> Set[Tuple[Any, ...]]: ...

@overload
def get_changing_fields(self, sep: str) -> Set[str]:
...
def get_changing_fields(self, sep: str) -> Set[str]: ...

def get_changing_fields(self, sep: Optional[str] = None) -> Union[Set[Tuple[Any, ...]], Set[str]]:
keys: Set = set()
Expand Down Expand Up @@ -663,11 +692,11 @@ def make_all_trials(self):

nested_update(trial_config_data, trial_update)

trial_config: T
trial_config: ConfigClass
if isinstance(self.base_config, dict):
trial_config = cast(T, trial_config_data)
trial_config = trial_config_data
else:
trial_config = cast(T, from_dict(type(self.base_config), trial_config_data))
trial_config = from_dict(type(self.base_config), trial_config_data)

if i < self.series_skip:
status = "skipped"
Expand All @@ -680,7 +709,7 @@ def make_all_trials(self):
def __iter__(self):
return self.get_all_trials(include_skipped=False)

def get_all_trials(self, *, include_skipped: bool = False) -> Generator[Trial[T], None, None]:
def get_all_trials(self, *, include_skipped: bool = False) -> Generator[Trial[ConfigClass], None, None]:
assert self.trials is not None

if not self.is_singular:
Expand Down
Loading

0 comments on commit 62e4bb1

Please sign in to comment.