cyrusyc's picture
page configurations
056d8d3
raw
history blame
2.66 kB
import os
from pathlib import Path
import torch
import yaml
from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from torch_geometric.data import Data
with open(Path(__file__).parent / "registry.yaml") as f:
REGISTRY = yaml.safe_load(f)
class MLIP(
nn.Module,
PyTorchModelHubMixin,
tags=["atomistic-simulation", "MLIP"],
):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
class ModuleMLIP(MLIP):
def __init__(self, model: nn.Module, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.add_module("model", model)
def forward(self, x):
print("Forwarding...")
out = self.model(x)
print("Forwarded!")
return out
class MLIPCalculator(Calculator):
name: str
# device: torch.device
# model: MLIP
implemented_properties: list[str] = ["energy", "forces", "stress"]
def __init__(
self,
# ASE Calculator
restart=None,
atoms=None,
directory=".",
**kwargs,
):
super().__init__(restart=restart, atoms=atoms, directory=directory, **kwargs)
# self.name: str = self.__class__.__name__
# self.device = device or torch.device(
# "cuda" if torch.cuda.is_available() else "cpu"
# )
# self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device)
# self.implemented_properties = ["energy", "forces", "stress"]
def calculate(
self, atoms: Atoms, properties: list[str], system_changes: list = all_changes,
):
"""Calculate energies and forces for the given Atoms object"""
super().calculate(atoms, properties, system_changes)
output = self.forward(atoms)
self.results = {}
if "energy" in properties:
self.results["energy"] = output["energy"].squeeze().item()
if "forces" in properties:
self.results["forces"] = output["forces"].squeeze().cpu().detach().numpy()
if "stress" in properties:
self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy()
def forward(self, x: Atoms) -> dict[str, torch.Tensor]:
"""Implement data conversion, graph creation, and model forward pass
Example implementation:
1. Use `ase.neighborlist.NeighborList` to get neighbor list
2. Create `torch_geometric.data.Data` object and copy the data
3. Pass the `Data` object to the model and return the output
"""
raise NotImplementedError