Skip to content

Commit

Permalink
Update generator.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunp2 authored Oct 18, 2023
1 parent 520c97f commit c195465
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion mofa/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from model import MOFRecord
from difflinker_sample import sample_from_sdf
from difflinker_train import get_args, main
import yaml

def train_generator(
starting_model: str | Path,
Expand All @@ -20,8 +22,22 @@ def train_generator(
Returns:
Path to the new model weights
"""
raise NotImplementedError()
args = get_args()

if args.config:
config_dict = yaml.load(args.config, Loader=yaml.FullLoader)
arg_dict = args.__dict__
for key, value in config_dict.items():
if isinstance(value, list) and key != 'normalize_factors':
for v in value:
arg_dict[key].append(v)
else:
arg_dict[key] = value
args.config = args.config.name
else:
config_dict = {}

main(args=args)

def run_generator(
node: str='CuCu',
Expand Down

0 comments on commit c195465

Please sign in to comment.