Spaces:
Running
Running
from __future__ import annotations | |
import os | |
from pathlib import Path | |
from mace.calculators import MACECalculator | |
from mlip_arena.models.utils import get_freer_device | |
class MACE_OFF_Medium(MACECalculator): | |
def __init__( | |
self, | |
checkpoint="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true", | |
device: str | None = None, | |
default_dtype="float32", | |
**kwargs, | |
): | |
cache_dir = Path.home() / ".cache" / "mace" | |
checkpoint_url_name = "".join( | |
c for c in os.path.basename(checkpoint) if c.isalnum() or c in "_" | |
) | |
cached_model_path = f"{cache_dir}/{checkpoint_url_name}" | |
if not os.path.isfile(cached_model_path): | |
import urllib | |
os.makedirs(cache_dir, exist_ok=True) | |
_, http_msg = urllib.request.urlretrieve(checkpoint, cached_model_path) | |
if "Content-Type: text/html" in http_msg: | |
raise RuntimeError( | |
f"Model download failed, please check the URL {checkpoint}" | |
) | |
model = cached_model_path | |
device = device or str(get_freer_device()) | |
super().__init__( | |
model_paths=model, device=device, default_dtype=default_dtype, **kwargs | |
) |