diff --git a/src/cordage/experiment.py b/src/cordage/experiment.py index 0ce8f8d..3b86fa5 100644 --- a/src/cordage/experiment.py +++ b/src/cordage/experiment.py @@ -1,11 +1,7 @@ -import json import logging -import re import shutil -from collections.abc import Generator, Iterable, Mapping +from collections.abc import Generator from copy import deepcopy -from dataclasses import dataclass, field -from dataclasses import replace as dataclass_replace from datetime import datetime, timezone from itertools import count, product from json.decoder import JSONDecodeError @@ -31,9 +27,8 @@ colorlog = None # type: ignore import typing -from enum import Enum -from cordage.global_config import GlobalConfig +from cordage.metadata import Annotatable, Metadata, Status from cordage.util import ( config_output_dir_type, flattened_items, @@ -41,7 +36,6 @@ logger, nest_items, nested_update, - to_dict, ) if typing.TYPE_CHECKING: @@ -51,256 +45,6 @@ ConfigClass = TypeVar("ConfigClass", bound="DataclassInstance") -class Status(str, Enum): - UNKOWN = "unkown" - PENDING = "pending" - RUNNING = "running" - COMPLETE = "complete" - FAILED = "failed" - ABORTED = "aborted" - SKIPPED = "skipped" - - def __str__(self) -> str: - return self.value - - @property - def has_started(self): - """The pname property.""" - return self not in (self.UNKOWN, self.PENDING) - - -@dataclass -class Metadata: - function: str - - global_config: GlobalConfig - - configuration: dict[str, Any] - - output_dir: Optional[Path] = None - status: Status = Status.UNKOWN - - start_time: Optional[datetime] = None - end_time: Optional[datetime] = None - - result: Any = None - - parent_dir: Optional[Path] = None - - additional_info: dict = field(default_factory=dict) - - @property - def duration(self): - assert self.end_time is not None and self.start_time is not None - - return self.end_time - self.start_time - - def replace(self, **changes): - return dataclass_replace(self, **changes) - - @property - def is_series(self): - return isinstance(self.configuration, dict) and "series_spec" in self.configuration - - def to_dict(self): - return to_dict(self) - - @classmethod - def from_dict(cls, data: Mapping): - return from_dict(cls, data) - - -class MetadataStore: - _warned_deprecated_nested_global_config: bool = False - - def __init__( - self, - metadata: Optional[Metadata] = None, - /, - global_config: Optional[GlobalConfig] = None, - **kw, - ): - self.metadata: Metadata - - if metadata is not None: - if global_config is not None or len(kw) > 0: - msg = "Using the `metadata` argument is incompatible with using other arguments." - raise TypeError(msg) - else: - self.metadata = metadata - else: - if global_config is None: - global_config = GlobalConfig() - - self.metadata = Metadata(global_config=global_config, **kw) - - @property - def global_config(self) -> GlobalConfig: - return self.metadata.global_config - - @property - def output_dir(self) -> Path: - if self.metadata.output_dir is None: - msg = f"{self.__class__.__name__} has not been started yet." - raise RuntimeError(msg) - else: - return self.metadata.output_dir - - @property - def parent_dir(self) -> Optional[Path]: - return self.metadata.parent_dir - - def set_output_dir(self, path: Path): - self.metadata.output_dir = path - - def create_output_dir(self): - if self.metadata.output_dir is not None: - self.output_dir.mkdir(parents=True, exist_ok=True) - self.set_output_dir(self.output_dir) - return self.output_dir - - tried_paths: set[Path] = set() - suffix = "" - - for i in count(1): - if i > 1: - level = floor(log10(i) / 2) + 1 - suffix = "_" * level + str(i).zfill(2 * level) - - path = ( - self.global_config.base_output_dir - / self.global_config.output_dir_format.format( - **self.metadata.__dict__, - collision_suffix=suffix, - ) - ) - - if path in tried_paths: - # suffix was already tried: assume that further tries - # wont resolve this collision - msg = f"Path {path} does already exist - collision could not be avoided." - raise RuntimeError(msg) - - try: - path.mkdir(parents=True, exist_ok=False) - self.set_output_dir(path) - return path - except FileExistsError: - if self.global_config.overwrite_existing: - logger.warning( - "Path %s does existing. Replacing directory with new one.", str(path) - ) - shutil.rmtree(path) - path.mkdir(parents=True) - self.set_output_dir(path) - return path - else: - tried_paths.add(path) - - @property - def metadata_path(self): - return self.output_dir / "cordage.json" - - def save_metadata(self): - md_dict = self.metadata.to_dict() - - with open(self.metadata_path, "w", encoding="utf-8") as fp: - - def invalid_obj_default(obj): - logger.warning("Cannot serialize %s", str(obj)) - - json.dump(md_dict, fp, indent=4, default=invalid_obj_default) - - @classmethod - def _convert_old_global_config_to_new(cls, d: dict[str, Any]) -> dict[str, Any]: - if "logging" in d["global_config"]: - if not cls._warned_deprecated_nested_global_config: - logger.warning("Using deprecated nested global_config format.") - cls._warned_deprecated_nested_global_config = True - - for k, v in d["global_config"]["logging"].items(): - d["global_config"][f"logging_{k}"] = v - - for k, v in d["global_config"]["param_names"].items(): - d["global_config"][f"param_name_{k}"] = v - - del d["global_config"]["logging"] - del d["global_config"]["param_names"] - return d - - @classmethod - def load_metadata(cls, path: PathLike) -> Metadata: - path = Path(path) - if not path.suffix == ".json": - path = path / "cordage.json" - - with path.open("r", encoding="utf-8") as fp: - metadata_dict = cls._convert_old_global_config_to_new(json.load(fp)) - metadata = Metadata.from_dict(metadata_dict) - - if metadata.output_dir != path.parent: - logger.info( - f"Output dir is not correct anymore. Changing it to the actual directory" - f"({metadata.output_dir} -> {path.parent})" - ) - metadata.output_dir = path.parent - - return metadata - - -class Annotatable(MetadataStore): - TAG_PATTERN = re.compile(r"\B#(\w*[a-zA-Z]+\w*)") - - def __init__(self, *args, **kw): - super().__init__(*args, **kw) - - self.annotations = {} - - @property - def tags(self): - tags = set(self.explicit_tags) - - # implicit tags - tags.update(re.findall(self.TAG_PATTERN, self.comment)) - - return list(tags) - - @property - def explicit_tags(self): - if "tags" not in self.annotations: - self.annotations["tags"] = [] - return self.annotations["tags"] - - def add_tag(self, *tags: Iterable): - for t in tags: - if t not in self.explicit_tags: - self.explicit_tags.append(t) - - def has_tag(self, *tags: str): - return len(tags) == 0 or any(t in tags for t in self.tags) - - @property - def comment(self): - return self.annotations.get("comment", "") or "" - - @comment.setter - def comment(self, value): - self.annotations["comment"] = value - - @property - def annotations_path(self): - return self.output_dir / "annotations.json" - - def save_annotations(self): - with open(self.annotations_path, "w", encoding="utf-8") as fp: - json.dump(self.annotations, fp, indent=4) - - def load_annotations(self): - if self.annotations_path.exists(): - with self.annotations_path.open("r", encoding="utf-8") as fp: - self.annotations = json.load(fp) - - class Experiment(Annotatable): def __init__(self, *args, config_cls: Optional[type] = None, **kwargs): super().__init__(*args, **kwargs) @@ -510,6 +254,50 @@ def teardown_log(self): handler.close() logger.removeHandler(handler) + def create_output_dir(self): + if self.metadata.output_dir is not None: + self.output_dir.mkdir(parents=True, exist_ok=True) + self.set_output_dir(self.output_dir) + return self.output_dir + + tried_paths: set[Path] = set() + suffix = "" + + for i in count(1): + if i > 1: + level = floor(log10(i) / 2) + 1 + suffix = "_" * level + str(i).zfill(2 * level) + + path = ( + self.global_config.base_output_dir + / self.global_config.output_dir_format.format( + **self.metadata.__dict__, + collision_suffix=suffix, + ) + ) + + if path in tried_paths: + # suffix was already tried: assume that further tries + # wont resolve this collision + msg = f"Path {path} does already exist - collision could not be avoided." + raise RuntimeError(msg) + + try: + path.mkdir(parents=True, exist_ok=False) + self.set_output_dir(path) + return path + except FileExistsError: + if self.global_config.overwrite_existing: + logger.warning( + "Path %s does existing. Replacing directory with new one.", str(path) + ) + shutil.rmtree(path) + path.mkdir(parents=True) + self.set_output_dir(path) + return path + else: + tried_paths.add(path) + class Trial(Experiment, Generic[ConfigClass]): def __init__( @@ -571,6 +359,8 @@ def set_output_dir(self, path: Path): class Series(Generic[ConfigClass], Experiment): + trials: list[Trial[ConfigClass]] + def __init__( self, metadata: Optional[Metadata] = None, @@ -602,8 +392,6 @@ def __init__( ) self.validate_series_spec() - - self.trials: Optional[list[Trial[ConfigClass]]] = None self.make_all_trials() def validate_series_spec(self): @@ -698,25 +486,27 @@ def get_trial_updates(self) -> Generator[dict, None, None]: else: yield {} - def __len__(self): + def _derive_len(self) -> int: if isinstance(self.series_spec, list): - assert self.trials is None or len(self.trials) == len(self.series_spec), ( - f"Number of existing ({len(self.trials)}) and expected trials " - f"({len(self.series_spec)}) do not match." - ) return len(self.series_spec) elif isinstance(self.series_spec, dict): num_trials = 1 for _, values in flattened_items(self.series_spec): num_trials *= len(values) - assert self.trials is None or len(self.trials) == num_trials, ( - f"Number of existing ({len(self.trials)}) and expected trials ({num_trials}) do " - "not match." - ) return num_trials else: return 1 + def __len__(self) -> int: + if len(self.trials) != self._derive_len(): + msg = ( + f"Number of existing ({len(self.trials)}) and expected trials " + f"({self._derive_len()}) do not match." + ) + raise RuntimeError(msg) + + return len(self.trials) + def make_trial(self, **kw): additional_info = kw.pop("additional_info", None) @@ -752,7 +542,7 @@ def make_all_trials(self): else: logger.debug( "The given configuration yields an experiment series with %d experiments.", - len(self), + self._derive_len(), ) self.trials = [] diff --git a/src/cordage/global_config.py b/src/cordage/global_config.py index 09f496a..a060bae 100644 --- a/src/cordage/global_config.py +++ b/src/cordage/global_config.py @@ -2,12 +2,14 @@ from datetime import datetime, timezone from os import PathLike from pathlib import Path -from typing import Union +from typing import Any, Union from cordage.util import from_dict as config_from_dict from cordage.util import from_file as config_from_file from cordage.util import logger +_warned_deprecated_nested_global_config: bool = False + @dataclass class GlobalConfig: @@ -110,3 +112,22 @@ def resolve(cls, global_config: Union[str, PathLike, dict, "GlobalConfig", None] else: msg = "`global_config` must be one of str, PathLike, dict, cordage.GlobalConfig, None" raise TypeError(msg) + + @classmethod + def _convert_old_to_new(cls, d: dict[str, Any]) -> dict[str, Any]: + global _warned_deprecated_nested_global_config # noqa: PLW0603 + + if "logging" in d["global_config"]: + if not _warned_deprecated_nested_global_config: + logger.warning("Using deprecated nested global_config format.") + _warned_deprecated_nested_global_config = True + + for k, v in d["global_config"]["logging"].items(): + d["global_config"][f"logging_{k}"] = v + + for k, v in d["global_config"]["param_names"].items(): + d["global_config"][f"param_name_{k}"] = v + + del d["global_config"]["logging"] + del d["global_config"]["param_names"] + return d diff --git a/src/cordage/metadata.py b/src/cordage/metadata.py new file mode 100644 index 0000000..e7bd896 --- /dev/null +++ b/src/cordage/metadata.py @@ -0,0 +1,221 @@ +import json +import re +from collections.abc import Iterable, Mapping +from dataclasses import dataclass, field +from dataclasses import replace as dataclass_replace +from datetime import datetime +from os import PathLike +from pathlib import Path +from typing import ( + Any, + Optional, + TypeVar, +) + +try: + import colorlog +except ImportError: + colorlog = None # type: ignore + +import typing +from enum import Enum + +from cordage.global_config import GlobalConfig +from cordage.util import ( + from_dict, + logger, + to_dict, +) + +if typing.TYPE_CHECKING: + from _typeshed import DataclassInstance + + +ConfigClass = TypeVar("ConfigClass", bound="DataclassInstance") + + +class Status(str, Enum): + UNKOWN = "unkown" + PENDING = "pending" + RUNNING = "running" + COMPLETE = "complete" + FAILED = "failed" + ABORTED = "aborted" + SKIPPED = "skipped" + + def __str__(self) -> str: + return self.value + + @property + def has_started(self): + """The pname property.""" + return self not in (self.UNKOWN, self.PENDING) + + +@dataclass +class Metadata: + function: str + + global_config: GlobalConfig + + configuration: dict[str, Any] + + output_dir: Optional[Path] = None + status: Status = Status.UNKOWN + + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None + + result: Any = None + + parent_dir: Optional[Path] = None + + additional_info: dict = field(default_factory=dict) + + @property + def duration(self): + assert self.end_time is not None and self.start_time is not None + + return self.end_time - self.start_time + + def replace(self, **changes): + return dataclass_replace(self, **changes) + + @property + def is_series(self): + return isinstance(self.configuration, dict) and "series_spec" in self.configuration + + def to_dict(self): + return to_dict(self) + + @classmethod + def from_dict(cls, data: Mapping): + return from_dict(cls, data) + + +class MetadataStore: + def __init__( + self, + metadata: Optional[Metadata] = None, + /, + global_config: Optional[GlobalConfig] = None, + **kw, + ): + self.metadata: Metadata + + if metadata is not None: + if global_config is not None or len(kw) > 0: + msg = "Using the `metadata` argument is incompatible with using other arguments." + raise TypeError(msg) + else: + self.metadata = metadata + else: + if global_config is None: + global_config = GlobalConfig() + + self.metadata = Metadata(global_config=global_config, **kw) + + @property + def global_config(self) -> GlobalConfig: + return self.metadata.global_config + + @property + def output_dir(self) -> Path: + if self.metadata.output_dir is None: + msg = f"{self.__class__.__name__} has not been started yet." + raise RuntimeError(msg) + else: + return self.metadata.output_dir + + @property + def parent_dir(self) -> Optional[Path]: + return self.metadata.parent_dir + + def set_output_dir(self, path: Path): + self.metadata.output_dir = path + + @property + def metadata_path(self): + return self.output_dir / "cordage.json" + + def save_metadata(self): + md_dict = self.metadata.to_dict() + + with open(self.metadata_path, "w", encoding="utf-8") as fp: + + def invalid_obj_default(obj): + logger.warning("Cannot serialize %s", str(obj)) + + json.dump(md_dict, fp, indent=4, default=invalid_obj_default) + + @classmethod + def load_metadata(cls, path: PathLike) -> Metadata: + path = Path(path) + if not path.suffix == ".json": + path = path / "cordage.json" + + with path.open("r", encoding="utf-8") as fp: + metadata_dict = GlobalConfig._convert_old_to_new(json.load(fp)) + metadata = Metadata.from_dict(metadata_dict) + + if metadata.output_dir != path.parent: + logger.info( + f"Output dir is not correct anymore. Changing it to the actual directory" + f"({metadata.output_dir} -> {path.parent})" + ) + metadata.output_dir = path.parent + + return metadata + + +class Annotatable(MetadataStore): + TAG_PATTERN = re.compile(r"\B#(\w*[a-zA-Z]+\w*)") + + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + + self.annotations = {} + + @property + def tags(self): + tags = set(self.explicit_tags) + + # implicit tags + tags.update(re.findall(self.TAG_PATTERN, self.comment)) + + return list(tags) + + @property + def explicit_tags(self): + if "tags" not in self.annotations: + self.annotations["tags"] = [] + return self.annotations["tags"] + + def add_tag(self, *tags: Iterable): + for t in tags: + if t not in self.explicit_tags: + self.explicit_tags.append(t) + + def has_tag(self, *tags: str): + return len(tags) == 0 or any(t in tags for t in self.tags) + + @property + def comment(self): + return self.annotations.get("comment", "") or "" + + @comment.setter + def comment(self, value): + self.annotations["comment"] = value + + @property + def annotations_path(self): + return self.output_dir / "annotations.json" + + def save_annotations(self): + with open(self.annotations_path, "w", encoding="utf-8") as fp: + json.dump(self.annotations, fp, indent=4) + + def load_annotations(self): + if self.annotations_path.exists(): + with self.annotations_path.open("r", encoding="utf-8") as fp: + self.annotations = json.load(fp)