diff --git a/pyproject.toml b/pyproject.toml index 0770e3f..0ce5fd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,8 @@ dependencies = [ "azure-identity", "mp-api", "emmet-core<0.84", - "pydantic==2.9.2" + "pydantic==2.9.2", + "deprecated" ] [project.optional-dependencies] diff --git a/src/mattersim/forcefield/__init__.py b/src/mattersim/forcefield/__init__.py index 8af166d..4dc9769 100644 --- a/src/mattersim/forcefield/__init__.py +++ b/src/mattersim/forcefield/__init__.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- -from .potential import MatterSimCalculator, Potential +from .potential import DeepCalculator, MatterSimCalculator, Potential -__all__ = ["MatterSimCalculator", "Potential"] +__all__ = ["MatterSimCalculator", "Potential", "DeepCalculator"] diff --git a/src/mattersim/forcefield/potential.py b/src/mattersim/forcefield/potential.py index 6923c55..5478459 100644 --- a/src/mattersim/forcefield/potential.py +++ b/src/mattersim/forcefield/potential.py @@ -18,6 +18,7 @@ from ase.calculators.calculator import Calculator from ase.constraints import full_3x3_to_voigt_6_stress from ase.units import GPa +from deprecated import deprecated from torch.optim import Adam from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR from torch_ema import ExponentialMovingAverage @@ -990,6 +991,107 @@ def batch_to_dict(graph_batch, model_type="m3gnet", device="cuda"): return input +@deprecated(version="1.0.0", reason="Please use MatterSimCalculator instead.") +class DeepCalculator(Calculator): + """ + Deep calculator based on ase Calculator + """ + + implemented_properties = ["energy", "free_energy", "forces", "stress"] + + def __init__( + self, + potential: Potential, + args_dict: dict = {}, + compute_stress: bool = True, + stress_weight: float = 1.0, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + **kwargs, + ): + """ + Args: + potential (Potential): m3gnet.models.Potential + compute_stress (bool): whether to calculate the stress + stress_weight (float): the stress weight. + **kwargs: + """ + super().__init__(**kwargs) + self.potential = potential + self.compute_stress = compute_stress + self.stress_weight = stress_weight + self.args_dict = args_dict + self.device = device + + def calculate( + self, + atoms: Optional[Atoms] = None, + properties: Optional[list] = None, + system_changes: Optional[list] = None, + ): + """ + Args: + atoms (ase.Atoms): ase Atoms object + properties (list): list of properties to calculate + system_changes (list): monitor which properties of atoms were + changed for new calculation. If not, the previous calculation + results will be loaded. + Returns: + """ + + all_changes = [ + "positions", + "numbers", + "cell", + "pbc", + "initial_charges", + "initial_magmoms", + ] + + properties = properties or ["energy"] + system_changes = system_changes or all_changes + super().calculate( + atoms=atoms, properties=properties, system_changes=system_changes + ) + + self.args_dict["batch_size"] = 1 + self.args_dict["only_inference"] = 1 + dataloader = build_dataloader( + [atoms], model_type=self.potential.model_name, **self.args_dict + ) + for graph_batch in dataloader: + # Resemble input dictionary + if ( + self.potential.model_name == "graphormer" + or self.potential.model_name == "geomformer" + ): + raise NotImplementedError + else: + graph_batch = graph_batch.to(self.device) + input = batch_to_dict(graph_batch) + + result = self.potential.forward( + input, include_forces=True, include_stresses=self.compute_stress + ) + if ( + self.potential.model_name == "graphormer" + or self.potential.model_name == "geomformer" + ): + raise NotImplementedError + else: + self.results.update( + energy=result["total_energy"].detach().cpu().numpy()[0], + free_energy=result["total_energy"].detach().cpu().numpy()[0], + forces=result["forces"].detach().cpu().numpy(), + ) + if self.compute_stress: + self.results.update( + stress=self.stress_weight + * full_3x3_to_voigt_6_stress( + result["stresses"].detach().cpu().numpy()[0] + ) + ) + + class MatterSimCalculator(Calculator): """ Deep calculator based on ase Calculator