Spaces:
Running
Running
from __future__ import annotations | |
from pathlib import Path | |
import yaml | |
from ase import Atoms | |
from fairchem.core import OCPCalculator | |
with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f: | |
REGISTRY = yaml.safe_load(f) | |
class EquiformerV2(OCPCalculator): | |
def __init__( | |
self, | |
checkpoint=REGISTRY["EquiformerV2(OC22)"]["checkpoint"], | |
# TODO: cannot assign device | |
local_cache="/tmp/ocp/", | |
cpu=False, | |
seed=0, | |
**kwargs, | |
) -> None: | |
super().__init__( | |
model_name=checkpoint, | |
local_cache=local_cache, | |
cpu=cpu, | |
seed=seed, | |
**kwargs, | |
) | |
def calculate(self, atoms: Atoms, properties, system_changes) -> None: | |
super().calculate(atoms, properties, system_changes) | |
self.results.update( | |
force=atoms.get_forces(), | |
) | |
class EquiformerV2OC20(OCPCalculator): | |
def __init__( | |
self, | |
checkpoint=REGISTRY["EquiformerV2(OC22)"]["checkpoint"], | |
# TODO: cannot assign device | |
local_cache="/tmp/ocp/", | |
cpu=False, | |
seed=0, | |
**kwargs, | |
) -> None: | |
super().__init__( | |
model_name=checkpoint, | |
local_cache=local_cache, | |
cpu=cpu, | |
seed=seed, | |
**kwargs, | |
) |