Spaces:
Running
Running
improve caching, unify calculator init
Browse files- mlip_arena/models/externals/mace-mp.py +1 -1
- mlip_arena/models/externals/mattersim.py +19 -4
- mlip_arena/tasks/elasticity.py +1 -1
- mlip_arena/tasks/eos.py +9 -5
- mlip_arena/tasks/md.py +10 -38
- mlip_arena/tasks/neb.py +20 -7
- mlip_arena/tasks/optimize.py +10 -41
- mlip_arena/tasks/utils.py +25 -10
mlip_arena/models/externals/mace-mp.py
CHANGED
@@ -11,7 +11,7 @@ from mlip_arena.models.utils import get_freer_device
|
|
11 |
class MACE_MP_Medium(MACECalculator):
|
12 |
def __init__(
|
13 |
self,
|
14 |
-
checkpoint="
|
15 |
device: str | None = None,
|
16 |
default_dtype="float32",
|
17 |
**kwargs,
|
|
|
11 |
class MACE_MP_Medium(MACECalculator):
|
12 |
def __init__(
|
13 |
self,
|
14 |
+
checkpoint="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model",
|
15 |
device: str | None = None,
|
16 |
default_dtype="float32",
|
17 |
**kwargs,
|
mlip_arena/models/externals/mattersim.py
CHANGED
@@ -5,11 +5,14 @@ from pathlib import Path
|
|
5 |
import yaml
|
6 |
from mattersim.forcefield import MatterSimCalculator
|
7 |
|
|
|
8 |
from mlip_arena.models.utils import get_freer_device
|
|
|
9 |
|
10 |
with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f:
|
11 |
REGISTRY = yaml.safe_load(f)
|
12 |
|
|
|
13 |
class MatterSim(MatterSimCalculator):
|
14 |
def __init__(
|
15 |
self,
|
@@ -18,7 +21,19 @@ class MatterSim(MatterSimCalculator):
|
|
18 |
**kwargs,
|
19 |
):
|
20 |
super().__init__(
|
21 |
-
load_path=checkpoint,
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import yaml
|
6 |
from mattersim.forcefield import MatterSimCalculator
|
7 |
|
8 |
+
from ase import Atoms
|
9 |
from mlip_arena.models.utils import get_freer_device
|
10 |
+
# from pymatgen.io.ase import AseAtomsAdaptor, MSONAtoms
|
11 |
|
12 |
with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f:
|
13 |
REGISTRY = yaml.safe_load(f)
|
14 |
|
15 |
+
|
16 |
class MatterSim(MatterSimCalculator):
|
17 |
def __init__(
|
18 |
self,
|
|
|
21 |
**kwargs,
|
22 |
):
|
23 |
super().__init__(
|
24 |
+
load_path=checkpoint, device=str(device or get_freer_device()), **kwargs
|
25 |
+
)
|
26 |
+
|
27 |
+
def calculate(
|
28 |
+
self,
|
29 |
+
atoms: Atoms | None = None,
|
30 |
+
properties: list | None = None,
|
31 |
+
system_changes: list | None = None,
|
32 |
+
):
|
33 |
+
super().calculate(atoms, properties, system_changes)
|
34 |
+
|
35 |
+
# # convert unpicklizable atoms back to picklizable atoms to avoid prefect pickling error
|
36 |
+
# if isinstance(self.atoms, MSONAtoms):
|
37 |
+
# atoms = self.atoms.copy()
|
38 |
+
# strucutre = AseAtomsAdaptor().get_structure(atoms)
|
39 |
+
# self.atoms = AseAtomsAdaptor().get_atoms(strucutre, msonable=False)
|
mlip_arena/tasks/elasticity.py
CHANGED
@@ -90,7 +90,7 @@ def run(
|
|
90 |
normal_strains: list[float] | np.ndarray | None = np.linspace(-0.01, 0.01, 4),
|
91 |
shear_strains: list[float] | np.ndarray | None = np.linspace(-0.06, 0.06, 4),
|
92 |
persist_opt: bool = True,
|
93 |
-
cache_opt: bool =
|
94 |
) -> dict[str, Any] | State:
|
95 |
"""
|
96 |
Compute the elastic tensor for the given structure and calculator.
|
|
|
90 |
normal_strains: list[float] | np.ndarray | None = np.linspace(-0.01, 0.01, 4),
|
91 |
shear_strains: list[float] | np.ndarray | None = np.linspace(-0.06, 0.06, 4),
|
92 |
persist_opt: bool = True,
|
93 |
+
cache_opt: bool = False,
|
94 |
) -> dict[str, Any] | State:
|
95 |
"""
|
96 |
Compute the elastic tensor for the given structure and calculator.
|
mlip_arena/tasks/eos.py
CHANGED
@@ -17,8 +17,6 @@ from prefect.runtime import task_run
|
|
17 |
from prefect.states import State
|
18 |
|
19 |
from ase import Atoms
|
20 |
-
from ase.filters import * # type: ignore
|
21 |
-
from ase.optimize import * # type: ignore
|
22 |
from ase.optimize.optimize import Optimizer
|
23 |
from mlip_arena.models import MLIPEnum
|
24 |
from mlip_arena.tasks.optimize import run as OPT
|
@@ -54,6 +52,7 @@ def run(
|
|
54 |
max_abs_strain: float = 0.1,
|
55 |
npoints: int = 11,
|
56 |
concurrent: bool = True,
|
|
|
57 |
) -> dict[str, Any] | State:
|
58 |
"""
|
59 |
Compute the equation of state (EOS) for the given atoms and calculator.
|
@@ -78,7 +77,12 @@ def run(
|
|
78 |
A dictionary containing the EOS data, bulk modulus, equilibrium volume, and equilibrium energy if successful. Otherwise, a prefect state object.
|
79 |
"""
|
80 |
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
82 |
atoms=atoms,
|
83 |
calculator_name=calculator_name,
|
84 |
calculator_kwargs=calculator_kwargs,
|
@@ -112,7 +116,7 @@ def run(
|
|
112 |
atoms = relaxed.copy()
|
113 |
atoms.set_cell(c0 * f, scale_atoms=True)
|
114 |
|
115 |
-
future =
|
116 |
atoms=atoms,
|
117 |
calculator_name=calculator_name,
|
118 |
calculator_kwargs=calculator_kwargs,
|
@@ -138,7 +142,7 @@ def run(
|
|
138 |
atoms = relaxed.copy()
|
139 |
atoms.set_cell(c0 * f, scale_atoms=True)
|
140 |
|
141 |
-
state =
|
142 |
atoms=atoms,
|
143 |
calculator_name=calculator_name,
|
144 |
calculator_kwargs=calculator_kwargs,
|
|
|
17 |
from prefect.states import State
|
18 |
|
19 |
from ase import Atoms
|
|
|
|
|
20 |
from ase.optimize.optimize import Optimizer
|
21 |
from mlip_arena.models import MLIPEnum
|
22 |
from mlip_arena.tasks.optimize import run as OPT
|
|
|
52 |
max_abs_strain: float = 0.1,
|
53 |
npoints: int = 11,
|
54 |
concurrent: bool = True,
|
55 |
+
cache_opt: bool = False,
|
56 |
) -> dict[str, Any] | State:
|
57 |
"""
|
58 |
Compute the equation of state (EOS) for the given atoms and calculator.
|
|
|
77 |
A dictionary containing the EOS data, bulk modulus, equilibrium volume, and equilibrium energy if successful. Otherwise, a prefect state object.
|
78 |
"""
|
79 |
|
80 |
+
OPT_ = OPT.with_options(
|
81 |
+
refresh_cache=not cache_opt,
|
82 |
+
persist_result=cache_opt,
|
83 |
+
)
|
84 |
+
|
85 |
+
state = OPT_(
|
86 |
atoms=atoms,
|
87 |
calculator_name=calculator_name,
|
88 |
calculator_kwargs=calculator_kwargs,
|
|
|
116 |
atoms = relaxed.copy()
|
117 |
atoms.set_cell(c0 * f, scale_atoms=True)
|
118 |
|
119 |
+
future = OPT_.submit(
|
120 |
atoms=atoms,
|
121 |
calculator_name=calculator_name,
|
122 |
calculator_kwargs=calculator_kwargs,
|
|
|
142 |
atoms = relaxed.copy()
|
143 |
atoms.set_cell(c0 * f, scale_atoms=True)
|
144 |
|
145 |
+
state = OPT_(
|
146 |
atoms=atoms,
|
147 |
calculator_name=calculator_name,
|
148 |
calculator_kwargs=calculator_kwargs,
|
mlip_arena/tasks/md.py
CHANGED
@@ -65,12 +65,9 @@ from prefect.cache_policies import INPUTS, TASK_SOURCE
|
|
65 |
from prefect.runtime import task_run
|
66 |
from scipy.interpolate import interp1d
|
67 |
from scipy.linalg import schur
|
68 |
-
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
|
69 |
from tqdm.auto import tqdm
|
70 |
|
71 |
from ase import Atoms, units
|
72 |
-
from ase.calculators.calculator import Calculator
|
73 |
-
from ase.calculators.mixing import SumCalculator
|
74 |
from ase.io import read
|
75 |
from ase.io.trajectory import Trajectory
|
76 |
from ase.md.andersen import Andersen
|
@@ -86,7 +83,7 @@ from ase.md.velocitydistribution import (
|
|
86 |
)
|
87 |
from ase.md.verlet import VelocityVerlet
|
88 |
from mlip_arena.models import MLIPEnum
|
89 |
-
from mlip_arena.
|
90 |
|
91 |
_valid_dynamics: dict[str, tuple[str, ...]] = {
|
92 |
"nve": ("velocityverlet",),
|
@@ -201,14 +198,12 @@ def _generate_task_run_name():
|
|
201 |
name="MD",
|
202 |
task_run_name=_generate_task_run_name,
|
203 |
cache_policy=TASK_SOURCE + INPUTS
|
204 |
-
# cache_key_fn=task_input_hash,
|
205 |
-
# cache_expiration=timedelta(days=1)
|
206 |
)
|
207 |
def run(
|
208 |
atoms: Atoms,
|
209 |
calculator_name: str | MLIPEnum,
|
210 |
calculator_kwargs: dict | None,
|
211 |
-
dispersion:
|
212 |
dispersion_kwargs: dict | None = None,
|
213 |
device: str | None = None,
|
214 |
ensemble: Literal["nve", "nvt", "npt"] = "nvt",
|
@@ -225,37 +220,14 @@ def run(
|
|
225 |
traj_interval: int = 1,
|
226 |
restart: bool = True,
|
227 |
):
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
calc = calculator_name.value(**calculator_kwargs)
|
237 |
-
elif (
|
238 |
-
isinstance(calculator_name, str) and calculator_name in MLIPEnum._member_names_
|
239 |
-
):
|
240 |
-
calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
|
241 |
-
else:
|
242 |
-
raise ValueError(f"Invalid calculator: {calculator_name}")
|
243 |
-
|
244 |
-
print(f"Using calculator: {calc}")
|
245 |
-
|
246 |
-
dispersion_kwargs = dispersion_kwargs or {}
|
247 |
-
|
248 |
-
dispersion_kwargs.update({"device": device})
|
249 |
-
|
250 |
-
if dispersion is not None:
|
251 |
-
disp_calc = TorchDFTD3Calculator(
|
252 |
-
**dispersion_kwargs,
|
253 |
-
)
|
254 |
-
calc = SumCalculator([calc, disp_calc])
|
255 |
-
|
256 |
-
print(f"Using dispersion: {dispersion}")
|
257 |
-
|
258 |
-
atoms.calc = calc
|
259 |
|
260 |
if time_step is None:
|
261 |
# If a structure contains an isotope of hydrogen, set default `time_step`
|
|
|
65 |
from prefect.runtime import task_run
|
66 |
from scipy.interpolate import interp1d
|
67 |
from scipy.linalg import schur
|
|
|
68 |
from tqdm.auto import tqdm
|
69 |
|
70 |
from ase import Atoms, units
|
|
|
|
|
71 |
from ase.io import read
|
72 |
from ase.io.trajectory import Trajectory
|
73 |
from ase.md.andersen import Andersen
|
|
|
83 |
)
|
84 |
from ase.md.verlet import VelocityVerlet
|
85 |
from mlip_arena.models import MLIPEnum
|
86 |
+
from mlip_arena.tasks.utils import get_calculator
|
87 |
|
88 |
_valid_dynamics: dict[str, tuple[str, ...]] = {
|
89 |
"nve": ("velocityverlet",),
|
|
|
198 |
name="MD",
|
199 |
task_run_name=_generate_task_run_name,
|
200 |
cache_policy=TASK_SOURCE + INPUTS
|
|
|
|
|
201 |
)
|
202 |
def run(
|
203 |
atoms: Atoms,
|
204 |
calculator_name: str | MLIPEnum,
|
205 |
calculator_kwargs: dict | None,
|
206 |
+
dispersion: bool = False,
|
207 |
dispersion_kwargs: dict | None = None,
|
208 |
device: str | None = None,
|
209 |
ensemble: Literal["nve", "nvt", "npt"] = "nvt",
|
|
|
220 |
traj_interval: int = 1,
|
221 |
restart: bool = True,
|
222 |
):
|
223 |
+
|
224 |
+
atoms.calc = get_calculator(
|
225 |
+
calculator_name=calculator_name,
|
226 |
+
calculator_kwargs=calculator_kwargs,
|
227 |
+
dispersion=dispersion,
|
228 |
+
dispersion_kwargs=dispersion_kwargs,
|
229 |
+
device=device,
|
230 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
if time_step is None:
|
233 |
# If a structure contains an isotope of hydrogen, set default `time_step`
|
mlip_arena/tasks/neb.py
CHANGED
@@ -57,6 +57,7 @@ from mlip_arena.tasks.optimize import run as OPT
|
|
57 |
from mlip_arena.tasks.utils import get_calculator
|
58 |
from pymatgen.io.ase import AseAtomsAdaptor
|
59 |
|
|
|
60 |
if TYPE_CHECKING:
|
61 |
pass
|
62 |
|
@@ -100,7 +101,7 @@ def run(
|
|
100 |
images: list[Atoms],
|
101 |
calculator_name: str | MLIPEnum,
|
102 |
calculator_kwargs: dict | None = None,
|
103 |
-
dispersion:
|
104 |
dispersion_kwargs: dict | None = None,
|
105 |
device: str | None = None,
|
106 |
optimizer: Optimizer | str = "MDMin", # type: ignore
|
@@ -159,11 +160,16 @@ def run(
|
|
159 |
optimizer_instance.run(**criterion)
|
160 |
|
161 |
neb_tool = NEBTools(neb.images)
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
return {
|
164 |
-
"barrier":
|
165 |
-
"images":
|
166 |
-
"forcefit":
|
167 |
}
|
168 |
|
169 |
|
@@ -188,6 +194,7 @@ def run_from_end_points(
|
|
188 |
interpolation: Literal["linear", "idpp"] = "idpp",
|
189 |
climb: bool = True,
|
190 |
traj_file: str | Path | None = None,
|
|
|
191 |
) -> dict[str, Any] | State:
|
192 |
"""Run the nudged elastic band (NEB) calculation from end points.
|
193 |
|
@@ -212,7 +219,9 @@ def run_from_end_points(
|
|
212 |
"""
|
213 |
|
214 |
if relax_end_points:
|
215 |
-
relax = OPT(
|
|
|
|
|
216 |
atoms=start.copy(),
|
217 |
calculator_name=calculator_name,
|
218 |
calculator_kwargs=calculator_kwargs,
|
@@ -225,7 +234,9 @@ def run_from_end_points(
|
|
225 |
)
|
226 |
start = relax["atoms"]
|
227 |
|
228 |
-
relax = OPT(
|
|
|
|
|
229 |
atoms=end.copy(),
|
230 |
calculator_name=calculator_name,
|
231 |
calculator_kwargs=calculator_kwargs,
|
@@ -252,7 +263,9 @@ def run_from_end_points(
|
|
252 |
|
253 |
images = [s.to_ase_atoms() for s in path]
|
254 |
|
255 |
-
return run(
|
|
|
|
|
256 |
images,
|
257 |
calculator_name,
|
258 |
calculator_kwargs=calculator_kwargs,
|
|
|
57 |
from mlip_arena.tasks.utils import get_calculator
|
58 |
from pymatgen.io.ase import AseAtomsAdaptor
|
59 |
|
60 |
+
|
61 |
if TYPE_CHECKING:
|
62 |
pass
|
63 |
|
|
|
101 |
images: list[Atoms],
|
102 |
calculator_name: str | MLIPEnum,
|
103 |
calculator_kwargs: dict | None = None,
|
104 |
+
dispersion: bool = False,
|
105 |
dispersion_kwargs: dict | None = None,
|
106 |
device: str | None = None,
|
107 |
optimizer: Optimizer | str = "MDMin", # type: ignore
|
|
|
160 |
optimizer_instance.run(**criterion)
|
161 |
|
162 |
neb_tool = NEBTools(neb.images)
|
163 |
+
barrier = neb_tool.get_barrier()
|
164 |
+
|
165 |
+
forcefit = fit_images(neb.images)
|
166 |
+
|
167 |
+
images = neb.images
|
168 |
|
169 |
return {
|
170 |
+
"barrier": barrier,
|
171 |
+
"images": images,
|
172 |
+
"forcefit": forcefit,
|
173 |
}
|
174 |
|
175 |
|
|
|
194 |
interpolation: Literal["linear", "idpp"] = "idpp",
|
195 |
climb: bool = True,
|
196 |
traj_file: str | Path | None = None,
|
197 |
+
cache_subtasks: bool = False,
|
198 |
) -> dict[str, Any] | State:
|
199 |
"""Run the nudged elastic band (NEB) calculation from end points.
|
200 |
|
|
|
219 |
"""
|
220 |
|
221 |
if relax_end_points:
|
222 |
+
relax = OPT.with_options(
|
223 |
+
refresh_cache=not cache_subtasks,
|
224 |
+
)(
|
225 |
atoms=start.copy(),
|
226 |
calculator_name=calculator_name,
|
227 |
calculator_kwargs=calculator_kwargs,
|
|
|
234 |
)
|
235 |
start = relax["atoms"]
|
236 |
|
237 |
+
relax = OPT.with_options(
|
238 |
+
refresh_cache=not cache_subtasks,
|
239 |
+
)(
|
240 |
atoms=end.copy(),
|
241 |
calculator_name=calculator_name,
|
242 |
calculator_kwargs=calculator_kwargs,
|
|
|
263 |
|
264 |
images = [s.to_ase_atoms() for s in path]
|
265 |
|
266 |
+
return run.with_options(
|
267 |
+
refresh_cache=not cache_subtasks,
|
268 |
+
)(
|
269 |
images,
|
270 |
calculator_name,
|
271 |
calculator_kwargs=calculator_kwargs,
|
mlip_arena/tasks/optimize.py
CHANGED
@@ -7,18 +7,15 @@ from __future__ import annotations
|
|
7 |
from prefect import task
|
8 |
from prefect.cache_policies import INPUTS, TASK_SOURCE
|
9 |
from prefect.runtime import task_run
|
10 |
-
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
|
11 |
|
12 |
from ase import Atoms
|
13 |
-
from ase.calculators.calculator import Calculator
|
14 |
-
from ase.calculators.mixing import SumCalculator
|
15 |
from ase.constraints import FixSymmetry
|
16 |
from ase.filters import * # type: ignore
|
17 |
from ase.filters import Filter
|
18 |
from ase.optimize import * # type: ignore
|
19 |
from ase.optimize.optimize import Optimizer
|
20 |
from mlip_arena.models import MLIPEnum
|
21 |
-
from mlip_arena.
|
22 |
|
23 |
_valid_filters: dict[str, Filter] = {
|
24 |
"Filter": Filter,
|
@@ -54,17 +51,13 @@ def _generate_task_run_name():
|
|
54 |
|
55 |
|
56 |
@task(
|
57 |
-
name="OPT",
|
58 |
-
task_run_name=_generate_task_run_name,
|
59 |
-
cache_policy=TASK_SOURCE + INPUTS
|
60 |
-
# cache_key_fn=task_input_hash,
|
61 |
-
# cache_expiration=timedelta(days=1)
|
62 |
)
|
63 |
def run(
|
64 |
atoms: Atoms,
|
65 |
calculator_name: str | MLIPEnum,
|
66 |
calculator_kwargs: dict | None = None,
|
67 |
-
dispersion:
|
68 |
dispersion_kwargs: dict | None = None,
|
69 |
device: str | None = None,
|
70 |
optimizer: Optimizer | str = BFGSLineSearch,
|
@@ -74,37 +67,13 @@ def run(
|
|
74 |
criterion: dict | None = None,
|
75 |
symmetry: bool = False,
|
76 |
):
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
assert issubclass(calculator_name.value, Calculator)
|
85 |
-
calc = calculator_name.value(**calculator_kwargs)
|
86 |
-
elif (
|
87 |
-
isinstance(calculator_name, str) and calculator_name in MLIPEnum._member_names_
|
88 |
-
):
|
89 |
-
calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
|
90 |
-
else:
|
91 |
-
raise ValueError(f"Invalid calculator: {calculator_name}")
|
92 |
-
|
93 |
-
print(f"Using calculator: {calc}")
|
94 |
-
|
95 |
-
dispersion_kwargs = dispersion_kwargs or {}
|
96 |
-
|
97 |
-
dispersion_kwargs.update({"device": device})
|
98 |
-
|
99 |
-
if dispersion is not None:
|
100 |
-
disp_calc = TorchDFTD3Calculator(
|
101 |
-
**dispersion_kwargs,
|
102 |
-
)
|
103 |
-
calc = SumCalculator([calc, disp_calc])
|
104 |
-
|
105 |
-
print(f"Using dispersion: {dispersion}")
|
106 |
-
|
107 |
-
atoms.calc = calc
|
108 |
|
109 |
if isinstance(filter, str):
|
110 |
if filter not in _valid_filters:
|
|
|
7 |
from prefect import task
|
8 |
from prefect.cache_policies import INPUTS, TASK_SOURCE
|
9 |
from prefect.runtime import task_run
|
|
|
10 |
|
11 |
from ase import Atoms
|
|
|
|
|
12 |
from ase.constraints import FixSymmetry
|
13 |
from ase.filters import * # type: ignore
|
14 |
from ase.filters import Filter
|
15 |
from ase.optimize import * # type: ignore
|
16 |
from ase.optimize.optimize import Optimizer
|
17 |
from mlip_arena.models import MLIPEnum
|
18 |
+
from mlip_arena.tasks.utils import get_calculator
|
19 |
|
20 |
_valid_filters: dict[str, Filter] = {
|
21 |
"Filter": Filter,
|
|
|
51 |
|
52 |
|
53 |
@task(
|
54 |
+
name="OPT", task_run_name=_generate_task_run_name, cache_policy=TASK_SOURCE + INPUTS
|
|
|
|
|
|
|
|
|
55 |
)
|
56 |
def run(
|
57 |
atoms: Atoms,
|
58 |
calculator_name: str | MLIPEnum,
|
59 |
calculator_kwargs: dict | None = None,
|
60 |
+
dispersion: bool = False,
|
61 |
dispersion_kwargs: dict | None = None,
|
62 |
device: str | None = None,
|
63 |
optimizer: Optimizer | str = BFGSLineSearch,
|
|
|
67 |
criterion: dict | None = None,
|
68 |
symmetry: bool = False,
|
69 |
):
|
70 |
+
atoms.calc = get_calculator(
|
71 |
+
calculator_name=calculator_name,
|
72 |
+
calculator_kwargs=calculator_kwargs,
|
73 |
+
dispersion=dispersion,
|
74 |
+
dispersion_kwargs=dispersion_kwargs,
|
75 |
+
device=device,
|
76 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
if isinstance(filter, str):
|
79 |
if filter not in _valid_filters:
|
mlip_arena/tasks/utils.py
CHANGED
@@ -4,24 +4,33 @@ from __future__ import annotations
|
|
4 |
|
5 |
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
|
6 |
|
|
|
7 |
from ase.calculators.calculator import Calculator
|
8 |
from ase.calculators.mixing import SumCalculator
|
9 |
-
from ase.filters import * # type: ignore
|
10 |
-
from ase.optimize import * # type: ignore
|
11 |
from mlip_arena.models import MLIPEnum
|
12 |
from mlip_arena.models.utils import get_freer_device
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
def get_calculator(
|
16 |
calculator_name: str | MLIPEnum,
|
17 |
calculator_kwargs: dict | None,
|
18 |
-
dispersion:
|
19 |
dispersion_kwargs: dict | None = None,
|
20 |
device: str | None = None,
|
21 |
-
) -> Calculator:
|
|
|
22 |
device = device or str(get_freer_device())
|
23 |
|
24 |
-
|
25 |
|
26 |
calculator_kwargs = calculator_kwargs or {}
|
27 |
|
@@ -33,19 +42,25 @@ def get_calculator(
|
|
33 |
else:
|
34 |
raise ValueError(f"Invalid calculator: {calculator_name}")
|
35 |
|
36 |
-
|
|
|
|
|
37 |
|
38 |
-
dispersion_kwargs = dispersion_kwargs or
|
|
|
|
|
39 |
|
40 |
dispersion_kwargs.update({"device": device})
|
41 |
|
42 |
-
if dispersion
|
43 |
disp_calc = TorchDFTD3Calculator(
|
44 |
**dispersion_kwargs,
|
45 |
)
|
46 |
calc = SumCalculator([calc, disp_calc])
|
47 |
|
48 |
-
|
|
|
|
|
49 |
|
50 |
-
assert isinstance(calc, Calculator)
|
51 |
return calc
|
|
|
4 |
|
5 |
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
|
6 |
|
7 |
+
from ase import units
|
8 |
from ase.calculators.calculator import Calculator
|
9 |
from ase.calculators.mixing import SumCalculator
|
|
|
|
|
10 |
from mlip_arena.models import MLIPEnum
|
11 |
from mlip_arena.models.utils import get_freer_device
|
12 |
|
13 |
+
try:
|
14 |
+
from prefect.logging import get_run_logger
|
15 |
+
|
16 |
+
logger = get_run_logger()
|
17 |
+
except (ImportError, RuntimeError):
|
18 |
+
from loguru import logger
|
19 |
+
|
20 |
+
from pprint import pformat
|
21 |
+
|
22 |
|
23 |
def get_calculator(
|
24 |
calculator_name: str | MLIPEnum,
|
25 |
calculator_kwargs: dict | None,
|
26 |
+
dispersion: bool = False,
|
27 |
dispersion_kwargs: dict | None = None,
|
28 |
device: str | None = None,
|
29 |
+
) -> Calculator | SumCalculator:
|
30 |
+
"""Get a calculator with optional dispersion correction."""
|
31 |
device = device or str(get_freer_device())
|
32 |
|
33 |
+
logger.info(f"Using device: {device}")
|
34 |
|
35 |
calculator_kwargs = calculator_kwargs or {}
|
36 |
|
|
|
42 |
else:
|
43 |
raise ValueError(f"Invalid calculator: {calculator_name}")
|
44 |
|
45 |
+
logger.info(f"Using calculator: {calc}")
|
46 |
+
if calculator_kwargs:
|
47 |
+
logger.info(pformat(calculator_kwargs))
|
48 |
|
49 |
+
dispersion_kwargs = dispersion_kwargs or dict(
|
50 |
+
damping="bj", xc="pbe", cutoff=40.0 * units.Bohr
|
51 |
+
)
|
52 |
|
53 |
dispersion_kwargs.update({"device": device})
|
54 |
|
55 |
+
if dispersion:
|
56 |
disp_calc = TorchDFTD3Calculator(
|
57 |
**dispersion_kwargs,
|
58 |
)
|
59 |
calc = SumCalculator([calc, disp_calc])
|
60 |
|
61 |
+
logger.info(f"Using dispersion: {disp_calc}")
|
62 |
+
if dispersion_kwargs:
|
63 |
+
logger.info(pformat(dispersion_kwargs))
|
64 |
|
65 |
+
assert isinstance(calc, Calculator) or isinstance(calc, SumCalculator)
|
66 |
return calc
|