Spaces:
Running
Running
File size: 1,297 Bytes
52c1bfb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from __future__ import annotations
import os
from pathlib import Path
from mace.calculators import MACECalculator
from mlip_arena.models.utils import get_freer_device
class MACE_OFF_Medium(MACECalculator):
def __init__(
self,
checkpoint="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true",
device: str | None = None,
default_dtype="float32",
**kwargs,
):
cache_dir = Path.home() / ".cache" / "mace"
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint) if c.isalnum() or c in "_"
)
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
import urllib
os.makedirs(cache_dir, exist_ok=True)
_, http_msg = urllib.request.urlretrieve(checkpoint, cached_model_path)
if "Content-Type: text/html" in http_msg:
raise RuntimeError(
f"Model download failed, please check the URL {checkpoint}"
)
model = cached_model_path
device = device or str(get_freer_device())
super().__init__(
model_paths=model, device=device, default_dtype=default_dtype, **kwargs
) |