File size: 2,703 Bytes
52c1bfb
 
 
 
fdb2bc5
52c1bfb
 
 
 
 
 
fdb2bc5
 
52c1bfb
 
 
 
fdb2bc5
52c1bfb
 
 
 
 
 
 
22f0dbc
52c1bfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8396dce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from __future__ import annotations

from pathlib import Path

import yaml
import requests
from orb_models.forcefield import pretrained
from orb_models.forcefield.calculator import ORBCalculator

from mlip_arena.models.utils import get_freer_device

with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f:
    REGISTRY = yaml.safe_load(f)

class ORB(ORBCalculator):
    def __init__(
        self,
        checkpoint=REGISTRY["ORB"]["checkpoint"],
        device=None,
        **kwargs,
    ):
        device = device or get_freer_device()

        cache_dir = Path.home() / ".cache" / "orb"
        cache_dir.mkdir(parents=True, exist_ok=True)
        ckpt_path = cache_dir / checkpoint

        url = f"https://storage.googleapis.com/orbitalmaterials-public-models/forcefields/{checkpoint}"

        if not ckpt_path.exists():
            print(f"Downloading ORB model from {url} to {ckpt_path}...")
            try:
                response = requests.get(url, stream=True, timeout=120)
                response.raise_for_status()
                with open(ckpt_path, "wb") as f:
                    for chunk in response.iter_content(chunk_size=8192):
                        f.write(chunk)
                print("Download completed.")
            except requests.exceptions.RequestException as e:
                raise RuntimeError("Failed to download ORB model.") from e

        orbff = pretrained.orb_v1(weights_path=ckpt_path, device=device)
        super().__init__(orbff, device=device, **kwargs)

class ORBv2(ORBCalculator):
    def __init__(
        self,
        checkpoint=REGISTRY["ORBv2"]["checkpoint"],
        device=None,
        **kwargs,
    ):
        device = device or get_freer_device()

        cache_dir = Path.home() / ".cache" / "orb"
        cache_dir.mkdir(parents=True, exist_ok=True)
        ckpt_path = cache_dir / checkpoint

        url = f"https://storage.googleapis.com/orbitalmaterials-public-models/forcefields/{checkpoint}"

        if not ckpt_path.exists():
            print(f"Downloading ORB model from {url} to {ckpt_path}...")
            try:
                response = requests.get(url, stream=True, timeout=120)
                response.raise_for_status()
                with open(ckpt_path, "wb") as f:
                    for chunk in response.iter_content(chunk_size=8192):
                        f.write(chunk)
                print("Download completed.")
            except requests.exceptions.RequestException as e:
                raise RuntimeError("Failed to download ORB model.") from e

        orbff = pretrained.orb_v2(weights_path=ckpt_path, device=device)
        super().__init__(orbff, device=device, **kwargs)