From 4847df2275734fc414ef93e3f028305266c68b31 Mon Sep 17 00:00:00 2001 From: myron Date: Fri, 3 Nov 2023 02:54:04 -0700 Subject: [PATCH] autorunner params from config (#7175) allows setting AutoRunner params from config allows specifying number of folds in config --------- Signed-off-by: myron --- monai/apps/auto3dseg/auto_runner.py | 79 +++++++++++++++++++---------- tests/test_vis_gradcam.py | 3 +- 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py index 80ae34180e..e4c2d908b7 100644 --- a/monai/apps/auto3dseg/auto_runner.py +++ b/monai/apps/auto3dseg/auto_runner.py @@ -214,22 +214,11 @@ def __init__( mlflow_tracking_uri: str | None = None, **kwargs: Any, ): - logger.info(f"AutoRunner using work directory {work_dir}") - os.makedirs(work_dir, exist_ok=True) - - self.work_dir = os.path.abspath(work_dir) - self.data_src_cfg = dict() - self.data_src_cfg_name = os.path.join(self.work_dir, "input.yaml") - self.algos = algos - self.templates_path_or_url = templates_path_or_url - self.allow_skip = allow_skip - self.mlflow_tracking_uri = mlflow_tracking_uri - self.kwargs = deepcopy(kwargs) - - if input is None and os.path.isfile(self.data_src_cfg_name): - input = self.data_src_cfg_name + if input is None and os.path.isfile(os.path.join(os.path.abspath(work_dir), "input.yaml")): + input = os.path.join(os.path.abspath(work_dir), "input.yaml") logger.info(f"Input config is not provided, using the default {input}") + self.data_src_cfg = dict() if isinstance(input, dict): self.data_src_cfg = input elif isinstance(input, str) and os.path.isfile(input): @@ -238,6 +227,51 @@ def __init__( else: raise ValueError(f"{input} is not a valid file or dict") + if "work_dir" in self.data_src_cfg: # override from config + work_dir = self.data_src_cfg["work_dir"] + self.work_dir = os.path.abspath(work_dir) + + logger.info(f"AutoRunner using work directory {self.work_dir}") + os.makedirs(self.work_dir, exist_ok=True) + self.data_src_cfg_name = os.path.join(self.work_dir, "input.yaml") + + self.algos = algos + self.templates_path_or_url = templates_path_or_url + self.allow_skip = allow_skip + + # cache.yaml + self.not_use_cache = not_use_cache + self.cache_filename = os.path.join(self.work_dir, "cache.yaml") + self.cache = self.read_cache() + self.export_cache() + + # determine if we need to analyze, algo_gen or train from cache, unless manually provided + self.analyze = not self.cache["analyze"] if analyze is None else analyze + self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen + self.train = train + self.ensemble = ensemble # last step, no need to check + self.hpo = hpo and has_nni + self.hpo_backend = hpo_backend + self.mlflow_tracking_uri = mlflow_tracking_uri + self.kwargs = deepcopy(kwargs) + + # parse input config for AutoRunner param overrides + for param in [ + "analyze", + "algo_gen", + "train", + "hpo", + "ensemble", + "not_use_cache", + "allow_skip", + ]: # override from config + if param in self.data_src_cfg and isinstance(self.data_src_cfg[param], bool): + setattr(self, param, self.data_src_cfg[param]) # e.g. self.analyze = self.data_src_cfg["analyze"] + + for param in ["algos", "hpo_backend", "templates_path_or_url", "mlflow_tracking_uri"]: # override from config + if param in self.data_src_cfg: + setattr(self, param, self.data_src_cfg[param]) # e.g. self.algos = self.data_src_cfg["algos"] + missing_keys = {"dataroot", "datalist", "modality"}.difference(self.data_src_cfg.keys()) if len(missing_keys) > 0: raise ValueError(f"Config keys are missing {missing_keys}") @@ -256,6 +290,8 @@ def __init__( # inspect and update folds num_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename) + if "num_fold" in self.data_src_cfg: + num_fold = int(self.data_src_cfg["num_fold"]) # override from config self.data_src_cfg["datalist"] = datalist_filename # update path to a version in work_dir and save user input ConfigParser.export_config_file( @@ -266,17 +302,6 @@ def __init__( self.datastats_filename = os.path.join(self.work_dir, "datastats.yaml") self.datalist_filename = datalist_filename - self.not_use_cache = not_use_cache - self.cache_filename = os.path.join(self.work_dir, "cache.yaml") - self.cache = self.read_cache() - self.export_cache() - - # determine if we need to analyze, algo_gen or train from cache, unless manually provided - self.analyze = not self.cache["analyze"] if analyze is None else analyze - self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen - self.train = train - self.ensemble = ensemble # last step, no need to check - self.set_training_params() self.set_device_info() self.set_prediction_params() @@ -288,9 +313,9 @@ def __init__( self.gpu_customization_specs: dict[str, Any] = {} # hpo - if hpo_backend.lower() != "nni": + if self.hpo_backend.lower() != "nni": raise NotImplementedError("HPOGen backend only supports NNI") - self.hpo = hpo and has_nni + self.hpo = self.hpo and has_nni self.set_hpo_params() self.search_space: dict[str, dict[str, Any]] = {} self.hpo_tasks = 0 diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index f5ba188082..4b554de0aa 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -20,7 +20,7 @@ from monai.networks.nets import DenseNet, DenseNet121, SEResNet50 from monai.visualize import GradCAM, GradCAMpp -from tests.utils import assert_allclose +from tests.utils import assert_allclose, skip_if_quick class DenseNetAdjoint(DenseNet121): @@ -147,6 +147,7 @@ def __call__(self, x, adjoint_info): TESTS_ILL.append([cam]) +@skip_if_quick class TestGradientClassActivationMap(unittest.TestCase): @parameterized.expand(TESTS) def test_shape(self, cam_class, input_data, expected_shape):