diff --git a/mofa/generator.py b/mofa/generator.py index 8bedbd08..ec97d064 100644 --- a/mofa/generator.py +++ b/mofa/generator.py @@ -7,6 +7,7 @@ from difflinker_sample import sample_from_sdf from difflinker_train import get_args, main import yaml +import os def train_generator( # starting_model: str | Path, @@ -22,6 +23,8 @@ def train_generator( Returns: Path to the new model weights """ + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + args = get_args() if args.config: @@ -70,4 +73,6 @@ def run_generator( print("Saved XYZ files in mofa/output directory!") if __name__ == "__main__": - run_generator() + # run_generator() + train_generator() +