From 2e6dc9cefd0473b1f17b7917ebae5ad4b4b960ec Mon Sep 17 00:00:00 2001 From: hyunp2 <42776897+hyunp2@users.noreply.github.com> Date: Wed, 18 Oct 2023 00:21:04 -0500 Subject: [PATCH] Update generator.py --- mofa/generator.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/mofa/generator.py b/mofa/generator.py index 00dac2b7..87258b04 100644 --- a/mofa/generator.py +++ b/mofa/generator.py @@ -4,7 +4,7 @@ import ase from model import MOFRecord -from difflinker_sample import sampler +from difflinker_sample import sample_from_sdf def train_generator( starting_model: str | Path, @@ -24,18 +24,29 @@ def train_generator( def run_generator( - model: str | Path, - molecule_sizes: list[int], - num_samples: int + node: str='CuCu', + n_atoms: int|str=8, + input_path: str|Path=f"mofa/data/fragments_all/CuCu/hMOF_frag_frag.sdf", + model: str|Path="mofa/models/geom_difflinker.ckpt", + n_samples: int=1, + n_steps: int=None ) -> list[ase.Atoms]: """ Args: model: Path to the starting weights - molecule_sizes: Number of heavy atoms in the linker molecules to generate - num_samples: Number of samples of molecules to generate + n_atoms: Number of heavy atoms in the linker molecules to generate + n_samples: Number of samples of molecules to generate Returns: 3D geometries of the generated linkers """ - sampler(nodes: List[str]=['CuCu'], n_atoms_list) - - raise NotImplementedError() + assert node in input_path, "node must be in input_path name" + sample_from_sdf(node=node, + n_atoms=n_atoms, + input_path=input_path, + model=model, + n_samples=n_samples, + n_steps=n_steps + ) + +if __name__ == "__main__": + run_generator()