From c195465cef0c46877f06c7c3bb5b9e534c42bb5f Mon Sep 17 00:00:00 2001 From: hyunp2 <42776897+hyunp2@users.noreply.github.com> Date: Wed, 18 Oct 2023 00:34:32 -0500 Subject: [PATCH] Update generator.py --- mofa/generator.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/mofa/generator.py b/mofa/generator.py index ee0e801d..1adbe9aa 100644 --- a/mofa/generator.py +++ b/mofa/generator.py @@ -5,6 +5,8 @@ from model import MOFRecord from difflinker_sample import sample_from_sdf +from difflinker_train import get_args, main +import yaml def train_generator( starting_model: str | Path, @@ -20,8 +22,22 @@ def train_generator( Returns: Path to the new model weights """ - raise NotImplementedError() + args = get_args() + if args.config: + config_dict = yaml.load(args.config, Loader=yaml.FullLoader) + arg_dict = args.__dict__ + for key, value in config_dict.items(): + if isinstance(value, list) and key != 'normalize_factors': + for v in value: + arg_dict[key].append(v) + else: + arg_dict[key] = value + args.config = args.config.name + else: + config_dict = {} + + main(args=args) def run_generator( node: str='CuCu',