Skip to content

Commit

Permalink
Update generator.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunp2 authored Oct 18, 2023
1 parent f2d1936 commit f90d955
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion mofa/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from difflinker_sample import sample_from_sdf
from difflinker_train import get_args, main
import yaml
import os

def train_generator(
# starting_model: str | Path,
Expand All @@ -22,6 +23,8 @@ def train_generator(
Returns:
Path to the new model weights
"""
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

args = get_args()

if args.config:
Expand Down Expand Up @@ -70,4 +73,6 @@ def run_generator(
print("Saved XYZ files in mofa/output directory!")

if __name__ == "__main__":
run_generator()
# run_generator()
train_generator()

0 comments on commit f90d955

Please sign in to comment.