From d6f96e9793d33b916c055c927c20058f9fc3125d Mon Sep 17 00:00:00 2001 From: Fabian Zills <46721498+PythonFZ@users.noreply.github.com> Date: Tue, 14 May 2024 16:41:19 +0200 Subject: [PATCH] add `zntrack.apply` (#798) * add "zntrack.apply" * rename class name * update docstrings * fix eager --- tests/integration/test_apply.py | 28 ++++++++++++++++++++++++++++ zntrack/__init__.py | 2 ++ zntrack/cli/__init__.py | 18 ++++++++++++++++-- zntrack/core/node.py | 8 +++++++- zntrack/examples/__init__.py | 4 ++++ zntrack/project/zntrack_project.py | 5 ++++- zntrack/utils/apply.py | 23 +++++++++++++++++++++++ 7 files changed, 84 insertions(+), 4 deletions(-) create mode 100644 tests/integration/test_apply.py create mode 100644 zntrack/utils/apply.py diff --git a/tests/integration/test_apply.py b/tests/integration/test_apply.py new file mode 100644 index 00000000..fc9a953e --- /dev/null +++ b/tests/integration/test_apply.py @@ -0,0 +1,28 @@ +"""Test the apply function.""" + +import pytest + +import zntrack.examples + + +@pytest.mark.parametrize("eager", [True, False]) +def test_apply(proj_path, eager) -> None: + """Test the "zntrack.apply" function.""" + project = zntrack.Project() + + JoinedParamsToOuts = zntrack.apply(zntrack.examples.ParamsToOuts, "join") + + with project: + a = zntrack.examples.ParamsToOuts(params=["a", "b"]) + b = JoinedParamsToOuts(params=["a", "b"]) + c = zntrack.apply(zntrack.examples.ParamsToOuts, "join")(params=["a", "b", "c"]) + + project.run(eager=eager) + + a.load() + b.load() + c.load() + + assert a.outs == ["a", "b"] + assert b.outs == "a-b" + assert c.outs == "a-b-c" diff --git a/zntrack/__init__.py b/zntrack/__init__.py index fc4c6d28..99d4c665 100644 --- a/zntrack/__init__.py +++ b/zntrack/__init__.py @@ -24,6 +24,7 @@ ) from zntrack.project import Project from zntrack.utils import config +from zntrack.utils.apply import apply from zntrack.utils.node_wd import nwd __version__ = importlib.metadata.version("zntrack") @@ -45,6 +46,7 @@ "exceptions", "from_rev", "get_nodes", + "apply", ] __all__ += [ diff --git a/zntrack/cli/__init__.py b/zntrack/cli/__init__.py index f5e8b219..fe62aeb1 100644 --- a/zntrack/cli/__init__.py +++ b/zntrack/cli/__init__.py @@ -47,10 +47,23 @@ def main( @app.command() -def run(node: str, name: str = None, meta_only: bool = False) -> None: +def run( + node: str, name: str = None, meta_only: bool = False, method: str = "run" +) -> None: """Execute a ZnTrack Node. Use as 'zntrack run module.Node --name node_name'. + + Arguments: + --------- + node : str + The node to run. + name : str + The name of the node. + meta_only : bool + Save only the metadata. + method : str, default 'run' + The method to run on the node. """ env_file = pathlib.Path("env.yaml") if env_file.exists(): @@ -80,7 +93,8 @@ def run(node: str, name: str = None, meta_only: bool = False) -> None: node: Node = cls.from_rev(name=name, results=False) node.save(meta_only=True) if not meta_only: - node.run() + # dynamic version of node.run() + getattr(node, method)() node.save(parameter=False) else: raise ValueError(f"Node {node} is not a ZnTrack Node.") diff --git a/zntrack/core/node.py b/zntrack/core/node.py index e05a90fc..4b8ff292 100644 --- a/zntrack/core/node.py +++ b/zntrack/core/node.py @@ -443,7 +443,13 @@ def get_dvc_cmd( cmd += ["--outs", f"{(get_nwd(node) /'node-meta.json').as_posix()}"] module = module_handler(node.__class__) - cmd += [f"zntrack run {module}.{node.__class__.__name__} --name {node.name}"] + + zntrack_run = f"zntrack run {module}.{node.__class__.__name__} --name {node.name}" + if hasattr(node, "_method"): + zntrack_run += f" --method {node._method}" + + cmd += [zntrack_run] + optionals = [x for x in optionals if x] # remove empty entries [] return [cmd] + optionals diff --git a/zntrack/examples/__init__.py b/zntrack/examples/__init__.py index d5b848c1..aab5b42e 100644 --- a/zntrack/examples/__init__.py +++ b/zntrack/examples/__init__.py @@ -23,6 +23,10 @@ def run(self) -> None: """Save params to outs.""" self.outs = self.params + def join(self) -> None: + """Join the results.""" + self.outs = "-".join(self.params) + class ParamsToMetrics(zntrack.Node): """Save params to metrics.""" diff --git a/zntrack/project/zntrack_project.py b/zntrack/project/zntrack_project.py index bfd824d8..e1338583 100644 --- a/zntrack/project/zntrack_project.py +++ b/zntrack/project/zntrack_project.py @@ -287,7 +287,10 @@ def run( # update connectors log.info(f"Running node {node}") self.graph._update_node_attributes(node, UpdateConnectors()) - node.run() + if hasattr(node, "_method"): + getattr(node, node._method)() + else: + node.run() if save: node.save() node.state.loaded = True diff --git a/zntrack/utils/apply.py b/zntrack/utils/apply.py new file mode 100644 index 00000000..ebb91110 --- /dev/null +++ b/zntrack/utils/apply.py @@ -0,0 +1,23 @@ +"""Zntrack apply module for custom "run" methods.""" + +import typing as t + +o = t.TypeVar("o") + + +def apply(obj: o, method: str) -> o: + """Return a new object like "o" which has the method string attached.""" + + class MockInheritanceClass(obj): + """Copy of the original class with the new method attribute. + + We can not set the method directly on the original class, because + it would be used by all the other instances of the class as well. + """ + + _method = method + + MockInheritanceClass.__module__ = obj.__module__ + MockInheritanceClass.__name__ = obj.__name__ + + return MockInheritanceClass