Yuan (Cyrus) Chiang
Add `eqV2_86M_omat_mp_salex` model (#14)
52c1bfb unverified
from __future__ import annotations
from pathlib import Path
import yaml
from ase import Atoms
from fairchem.core import OCPCalculator
from huggingface_hub import hf_hub_download
with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f:
REGISTRY = yaml.safe_load(f)
class eqV2(OCPCalculator):
def __init__(
self,
checkpoint=REGISTRY["eqV2(OMat)"]["checkpoint"],
cache_dir=None,
cpu=False, # TODO: cannot assign device
seed=0,
**kwargs,
) -> None:
"""
Initialize an eqV2 calculator.
Parameters
----------
checkpoint : str, default="eqV2_86M_omat_mp_salex.pt"
The name of the eqV2 checkpoint to use.
local_cache : str, default="/tmp/ocp/"
The directory to store the downloaded checkpoint.
cpu : bool, default=False
Whether to run the model on CPU or GPU.
seed : int, default=0
The random seed for the model.
Other Parameters
----------------
**kwargs
Any additional keyword arguments are passed to the superclass.
"""
# https://huggingface.co/fairchem/OMAT24/resolve/main/eqV2_86M_omat_mp_salex.pt
checkpoint_path = hf_hub_download(
"fairchem/OMAT24",
filename=checkpoint,
revision="bf92f9671cb9d5b5c77ecb4aa8b317ff10b882ce",
cache_dir=cache_dir
)
super().__init__(
checkpoint_path=checkpoint_path,
cpu=cpu,
seed=seed,
**kwargs,
)
class EquiformerV2(OCPCalculator):
def __init__(
self,
checkpoint=REGISTRY["EquiformerV2(OC22)"]["checkpoint"],
# TODO: cannot assign device
local_cache="/tmp/ocp/",
cpu=False,
seed=0,
**kwargs,
) -> None:
super().__init__(
model_name=checkpoint,
local_cache=local_cache,
cpu=cpu,
seed=seed,
**kwargs,
)
def calculate(self, atoms: Atoms, properties, system_changes) -> None:
super().calculate(atoms, properties, system_changes)
self.results.update(
force=atoms.get_forces(),
)
class EquiformerV2OC20(OCPCalculator):
def __init__(
self,
checkpoint=REGISTRY["EquiformerV2(OC22)"]["checkpoint"],
# TODO: cannot assign device
local_cache="/tmp/ocp/",
cpu=False,
seed=0,
**kwargs,
) -> None:
super().__init__(
model_name=checkpoint,
local_cache=local_cache,
cpu=cpu,
seed=seed,
**kwargs,
)
class eSCN(OCPCalculator):
def __init__(
self,
checkpoint="eSCN-L6-M3-Lay20-S2EF-OC20-All+MD", # TODO: import from registry
# TODO: cannot assign device
local_cache="/tmp/ocp/",
cpu=False,
seed=0,
**kwargs,
) -> None:
super().__init__(
model_name=checkpoint,
local_cache=local_cache,
cpu=cpu,
seed=seed,
**kwargs,
)
def calculate(self, atoms: Atoms, properties, system_changes) -> None:
super().calculate(atoms, properties, system_changes)
self.results.update(
force=atoms.get_forces(),
)