Skip to content

Commit

Permalink
Update generator.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunp2 authored Oct 18, 2023
1 parent 4116a48 commit 2e6dc9c
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions mofa/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ase

from model import MOFRecord
from difflinker_sample import sampler
from difflinker_sample import sample_from_sdf

def train_generator(
starting_model: str | Path,
Expand All @@ -24,18 +24,29 @@ def train_generator(


def run_generator(
model: str | Path,
molecule_sizes: list[int],
num_samples: int
node: str='CuCu',
n_atoms: int|str=8,
input_path: str|Path=f"mofa/data/fragments_all/CuCu/hMOF_frag_frag.sdf",
model: str|Path="mofa/models/geom_difflinker.ckpt",
n_samples: int=1,
n_steps: int=None
) -> list[ase.Atoms]:
"""
Args:
model: Path to the starting weights
molecule_sizes: Number of heavy atoms in the linker molecules to generate
num_samples: Number of samples of molecules to generate
n_atoms: Number of heavy atoms in the linker molecules to generate
n_samples: Number of samples of molecules to generate
Returns:
3D geometries of the generated linkers
"""
sampler(nodes: List[str]=['CuCu'], n_atoms_list)

raise NotImplementedError()
assert node in input_path, "node must be in input_path name"
sample_from_sdf(node=node,
n_atoms=n_atoms,
input_path=input_path,
model=model,
n_samples=n_samples,
n_steps=n_steps
)

if __name__ == "__main__":
run_generator()

0 comments on commit 2e6dc9c

Please sign in to comment.