Skip to content

Commit

Permalink
add auto_remove and fix NodeNotAvailable error (#751)
Browse files Browse the repository at this point in the history
* add `auto_remove` and fix `NodeNotAvailable` error

* bugfix + indent
  • Loading branch information
PythonFZ authored Dec 19, 2023
1 parent 3d5b44f commit ac3938b
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 12 deletions.
4 changes: 2 additions & 2 deletions tests/integration/test_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_external(proj_path):
node = zntrack.from_rev(
"HelloWorld",
remote="https://github.com/PythonFZ/ZnTrackExamples.git",
rev="fbb6ada",
rev="890c714",
)

with zntrack.Project() as proj:
Expand All @@ -37,7 +37,7 @@ def test_external_grp(proj_path):
node = zntrack.from_rev(
"HelloWorld",
remote="https://github.com/PythonFZ/ZnTrackExamples.git",
rev="fbb6ada",
rev="890c714",
)

proj = zntrack.Project()
Expand Down
10 changes: 5 additions & 5 deletions tests/integration/test_from_rev.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def test_import_from_remote(proj_path):
node = zntrack.from_rev(
"HelloWorld",
remote="https://github.com/PythonFZ/ZnTrackExamples.git",
rev="fbb6ada",
rev="890c714",
)
assert node.max_number == 512
assert node.random_number == 123
assert node.name == "HelloWorld"
assert node.state.rev == "fbb6ada"
assert node.state.rev == "890c714"
assert node.state.remote == "https://github.com/PythonFZ/ZnTrackExamples.git"
assert node.state.results == NodeStatusResults.AVAILABLE
assert node.uuid == uuid.UUID("1d2d5eef-c42b-4ff4-aa1f-837638fdf090")
Expand All @@ -44,13 +44,13 @@ def test_connect_from_remote(proj_path):
node_a = zntrack.from_rev(
"HelloWorld",
remote="https://github.com/PythonFZ/ZnTrackExamples.git",
rev="fbb6ada",
rev="890c714",
)

node_b = zntrack.from_rev(
"HelloWorld",
remote="https://github.com/PythonFZ/ZnTrackExamples.git",
rev="35d35ff",
rev="369fe8f",
)

assert node_a.random_number == 123
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_two_nodes_connect_external(proj_path):
node_a = zntrack.from_rev(
"HelloWorld",
remote="https://github.com/PythonFZ/ZnTrackExamples.git",
rev="fbb6ada",
rev="890c714",
)

with zntrack.Project(automatic_node_names=True) as project:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_fs_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_patch_list(proj_path):
node = zntrack.from_rev(
"HelloWorld",
remote="https://github.com/PythonFZ/ZnTrackExamples.git",
rev="fbb6ada",
rev="890c714",
)

def func(self, path):
Expand Down
26 changes: 26 additions & 0 deletions tests/integration/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,29 @@ def test_git_only_repo(proj_path, git_only_repo):
else:
# check if node-meta.json is not in the repo index
assert ("nodes/ParamsToOuts/node-meta.json", 0) not in repo.index.entries.keys()


def test_auto_remove(proj_path):
with zntrack.Project(automatic_node_names=True) as project:
n1 = zntrack.examples.ParamsToOuts(params="Lorem Ipsum")
n2 = zntrack.examples.ParamsToOuts(params="Dolor Sit")

project.run()

n1 = zntrack.examples.ParamsToOuts.from_rev(n1.name)
n2 = zntrack.examples.ParamsToOuts.from_rev(n2.name)
assert n1.outs == "Lorem Ipsum"
assert n2.outs == "Dolor Sit"

repo = git.Repo()
repo.git.add(".")
repo.index.commit("initial commit")

with zntrack.Project(automatic_node_names=True) as project:
n1 = zntrack.examples.ParamsToOuts(params="Hello World")

project.run(auto_remove=True)

n1 = zntrack.examples.ParamsToOuts.from_rev(n1.name)
with pytest.raises(zntrack.exceptions.NodeNotAvailableError):
n2 = zntrack.examples.ParamsToOuts.from_rev(n2.name)
6 changes: 6 additions & 0 deletions zntrack/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ def load(self, lazy: bool = None, results: bool = True) -> None:
self._uuid = uuid.UUID(node_meta["uuid"])
self.state.results = NodeStatusResults.AVAILABLE
# TODO: documentation about _post_init and _post_load_ and when they are called

zntrack_config = json.loads(self.state.fs.read_text(config.files.zntrack))

if self.name not in zntrack_config:
raise exceptions.NodeNotAvailableError(self)

self._post_load_()

@classmethod
Expand Down
46 changes: 42 additions & 4 deletions zntrack/project/zntrack_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from zntrack import exceptions
from zntrack.core.node import Node, get_dvc_cmd
from zntrack.utils import NodeName, config, run_dvc_cmd
from zntrack.utils.cli import get_groups

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -171,6 +172,38 @@ def group(self, *names: typing.List[str]):
if not node._external_:
node.__dict__["nwd"] = grp.nwd / node._name_.get_name_without_groups()

def auto_remove(self, remove_empty_dirs=True):
"""Remove all nodes from 'dvc.yaml' that are not in the graph."""
_, dvc_node_names = get_groups(None, None)
graph_node_names = [self.graph.nodes[x]["value"].name for x in self.graph.nodes]

nodes_to_remove = []

for node_name in dvc_node_names:
if node_name not in graph_node_names:
if "+" in node_name:
# currently there is no way to remove the zntrack.deps Nodes correctly
# so we check for the parent node, if that is not available, we remove
# the node
continue
else:
nodes_to_remove.append(node_name)

if len(nodes_to_remove):
zntrack_config = json.loads(config.files.zntrack.read_text())

for node_name in tqdm.tqdm(nodes_to_remove):
run_dvc_cmd(["remove", node_name, "--outs"])
_ = zntrack_config.pop(node_name, None)

config.files.zntrack.write_text(json.dumps(zntrack_config, indent=4))

if remove_empty_dirs:
# remove all empty directories inside "nodes"
for path in pathlib.Path("nodes").glob("**/*"):
if path.is_dir() and not any(path.iterdir()):
path.rmdir()

def run(
self,
eager=False,
Expand All @@ -179,6 +212,7 @@ def run(
save: bool = True,
environment: dict = None,
nodes: list = None,
auto_remove: bool = False,
):
"""Run the Project Graph.
Expand All @@ -200,6 +234,9 @@ def run(
A dictionary of environment variables for all nodes.
nodes : list, default = None
A list of node names to run. If None, run all nodes.
auto_remove : bool, default = False
If True, remove all nodes from 'dvc.yaml' that are not in the graph.
This is the same as calling 'project.auto_remove()'
"""
if not save and not eager:
raise ValueError("Save can only be false if eager is True")
Expand Down Expand Up @@ -258,11 +295,12 @@ def run(
self.repro()
# TODO should we load the nodes here? Maybe, if lazy loading is implemented.

def build(
self, environment: dict = None, optional: dict = None, nodes: list = None
) -> None:
if auto_remove:
self.auto_remove()

def build(self, **kwargs) -> None:
"""Build the project graph without running it."""
self.run(repro=False, environment=environment, optional=optional, nodes=nodes)
self.run(repro=False, **kwargs)

def repro(self) -> None:
"""Run dvc repro."""
Expand Down

0 comments on commit ac3938b

Please sign in to comment.