diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 12995dd4..952da8ff 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ default_language_version: python: python3 repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-added-large-files - id: check-case-conflict @@ -21,7 +21,7 @@ repos: - id: sort-simple-yaml - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 24.1.1 + rev: 24.4.2 hooks: - id: black additional_dependencies: [".[jupyter]"] @@ -31,13 +31,13 @@ repos: hooks: - id: isort - repo: https://github.com/codespell-project/codespell - rev: v2.2.6 + rev: v2.3.0 hooks: - id: codespell additional_dependencies: ["tomli"] - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: 'v0.1.14' + rev: 'v0.4.8' hooks: - id: ruff args: ['--fix'] diff --git a/tests/integration/function_wrapper/test_single_function.py b/tests/integration/function_wrapper/test_single_function.py index a473fc5c..b70138ad 100644 --- a/tests/integration/function_wrapper/test_single_function.py +++ b/tests/integration/function_wrapper/test_single_function.py @@ -24,18 +24,20 @@ def test_example_func(proj_path): def test_example_func_dry_run(proj_path): script = example_func(dry_run=True) - assert " ".join(script) == " ".join([ - "stage", - "add", - "-n", - "example_func", - "--force", - "--params", - "params.yaml:example_func", - "--outs", - "test.txt", - "zntrack run test_single_function.example_func", - ]) + assert " ".join(script) == " ".join( + [ + "stage", + "add", + "-n", + "example_func", + "--force", + "--params", + "params.yaml:example_func", + "--outs", + "test.txt", + "zntrack run test_single_function.example_func", + ] + ) @zntrack.nodify(outs=[pathlib.Path("test.txt")], params={"text": "Lorem Ipsum"}) diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py index 439a7f63..5041dda8 100644 --- a/tests/integration/test_cli.py +++ b/tests/integration/test_cli.py @@ -141,16 +141,18 @@ def test_list_groups(proj_path, runner): "ParamsToOuts", "ParamsToOuts_1", ], - "nested": [{ - "GRP1": [ - "ParamsToOuts -> nested_GRP1_ParamsToOuts", - "ParamsToOuts_1 -> nested_GRP1_ParamsToOuts_1", - ], - "GRP2": [ - "ParamsToOuts -> nested_GRP2_ParamsToOuts", - "ParamsToOuts_1 -> nested_GRP2_ParamsToOuts_1", - ], - }], + "nested": [ + { + "GRP1": [ + "ParamsToOuts -> nested_GRP1_ParamsToOuts", + "ParamsToOuts_1 -> nested_GRP1_ParamsToOuts_1", + ], + "GRP2": [ + "ParamsToOuts -> nested_GRP2_ParamsToOuts", + "ParamsToOuts_1 -> nested_GRP2_ParamsToOuts_1", + ], + } + ], } groups, _ = utils.cli.get_groups(remote=proj_path, rev=None) diff --git a/zntrack/cli/__init__.py b/zntrack/cli/__init__.py index fe62aeb1..9693620a 100644 --- a/zntrack/cli/__init__.py +++ b/zntrack/cli/__init__.py @@ -64,6 +64,7 @@ def run( Save only the metadata. method : str, default 'run' The method to run on the node. + """ env_file = pathlib.Path("env.yaml") if env_file.exists(): diff --git a/zntrack/core/load.py b/zntrack/core/load.py index 56697a1e..1938e496 100644 --- a/zntrack/core/load.py +++ b/zntrack/core/load.py @@ -58,6 +58,7 @@ def _import_from_tempfile(package_and_module: str, remote, rev): If the module could not be found. FileNotFoundError If the file could not be found. + """ file = pathlib.Path(*package_and_module.split(".")).with_suffix(".py") fs = dvc.api.DVCFileSystem(url=remote, rev=rev) @@ -93,6 +94,7 @@ def from_rev(name, remote=".", rev=None, **kwargs) -> T: ------- Node The loaded node. + """ if isinstance(name, Node): name = name.name diff --git a/zntrack/core/node.py b/zntrack/core/node.py index 4b8ff292..bb875e9b 100644 --- a/zntrack/core/node.py +++ b/zntrack/core/node.py @@ -59,6 +59,7 @@ class NodeStatus: The temporary path used for loading the data. This is only set within the context manager 'use_tmp_path'. If neither 'remote' nor 'rev' are set, tmp_path will not be used. + """ loaded: bool @@ -202,6 +203,7 @@ class Node(zninit.ZnInit, znflow.Node): information about the state of the Node. nwd : pathlib.Path the node working directory. + """ _state: NodeStatus = None @@ -235,6 +237,7 @@ def convert_notebook(cls, nb_name: str = None): ---------- nb_name: str Notebook name when not using config.nb_name (this is not recommended) + """ # TODO this should not be a class method, but a function. jupyter_class_to_file(nb_name=nb_name, module_name=cls.__name__) @@ -324,6 +327,7 @@ def load(self, lazy: bool = None, results: bool = True) -> None: Whether to load the node lazily. If None, the value from the config is used. results : bool, default = True Whether to load the results. If False, only the parameters are loaded. + """ from zntrack.fields.field import Field, FieldGroup diff --git a/zntrack/core/nodify.py b/zntrack/core/nodify.py index 201869a2..c2a1c26f 100644 --- a/zntrack/core/nodify.py +++ b/zntrack/core/nodify.py @@ -35,6 +35,7 @@ class DVCRunOptions: References ---------- https://dvc.org/doc/command-reference/run#options. + """ no_commit: bool @@ -51,6 +52,7 @@ def dvc_args(self) -> list: ------- list: A list of strings for the subprocess call, e.g.: ["--no-commit", "--external"]. + """ out = [] for datacls_field in dataclasses.fields(self): @@ -97,6 +99,7 @@ def prepare_dvc_script( ------- list[str] The list to be passed to the subprocess call. + """ script = ["stage", "add", "-n", node_name] script += dvc_run_option.dvc_args @@ -134,6 +137,7 @@ def check_type( accept None even if not in types. allow_dict: allow for {key: types} + """ if isinstance(obj, (list, tuple, set)) and allow_iterable: for value in obj: @@ -254,6 +258,7 @@ def save_node_config_to_files(cfg: NodeConfig, node_name: str): The NodeConfig object which should be serialized to zntrack.json / params.yaml node_name: str The name of the node, usually func.__name__. + """ for value_name, value in dataclasses.asdict(cfg).items(): if value_name == "params": @@ -339,6 +344,7 @@ def nodify( References ---------- https://dvc.org/doc/command-reference/run#options + """ cfg_ = NodeConfig( outs=outs, diff --git a/zntrack/exceptions/__init__.py b/zntrack/exceptions/__init__.py index ac06962f..e82d5880 100644 --- a/zntrack/exceptions/__init__.py +++ b/zntrack/exceptions/__init__.py @@ -11,6 +11,7 @@ def __init__(self, arg): ---------- arg : str|Node Custom Error message or Node that is not available. + """ if isinstance(arg, str): super().__init__(arg) @@ -33,6 +34,7 @@ def __init__(self, node, field, instance): The 'zn.nodes' field instance : Node The node that contains the 'zn.nodes' field + """ msg = ( f"Can not set '{field.name}' of Node<'{instance.name}'> to" @@ -59,6 +61,7 @@ def __init__(self, node): ---------- node: Node The node that is already on the graph. + """ msg = ( f"Node name '{node.name}' is already used in the graph. Please use" diff --git a/zntrack/fields/dependency.py b/zntrack/fields/dependency.py index 1d49cbb6..2fdf2a1f 100644 --- a/zntrack/fields/dependency.py +++ b/zntrack/fields/dependency.py @@ -96,6 +96,7 @@ def _get_nodes_on_off_graph(self, instance) -> t.Tuple[list, list]: The nodes that are on the graph. off_graph : list The nodes that are off the graph. + """ values = getattr(instance, self.name) # TODO use IterableHandler? diff --git a/zntrack/fields/dvc/options.py b/zntrack/fields/dvc/options.py index 3cb93f47..31515848 100644 --- a/zntrack/fields/dvc/options.py +++ b/zntrack/fields/dvc/options.py @@ -111,6 +111,7 @@ def get_data(self, instance: "Node") -> any: ------- any The value of the field from the configuration file. + """ zntrack_dict = json.loads( instance.state.fs.read_text("zntrack.json"), diff --git a/zntrack/fields/field.py b/zntrack/fields/field.py index 3650a895..ae305a54 100644 --- a/zntrack/fields/field.py +++ b/zntrack/fields/field.py @@ -36,6 +36,7 @@ class Field(zninit.Descriptor, abc.ABC): ---------- dvc_option : str The dvc command option for this field. + """ dvc_option: str = None @@ -49,6 +50,7 @@ def save(self, instance: "Node"): ---------- instance : Node The Node instance to save the field for. + """ raise NotImplementedError @@ -70,6 +72,7 @@ def get_files(self, instance: "Node") -> list: ------- list The affected files. + """ raise NotImplementedError @@ -83,6 +86,7 @@ def load(self, instance: "Node", lazy: bool = None): lazy : bool, optional Whether to load the field lazily. This only applies to 'LazyField' classes. + """ try: instance.__dict__[self.name] = self.get_data(instance) @@ -103,6 +107,7 @@ def get_stage_add_argument(self, instance: "Node") -> typing.List[tuple]: ------- typing.List[tuple] The stage add argument for this field. + """ return [ (f"--{self.dvc_option}", pathlib.Path(x).as_posix()) @@ -127,6 +132,7 @@ def get_optional_dvc_cmd( ------- typing.List[str] The optional dvc commands. + """ return [] @@ -173,6 +179,7 @@ def get_value_except_lazy(self, instance): ------ DataIsLazyError If the value is lazy. + """ with contextlib.suppress(KeyError): if instance.__dict__[self.name] is LazyOption: @@ -198,6 +205,7 @@ def load(self, instance: "Node", lazy: bool = None): The Node instance to load the field for. lazy : bool, optional Whether to load the field lazily, by default 'zntrack.config.lazy'. + """ if lazy in {None, True} and config.lazy: instance.__dict__[self.name] = LazyOption @@ -226,6 +234,7 @@ def __init__( ---------- use_global_plots : bool Save the plots config not in 'stages' but in 'plots' in the dvc.yaml file. + """ super().__init__(*args, **kwargs) self.plots_options = {} diff --git a/zntrack/fields/fields.py b/zntrack/fields/fields.py index ab235a80..b7096bea 100644 --- a/zntrack/fields/fields.py +++ b/zntrack/fields/fields.py @@ -17,6 +17,7 @@ def outs(): The object is serialized and deserialized by ZnTrack and stored in the node working directory. see https://dvc.org/doc/command-reference/stage/add#-o + """ return Output(dvc_option="outs", use_repr=False) @@ -49,6 +50,7 @@ def params(*args, **kwargs): see https://dvc.org/doc/command-reference/stage/add#-p kwargs: dict Additional keyword arguments. + """ return Params(*args, **kwargs) @@ -63,6 +65,7 @@ def deps(*data): This can either be a Node or an attribute of a Node. It can not be an object that is not part of the Node graph. see https://dvc.org/doc/command-reference/stage/add#-d + """ return Dependency(*data) @@ -132,6 +135,7 @@ def params_path(*args, **kwargs): see https://dvc.org/doc/command-reference/stage/add#-p kwargs: dict Additional keyword arguments. + """ return DVCOption(*args, dvc_option="params", **kwargs) @@ -163,5 +167,6 @@ def plots_path(*args, dvc_option="plots", **kwargs): The DVC option to use for this field. kwargs: dict Additional keyword arguments that are used for plotting. + """ return PlotsOption(*args, dvc_option=dvc_option, **kwargs) diff --git a/zntrack/fields/zn/options.py b/zntrack/fields/zn/options.py index edb9dcc5..d0aaf9a5 100644 --- a/zntrack/fields/zn/options.py +++ b/zntrack/fields/zn/options.py @@ -103,6 +103,7 @@ class Params(Field): ---------- dvc_option: str The DVC option to use. Default is "params". + """ dvc_option: str = "params" @@ -115,6 +116,7 @@ def get_files(self, instance: "Node") -> list: ------- list A list of file paths. + """ return [config.files.params] @@ -125,6 +127,7 @@ def save(self, instance: "Node"): ---------- instance : Node The node instance associated with this field. + """ file = self.get_files(instance)[0] @@ -161,6 +164,7 @@ def get_stage_add_argument(self, instance: "Node") -> typing.List[tuple]: ------- list A list of tuples containing the DVC option and the file path. + """ file = self.get_files(instance)[0] return [(f"--{self.dvc_option}", f"{file}:{instance.name}")] @@ -180,6 +184,7 @@ def __init__(self, dvc_option: str, **kwargs): The DVC option used to specify the output file. **kwargs Additional arguments to pass to the parent constructor. + """ self.dvc_option = dvc_option super().__init__(**kwargs) @@ -196,6 +201,7 @@ def get_files(self, instance) -> list: ------- list A list containing the path of the file. + """ return [get_nwd(instance) / f"{self.name}.json"] @@ -206,6 +212,7 @@ def save(self, instance: "Node"): ---------- instance : Node The node instance. + """ try: value = self.get_value_except_lazy(instance) @@ -236,6 +243,7 @@ def get_stage_add_argument(self, instance) -> typing.List[tuple]: ------- list A list containing the DVC command for this field. + """ file = self.get_files(instance)[0] return [(f"--{self.dvc_option}", file.as_posix())] diff --git a/zntrack/project/zntrack_project.py b/zntrack/project/zntrack_project.py index e1338583..1b6bc97b 100644 --- a/zntrack/project/zntrack_project.py +++ b/zntrack/project/zntrack_project.py @@ -87,6 +87,7 @@ class Project: node as the node name. E.g. `node = Node()` will result in a node name of 'node'. If used within a group, the group name will be added to the node name. E.g. `group.name = Grp1` and `model = Node()` will result in a name of 'Grp1_model'. + """ graph: ZnTrackGraph = dataclasses.field(default_factory=ZnTrackGraph, init=False) @@ -111,6 +112,7 @@ def __post_init__(self): remove_existing_graph : bool, default = False If True, remove 'dvc.yaml', 'zntrack.json' and 'params.yaml' before writing new nodes. + """ self.graph.project = self if self.initialize: @@ -157,6 +159,7 @@ def group(self, *names: typing.List[str]): The name of the group. If None, the group will be named 'GroupX' where X is the number of groups + 1. If more than one name is given, the groups will be nested to 'nwd = name[0]/name[1]/.../name[-1]' + """ if not names: name = "Group1" @@ -248,6 +251,7 @@ def run( auto_remove : bool, default = False If True, remove all nodes from 'dvc.yaml' that are not in the graph. This is the same as calling 'project.auto_remove()' + """ if not save and not eager: raise ValueError("Save can only be false if eager is True") diff --git a/zntrack/tools/__init__.py b/zntrack/tools/__init__.py index c0ed6439..6e5ee0ce 100644 --- a/zntrack/tools/__init__.py +++ b/zntrack/tools/__init__.py @@ -21,6 +21,7 @@ def timeit(field: str): field : str The field to store the time in. The value is stored as {func_name: time} or {func_name: [time1, time2, ...]} + """ def decorator(func): diff --git a/zntrack/utils/__init__.py b/zntrack/utils/__init__.py index 4bb823be..6ca24687 100644 --- a/zntrack/utils/__init__.py +++ b/zntrack/utils/__init__.py @@ -39,6 +39,7 @@ def __init__(self) -> None: ------ NotImplementedError: This class is not meant to be instantiated. + """ raise NotImplementedError("This class is not meant to be instantiated.") @@ -60,6 +61,7 @@ def module_handler(obj) -> str: ---------- obj: Any object that implements __module__ + """ if config.nb_name: try: @@ -111,6 +113,7 @@ def run_dvc_cmd(script, stdout=None): ------ DVCProcessError: if the dvc cli command fails. + """ dvc_short_string = " ".join(script[:5]) if len(script) > 5: @@ -177,6 +180,7 @@ class NodeStatusResults(enum.Enum): the Node instance has failed to run. AVAILABLE : int the Node instance was loaded and results are available. + """ UNKNOWN = 0 @@ -202,6 +206,7 @@ def cwd_temp_dir(required_files=None) -> tempfile.TemporaryDirectory: ------- temp_dir: The temporary directory file. Close with temp_dir.cleanup() at the end. + """ temp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with # add ignore_cleanup_errors=True in Py3.10? diff --git a/zntrack/utils/cli.py b/zntrack/utils/cli.py index 76265859..e2692af6 100644 --- a/zntrack/utils/cli.py +++ b/zntrack/utils/cli.py @@ -40,6 +40,7 @@ def check_empty(self): Raises ------ typer.Exit: if the directory is not empty and force is false + """ is_empty = not any(pathlib.Path(".").iterdir()) if not is_empty and not self.force: @@ -92,6 +93,7 @@ def get_groups(remote, rev) -> (dict, list): values. Contains "short-name -> long-name" if inside a group. node_names: list A list of all node names in the project. + """ fs = DVCFileSystem(url=remote, rev=rev) with fs.open("zntrack.json") as f: diff --git a/zntrack/utils/config.py b/zntrack/utils/config.py index 4ceae6cd..10beb896 100644 --- a/zntrack/utils/config.py +++ b/zntrack/utils/config.py @@ -15,6 +15,7 @@ class Files: Notes ----- Currently frozen because changing the value is not tested. + """ zntrack: Path = Path("zntrack.json") @@ -55,6 +56,7 @@ class Config: Use the `dvc.cli.main` function instead of subprocess disable_operating_directory: bool, default = False Global config to disable operating directory context manager. + """ nb_name: str = None @@ -86,6 +88,7 @@ def updated_config(self, **kwargs) -> None: Yields ------ Environment with temporarily changed config. + """ state = {} for key, value in kwargs.items(): diff --git a/zntrack/utils/file_io.py b/zntrack/utils/file_io.py index 9fb89c5a..011476ea 100644 --- a/zntrack/utils/file_io.py +++ b/zntrack/utils/file_io.py @@ -28,6 +28,7 @@ def read_file(file: pathlib.Path) -> dict: ------- dict: Content of the json/yaml file + """ if file.suffix in [".yaml", ".yml"]: file_content = yaml.safe_load(file.read_text()) @@ -51,6 +52,7 @@ def write_file(file: pathlib.Path, value: dict, mkdir: bool = True): Any serializable data to save mkdir: bool Create a parent directory if necessary + """ if mkdir: file.parent.mkdir(exist_ok=True, parents=True) @@ -72,6 +74,7 @@ def clear_config_file(file: typing.Union[pathlib.Path, str], node_name: str): The file to read from, e.g. params.yaml / zntrack.json node_name: str The name of the Node + """ file = pathlib.Path(file) try: @@ -107,6 +110,7 @@ def update_config_file( be {node_name: value}. value: The value to write to the file + """ # Read file if node_name is None and value_name is None: