From f90d955f92925704a287032c849062932964a4fc Mon Sep 17 00:00:00 2001 From: hyunp2 <42776897+hyunp2@users.noreply.github.com> Date: Wed, 18 Oct 2023 00:46:31 -0500 Subject: [PATCH] Update generator.py --- mofa/generator.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mofa/generator.py b/mofa/generator.py index 8bedbd08..ec97d064 100644 --- a/mofa/generator.py +++ b/mofa/generator.py @@ -7,6 +7,7 @@ from difflinker_sample import sample_from_sdf from difflinker_train import get_args, main import yaml +import os def train_generator( # starting_model: str | Path, @@ -22,6 +23,8 @@ def train_generator( Returns: Path to the new model weights """ + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + args = get_args() if args.config: @@ -70,4 +73,6 @@ def run_generator( print("Saved XYZ files in mofa/output directory!") if __name__ == "__main__": - run_generator() + # run_generator() + train_generator() +