Skip to content

Commit

Permalink
clean calculator (#26)
Browse files Browse the repository at this point in the history
* clean calculator

* clean code

* update README

* fix bug

* update readme

* fix load method

---------

Co-authored-by: Xixian <v-xixianliu@microsoft.com>
Co-authored-by: Han Yang <hanyang@microsoft.com>
  • Loading branch information
3 people authored Nov 28, 2024
1 parent df61597 commit 39a550e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 10 deletions.
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,19 @@ python setup.py build_ext --inplace
```python
import torch
from ase.build import bulk
from mattersim.forcefield.potential import Potential
from mattersim.datasets.utils.build import build_dataloader
from ase.units import GPa
from mattersim.forcefield import MatterSimCalculator

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Running MatterSim on {device}")

potential = Potential.load()
si = bulk("Si", "diamond", a=5.43)
dataloader = build_dataloader([si], only_inference=True)

predictions = potential.predict_properties(dataloader, include_forces=True, include_stresses=True)
print(predictions)
si.calc = MatterSimCalculator()
print(f"Energy (eV) = {si.get_potential_energy()}")
print(f"Energy per atom (eV/atom) = {si.get_potential_energy()/len(si)}")
print(f"Forces of first atom (eV/A) = {si.get_forces()[0]}")
print(f"Stress[0][0] (eV/A^3) = {si.get_stress(voigt=False)[0][0]}")
print(f"Stress[0][0] (GPa) = {si.get_stress(voigt=False)[0][0] / GPa}")
```


Expand Down
4 changes: 4 additions & 0 deletions src/mattersim/forcefield/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
from .potential import MatterSimCalculator, Potential

__all__ = ["MatterSimCalculator", "Potential"]
38 changes: 35 additions & 3 deletions src/mattersim/forcefield/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ def batch_to_dict(graph_batch, model_type="m3gnet", device="cuda"):
return input


class DeepCalculator(Calculator):
class MatterSimCalculator(Calculator):
"""
Deep calculator based on ase Calculator
"""
Expand All @@ -999,7 +999,7 @@ class DeepCalculator(Calculator):

def __init__(
self,
potential: Potential,
potential: Potential = None,
args_dict: dict = {},
compute_stress: bool = True,
stress_weight: float = GPa,
Expand All @@ -1014,12 +1014,44 @@ def __init__(
**kwargs:
"""
super().__init__(**kwargs)
self.potential = potential
if potential is None:
self.potential = Potential.load()
else:
self.potential = potential
self.compute_stress = compute_stress
self.stress_weight = stress_weight
self.args_dict = args_dict
self.device = device

@staticmethod
def load(
load_path: str = None,
*,
model_name: str = "m3gnet",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
args: Dict = None,
load_training_state: bool = True,
args_dict: dict = {},
compute_stress: bool = True,
stress_weight: float = GPa,
**kwargs,
):
potential = Potential.load(
load_path=load_path,
model_name=model_name,
device=device,
args=args,
load_training_state=load_training_state,
)
return MatterSimCalculator(
potential=potential,
args_dict=args_dict,
compute_stress=compute_stress,
stress_weight=stress_weight,
device=device,
**kwargs,
)

def calculate(
self,
atoms: Optional[Atoms] = None,
Expand Down

0 comments on commit 39a550e

Please sign in to comment.