Skip to content

Commit

Permalink
Merge pull request #7 from plonerma/improved_field_handling
Browse files Browse the repository at this point in the history
Improved non-init field handling
  • Loading branch information
plonerma authored Nov 18, 2024
2 parents ebb944e + 7703101 commit 577fd4c
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 123 deletions.
9 changes: 3 additions & 6 deletions src/cordage/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from cordage.experiment import Experiment, Series, Trial
from cordage.global_config import GlobalConfig
from cordage.util import ConfigClass, logger, nest_items, nested_update, read_dict_from_file
from cordage.util import from_dict as config_from_dict


class MissingType:
Expand Down Expand Up @@ -225,7 +224,7 @@ def construct_argument_parser(self):
)

self.argument_parser.add_argument(
"--output-dir",
"--output_dir",
type=Path,
help="Path to use as the output directory.",
default=MISSING,
Expand Down Expand Up @@ -442,10 +441,8 @@ def parse_args(self, args: Optional[List[str]] = None) -> Experiment:
# series skip might be given via the command line
# ("--series-skip <n>") or a config file "__series-skip__"
series_kw["series_skip"] = argument_data.pop(self.global_config._series_skip_key, None)

series_kw["base_config"] = config_from_dict(
self.main_config_cls, argument_data, strict=self.global_config.strict_mode
)
series_kw["base_config"] = argument_data
series_kw["config_cls"] = self.main_config_cls

series: Series = Series(**series_kw)

Expand Down
134 changes: 60 additions & 74 deletions src/cordage/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Set,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
overload,
Expand Down Expand Up @@ -58,22 +57,13 @@
ConfigClass = TypeVar("ConfigClass", bound="DataclassInstance")


class SeriesConfiguration(TypedDict):
base_config: "DataclassInstance"
series_spec: Union[List[Dict], Dict[str, List], None]
series_skip: Optional[int]


Configuration = TypeVar("Configuration", SeriesConfiguration, "DataclassInstance")


@dataclass
class Metadata(Generic[Configuration]):
class Metadata:
function: str

global_config: GlobalConfig

configuration: Union[Configuration, Dict[str, Any]]
configuration: Dict[str, Any]

output_dir: Optional[Path] = None
status: str = "pending"
Expand Down Expand Up @@ -108,12 +98,12 @@ def from_dict(cls, data: Mapping):
return from_dict(cls, data)


class MetadataStore(Generic[Configuration]):
class MetadataStore:
_warned_deprecated_nested_global_config: bool = False

def __init__(
self,
metadata: Optional[Metadata[Configuration]] = None,
metadata: Optional[Metadata] = None,
/,
global_config: Optional[GlobalConfig] = None,
**kw,
Expand Down Expand Up @@ -244,7 +234,7 @@ def load_metadata(cls, path: PathLike) -> Metadata:
return metadata


class Annotatable(MetadataStore[Configuration]):
class Annotatable(MetadataStore):
TAG_PATTERN = re.compile(r"\B#(\w*[a-zA-Z]+\w*)")

def __init__(self, *args, **kw):
Expand Down Expand Up @@ -297,10 +287,11 @@ def load_annotations(self):
self.annotations = json.load(fp)


class Experiment(Annotatable[Configuration]):
def __init__(self, *args, **kwargs):
class Experiment(Annotatable):
def __init__(self, *args, config_cls: Optional[Type] = None, **kwargs):
super().__init__(*args, **kwargs)

self.config_cls = config_cls
self.log_handlers: List[logging.Handler] = []

def __repr__(self):
Expand Down Expand Up @@ -329,6 +320,7 @@ def start(self):
Set start time, create output directory, registers run, etc.
"""
assert self.config_cls is not None
assert self.status == "pending", f"{self.__class__.__name__} has already been started."
self.metadata.start_time = datetime.now(timezone.utc).astimezone()
self.metadata.status = "running"
Expand Down Expand Up @@ -404,18 +396,10 @@ def from_path(cls, path: PathLike, config_cls: Optional[Type[ConfigClass]] = Non

experiment: Experiment
if not metadata.is_series:
if config_cls is not None:
metadata.configuration = from_dict(config_cls, metadata.configuration)

experiment = Trial(metadata)
experiment = Trial(metadata, config_cls=config_cls)

else:
if config_cls is not None:
metadata.configuration["base_config"] = from_dict(
config_cls, metadata.configuration["base_config"]
)

experiment = Series(metadata)
experiment = Series(metadata, config_cls=config_cls)

experiment.load_annotations()

Expand Down Expand Up @@ -500,58 +484,79 @@ def teardown_log(self):
logger.removeHandler(handler)


class Trial(Experiment["DataclassInstance"], Generic[ConfigClass]):
class Trial(Experiment, Generic[ConfigClass]):
def __init__(
self,
metadata: Optional[Metadata["DataclassInstance"]] = None,
metadata: Optional[Metadata] = None,
/,
config: Optional[ConfigClass] = None,
config: Optional[Dict[str, Any]] = None,
config_cls=None,
**kw,
):
if metadata is not None:
if len(kw) == 0 and config is None:
super().__init__(metadata)
super().__init__(metadata, config_cls=config_cls)
else:
msg = "If metadata are provided, config and additional keywords can not be set."
raise TypeError(msg)
else:
super().__init__(configuration=config, **kw)
super().__init__(configuration=config, config_cls=config_cls, **kw)

self._config: Optional[ConfigClass] = None

@property
def config(self) -> ConfigClass:
if isinstance(self.metadata.configuration, dict):
msg = (
"`trial.config` is only available if the configuration was loaded with a "
"configuration dataclass (you could use `trial.metadata.configuration` instead)."
if self._config is None:
if self.config_cls is None:
msg = (
"`trial.config` is only available if the configuration was loaded with a "
"configuration dataclass. You could use `trial.metadata.configuration` "
"instead or pass `config_cls` to the trial initializer."
)
raise AttributeError(msg)

if self.metadata.output_dir is not None:
self.set_output_dir(self.metadata.output_dir)

# Create the config object
self._config = from_dict(
self.config_cls,
self.metadata.configuration,
strict=self.metadata.global_config.strict_mode,
)
raise AttributeError(msg)
return self.metadata.configuration

return self._config

def set_output_dir(self, path: Path):
super().set_output_dir(path)

output_dir_type = config_output_dir_type(
self.config, self.global_config.param_name_output_dir
)
if output_dir_type is not None:
self.config.output_dir = output_dir_type(path) # type: ignore
if self.config_cls is not None:
output_dir_type = config_output_dir_type(
self.config_cls, self.global_config.param_name_output_dir
)

if output_dir_type is not None:
self.metadata.configuration["output_dir"] = path

logging.info("Config: %s", self.metadata.configuration)
# Config has attribute output_dir, mypy does not know it
if self._config is not None:
self.config.output_dir = output_dir_type(path) # type: ignore


class Series(Generic[ConfigClass], Experiment[SeriesConfiguration]):
class Series(Generic[ConfigClass], Experiment):
def __init__(
self,
metadata: Optional[Metadata] = None,
/,
base_config: Optional[ConfigClass] = None,
base_config: Optional[Dict[str, Any]] = None,
series_spec: Union[List[Dict], Dict[str, List], None] = None,
series_skip: Optional[int] = None,
config_cls=None,
**kw,
):
if metadata is not None:
assert len(kw) == 0 and base_config is None and series_spec is None
super().__init__(metadata)
super().__init__(metadata, config_cls=config_cls)
else:
if isinstance(series_spec, list):
series_spec = [
Expand All @@ -565,6 +570,7 @@ def __init__(
"series_spec": series_spec,
"series_skip": series_skip,
},
config_cls=config_cls,
**kw,
)

Expand All @@ -573,14 +579,6 @@ def __init__(
self.trials: Optional[List[Trial[ConfigClass]]] = None
self.make_all_trials()

def set_output_dir(self, path: Path):
super().set_output_dir(path)
output_dir_type = config_output_dir_type(
self.base_config, self.global_config.param_name_output_dir
)
if output_dir_type is not None:
self.base_config.output_dir = output_dir_type(path) # type: ignore

def validate_series_spec(self):
series_spec = self.series_spec

Expand All @@ -604,7 +602,7 @@ def only_list_nodes(d):
assert series_spec is None

@property
def base_config(self) -> ConfigClass:
def base_config(self) -> Dict[str, Any]:
return self.metadata.configuration["base_config"]

@property
Expand Down Expand Up @@ -704,7 +702,7 @@ def make_trial(self, **kw):
assert isinstance(additional_info, dict)
trial_metadata.additional_info.update(additional_info)

return Trial(trial_metadata)
return Trial(trial_metadata, config_cls=self.config_cls)

def make_all_trials(self):
if self.series_spec is None:
Expand All @@ -726,31 +724,19 @@ def make_all_trials(self):
self.trials = []

for i, trial_update in enumerate(self.get_trial_updates()):
trial_config_data: Dict[str, Any]

if isinstance(self.base_config, dict):
trial_config_data = deepcopy(self.base_config)
else:
trial_config_data = to_dict(self.base_config)
trial_configuration: Dict[str, Any] = deepcopy(self.base_config)

logger.debug("Base configuration: %s", str(trial_config_data))
logger.debug("Trial update: %s", str(trial_update))

nested_update(trial_config_data, trial_update)

trial_config: ConfigClass
if isinstance(self.base_config, dict):
trial_config = trial_config_data
else:
trial_config = from_dict(type(self.base_config), trial_config_data)
nested_update(trial_configuration, trial_update)

if i < self.series_skip:
status = "skipped"
else:
status = "pending"

trial = self.make_trial(
configuration=trial_config, additional_info={"trial_index": i}, status=status
configuration=trial_configuration,
additional_info={"trial_index": i},
status=status,
)
self.trials.append(trial)

Expand Down
2 changes: 1 addition & 1 deletion src/cordage/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def to_file(dataclass_instance, path: PathLike):


def config_output_dir_type(
config_cls: Any, param_name_output_dir
config_cls: Type["DataclassInstance"], param_name_output_dir: str
) -> Union[Type[str], Type[Path], None]:
for field in dataclasses.fields(config_cls):
if field.name == param_name_output_dir:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def func(config: SimpleConfig):

experiment = cordage.run(
func,
args=["--output-dir", str(tmp_path / "some_specific_output_dir")],
args=["--output_dir", str(tmp_path / "some_specific_output_dir")],
global_config=global_config,
)

Expand All @@ -31,7 +31,7 @@ def func(config: SimpleConfig, cordage_trial, output_dir): # noqa: ARG001
cordage.run(
func,
args=[
"--output-dir",
"--output_dir",
str(tmp_path / "some_specific_output_dir"),
str(resources_path / "series_simple.yaml"),
],
Expand Down
21 changes: 21 additions & 0 deletions tests/test_field_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,24 @@ def func(config: NonInitConfig):
# Invoking the command with b should make the parser exit
with pytest.raises(SystemExit):
cordage.run(func, args=["--a", "2", "--b", "3"], global_config=global_config)


def test_non_init_field_series(global_config, resources_path):
@dataclass
class NonInitConfig:
a: int = -1
b: float = field(init=False)

def __post_init__(self):
self.b = float(self.a)

def func(config: NonInitConfig):
assert int(config.b) == config.a

cordage.run(
func,
args=[
str(resources_path / "series_simple.yaml"),
],
global_config=global_config,
)
2 changes: 1 addition & 1 deletion tests/test_misc_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def func(config: SimpleConfig):

assert isinstance(trial, Trial)

series = context.from_configuration(base_config=SimpleConfig(), series_spec={"a": [1, 2, 3]})
series = context.from_configuration(base_config={}, series_spec={"a": [1, 2, 3]})

assert isinstance(series, Series)

Expand Down
Loading

0 comments on commit 577fd4c

Please sign in to comment.