cyrusyc commited on
Commit
a787930
·
1 Parent(s): f791dcc

improve caching, unify calculator init

Browse files
mlip_arena/models/externals/mace-mp.py CHANGED
@@ -11,7 +11,7 @@ from mlip_arena.models.utils import get_freer_device
11
  class MACE_MP_Medium(MACECalculator):
12
  def __init__(
13
  self,
14
- checkpoint="http://tinyurl.com/5yyxdm76",
15
  device: str | None = None,
16
  default_dtype="float32",
17
  **kwargs,
 
11
  class MACE_MP_Medium(MACECalculator):
12
  def __init__(
13
  self,
14
+ checkpoint="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model",
15
  device: str | None = None,
16
  default_dtype="float32",
17
  **kwargs,
mlip_arena/models/externals/mattersim.py CHANGED
@@ -5,11 +5,14 @@ from pathlib import Path
5
  import yaml
6
  from mattersim.forcefield import MatterSimCalculator
7
 
 
8
  from mlip_arena.models.utils import get_freer_device
 
9
 
10
  with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f:
11
  REGISTRY = yaml.safe_load(f)
12
 
 
13
  class MatterSim(MatterSimCalculator):
14
  def __init__(
15
  self,
@@ -18,7 +21,19 @@ class MatterSim(MatterSimCalculator):
18
  **kwargs,
19
  ):
20
  super().__init__(
21
- load_path=checkpoint,
22
- device=str(device or get_freer_device()),
23
- **kwargs
24
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import yaml
6
  from mattersim.forcefield import MatterSimCalculator
7
 
8
+ from ase import Atoms
9
  from mlip_arena.models.utils import get_freer_device
10
+ # from pymatgen.io.ase import AseAtomsAdaptor, MSONAtoms
11
 
12
  with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f:
13
  REGISTRY = yaml.safe_load(f)
14
 
15
+
16
  class MatterSim(MatterSimCalculator):
17
  def __init__(
18
  self,
 
21
  **kwargs,
22
  ):
23
  super().__init__(
24
+ load_path=checkpoint, device=str(device or get_freer_device()), **kwargs
25
+ )
26
+
27
+ def calculate(
28
+ self,
29
+ atoms: Atoms | None = None,
30
+ properties: list | None = None,
31
+ system_changes: list | None = None,
32
+ ):
33
+ super().calculate(atoms, properties, system_changes)
34
+
35
+ # # convert unpicklizable atoms back to picklizable atoms to avoid prefect pickling error
36
+ # if isinstance(self.atoms, MSONAtoms):
37
+ # atoms = self.atoms.copy()
38
+ # strucutre = AseAtomsAdaptor().get_structure(atoms)
39
+ # self.atoms = AseAtomsAdaptor().get_atoms(strucutre, msonable=False)
mlip_arena/tasks/elasticity.py CHANGED
@@ -90,7 +90,7 @@ def run(
90
  normal_strains: list[float] | np.ndarray | None = np.linspace(-0.01, 0.01, 4),
91
  shear_strains: list[float] | np.ndarray | None = np.linspace(-0.06, 0.06, 4),
92
  persist_opt: bool = True,
93
- cache_opt: bool = True,
94
  ) -> dict[str, Any] | State:
95
  """
96
  Compute the elastic tensor for the given structure and calculator.
 
90
  normal_strains: list[float] | np.ndarray | None = np.linspace(-0.01, 0.01, 4),
91
  shear_strains: list[float] | np.ndarray | None = np.linspace(-0.06, 0.06, 4),
92
  persist_opt: bool = True,
93
+ cache_opt: bool = False,
94
  ) -> dict[str, Any] | State:
95
  """
96
  Compute the elastic tensor for the given structure and calculator.
mlip_arena/tasks/eos.py CHANGED
@@ -17,8 +17,6 @@ from prefect.runtime import task_run
17
  from prefect.states import State
18
 
19
  from ase import Atoms
20
- from ase.filters import * # type: ignore
21
- from ase.optimize import * # type: ignore
22
  from ase.optimize.optimize import Optimizer
23
  from mlip_arena.models import MLIPEnum
24
  from mlip_arena.tasks.optimize import run as OPT
@@ -54,6 +52,7 @@ def run(
54
  max_abs_strain: float = 0.1,
55
  npoints: int = 11,
56
  concurrent: bool = True,
 
57
  ) -> dict[str, Any] | State:
58
  """
59
  Compute the equation of state (EOS) for the given atoms and calculator.
@@ -78,7 +77,12 @@ def run(
78
  A dictionary containing the EOS data, bulk modulus, equilibrium volume, and equilibrium energy if successful. Otherwise, a prefect state object.
79
  """
80
 
81
- state = OPT(
 
 
 
 
 
82
  atoms=atoms,
83
  calculator_name=calculator_name,
84
  calculator_kwargs=calculator_kwargs,
@@ -112,7 +116,7 @@ def run(
112
  atoms = relaxed.copy()
113
  atoms.set_cell(c0 * f, scale_atoms=True)
114
 
115
- future = OPT.submit(
116
  atoms=atoms,
117
  calculator_name=calculator_name,
118
  calculator_kwargs=calculator_kwargs,
@@ -138,7 +142,7 @@ def run(
138
  atoms = relaxed.copy()
139
  atoms.set_cell(c0 * f, scale_atoms=True)
140
 
141
- state = OPT(
142
  atoms=atoms,
143
  calculator_name=calculator_name,
144
  calculator_kwargs=calculator_kwargs,
 
17
  from prefect.states import State
18
 
19
  from ase import Atoms
 
 
20
  from ase.optimize.optimize import Optimizer
21
  from mlip_arena.models import MLIPEnum
22
  from mlip_arena.tasks.optimize import run as OPT
 
52
  max_abs_strain: float = 0.1,
53
  npoints: int = 11,
54
  concurrent: bool = True,
55
+ cache_opt: bool = False,
56
  ) -> dict[str, Any] | State:
57
  """
58
  Compute the equation of state (EOS) for the given atoms and calculator.
 
77
  A dictionary containing the EOS data, bulk modulus, equilibrium volume, and equilibrium energy if successful. Otherwise, a prefect state object.
78
  """
79
 
80
+ OPT_ = OPT.with_options(
81
+ refresh_cache=not cache_opt,
82
+ persist_result=cache_opt,
83
+ )
84
+
85
+ state = OPT_(
86
  atoms=atoms,
87
  calculator_name=calculator_name,
88
  calculator_kwargs=calculator_kwargs,
 
116
  atoms = relaxed.copy()
117
  atoms.set_cell(c0 * f, scale_atoms=True)
118
 
119
+ future = OPT_.submit(
120
  atoms=atoms,
121
  calculator_name=calculator_name,
122
  calculator_kwargs=calculator_kwargs,
 
142
  atoms = relaxed.copy()
143
  atoms.set_cell(c0 * f, scale_atoms=True)
144
 
145
+ state = OPT_(
146
  atoms=atoms,
147
  calculator_name=calculator_name,
148
  calculator_kwargs=calculator_kwargs,
mlip_arena/tasks/md.py CHANGED
@@ -65,12 +65,9 @@ from prefect.cache_policies import INPUTS, TASK_SOURCE
65
  from prefect.runtime import task_run
66
  from scipy.interpolate import interp1d
67
  from scipy.linalg import schur
68
- from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
69
  from tqdm.auto import tqdm
70
 
71
  from ase import Atoms, units
72
- from ase.calculators.calculator import Calculator
73
- from ase.calculators.mixing import SumCalculator
74
  from ase.io import read
75
  from ase.io.trajectory import Trajectory
76
  from ase.md.andersen import Andersen
@@ -86,7 +83,7 @@ from ase.md.velocitydistribution import (
86
  )
87
  from ase.md.verlet import VelocityVerlet
88
  from mlip_arena.models import MLIPEnum
89
- from mlip_arena.models.utils import get_freer_device
90
 
91
  _valid_dynamics: dict[str, tuple[str, ...]] = {
92
  "nve": ("velocityverlet",),
@@ -201,14 +198,12 @@ def _generate_task_run_name():
201
  name="MD",
202
  task_run_name=_generate_task_run_name,
203
  cache_policy=TASK_SOURCE + INPUTS
204
- # cache_key_fn=task_input_hash,
205
- # cache_expiration=timedelta(days=1)
206
  )
207
  def run(
208
  atoms: Atoms,
209
  calculator_name: str | MLIPEnum,
210
  calculator_kwargs: dict | None,
211
- dispersion: str | None = None,
212
  dispersion_kwargs: dict | None = None,
213
  device: str | None = None,
214
  ensemble: Literal["nve", "nvt", "npt"] = "nvt",
@@ -225,37 +220,14 @@ def run(
225
  traj_interval: int = 1,
226
  restart: bool = True,
227
  ):
228
- device = device or str(get_freer_device())
229
-
230
- print(f"Using device: {device}")
231
-
232
- calculator_kwargs = calculator_kwargs or {}
233
-
234
- if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
235
- assert issubclass(calculator_name.value, Calculator)
236
- calc = calculator_name.value(**calculator_kwargs)
237
- elif (
238
- isinstance(calculator_name, str) and calculator_name in MLIPEnum._member_names_
239
- ):
240
- calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
241
- else:
242
- raise ValueError(f"Invalid calculator: {calculator_name}")
243
-
244
- print(f"Using calculator: {calc}")
245
-
246
- dispersion_kwargs = dispersion_kwargs or {}
247
-
248
- dispersion_kwargs.update({"device": device})
249
-
250
- if dispersion is not None:
251
- disp_calc = TorchDFTD3Calculator(
252
- **dispersion_kwargs,
253
- )
254
- calc = SumCalculator([calc, disp_calc])
255
-
256
- print(f"Using dispersion: {dispersion}")
257
-
258
- atoms.calc = calc
259
 
260
  if time_step is None:
261
  # If a structure contains an isotope of hydrogen, set default `time_step`
 
65
  from prefect.runtime import task_run
66
  from scipy.interpolate import interp1d
67
  from scipy.linalg import schur
 
68
  from tqdm.auto import tqdm
69
 
70
  from ase import Atoms, units
 
 
71
  from ase.io import read
72
  from ase.io.trajectory import Trajectory
73
  from ase.md.andersen import Andersen
 
83
  )
84
  from ase.md.verlet import VelocityVerlet
85
  from mlip_arena.models import MLIPEnum
86
+ from mlip_arena.tasks.utils import get_calculator
87
 
88
  _valid_dynamics: dict[str, tuple[str, ...]] = {
89
  "nve": ("velocityverlet",),
 
198
  name="MD",
199
  task_run_name=_generate_task_run_name,
200
  cache_policy=TASK_SOURCE + INPUTS
 
 
201
  )
202
  def run(
203
  atoms: Atoms,
204
  calculator_name: str | MLIPEnum,
205
  calculator_kwargs: dict | None,
206
+ dispersion: bool = False,
207
  dispersion_kwargs: dict | None = None,
208
  device: str | None = None,
209
  ensemble: Literal["nve", "nvt", "npt"] = "nvt",
 
220
  traj_interval: int = 1,
221
  restart: bool = True,
222
  ):
223
+
224
+ atoms.calc = get_calculator(
225
+ calculator_name=calculator_name,
226
+ calculator_kwargs=calculator_kwargs,
227
+ dispersion=dispersion,
228
+ dispersion_kwargs=dispersion_kwargs,
229
+ device=device,
230
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  if time_step is None:
233
  # If a structure contains an isotope of hydrogen, set default `time_step`
mlip_arena/tasks/neb.py CHANGED
@@ -57,6 +57,7 @@ from mlip_arena.tasks.optimize import run as OPT
57
  from mlip_arena.tasks.utils import get_calculator
58
  from pymatgen.io.ase import AseAtomsAdaptor
59
 
 
60
  if TYPE_CHECKING:
61
  pass
62
 
@@ -100,7 +101,7 @@ def run(
100
  images: list[Atoms],
101
  calculator_name: str | MLIPEnum,
102
  calculator_kwargs: dict | None = None,
103
- dispersion: str | None = None,
104
  dispersion_kwargs: dict | None = None,
105
  device: str | None = None,
106
  optimizer: Optimizer | str = "MDMin", # type: ignore
@@ -159,11 +160,16 @@ def run(
159
  optimizer_instance.run(**criterion)
160
 
161
  neb_tool = NEBTools(neb.images)
 
 
 
 
 
162
 
163
  return {
164
- "barrier": neb_tool.get_barrier(),
165
- "images": neb.images,
166
- "forcefit": fit_images(neb.images),
167
  }
168
 
169
 
@@ -188,6 +194,7 @@ def run_from_end_points(
188
  interpolation: Literal["linear", "idpp"] = "idpp",
189
  climb: bool = True,
190
  traj_file: str | Path | None = None,
 
191
  ) -> dict[str, Any] | State:
192
  """Run the nudged elastic band (NEB) calculation from end points.
193
 
@@ -212,7 +219,9 @@ def run_from_end_points(
212
  """
213
 
214
  if relax_end_points:
215
- relax = OPT(
 
 
216
  atoms=start.copy(),
217
  calculator_name=calculator_name,
218
  calculator_kwargs=calculator_kwargs,
@@ -225,7 +234,9 @@ def run_from_end_points(
225
  )
226
  start = relax["atoms"]
227
 
228
- relax = OPT(
 
 
229
  atoms=end.copy(),
230
  calculator_name=calculator_name,
231
  calculator_kwargs=calculator_kwargs,
@@ -252,7 +263,9 @@ def run_from_end_points(
252
 
253
  images = [s.to_ase_atoms() for s in path]
254
 
255
- return run(
 
 
256
  images,
257
  calculator_name,
258
  calculator_kwargs=calculator_kwargs,
 
57
  from mlip_arena.tasks.utils import get_calculator
58
  from pymatgen.io.ase import AseAtomsAdaptor
59
 
60
+
61
  if TYPE_CHECKING:
62
  pass
63
 
 
101
  images: list[Atoms],
102
  calculator_name: str | MLIPEnum,
103
  calculator_kwargs: dict | None = None,
104
+ dispersion: bool = False,
105
  dispersion_kwargs: dict | None = None,
106
  device: str | None = None,
107
  optimizer: Optimizer | str = "MDMin", # type: ignore
 
160
  optimizer_instance.run(**criterion)
161
 
162
  neb_tool = NEBTools(neb.images)
163
+ barrier = neb_tool.get_barrier()
164
+
165
+ forcefit = fit_images(neb.images)
166
+
167
+ images = neb.images
168
 
169
  return {
170
+ "barrier": barrier,
171
+ "images": images,
172
+ "forcefit": forcefit,
173
  }
174
 
175
 
 
194
  interpolation: Literal["linear", "idpp"] = "idpp",
195
  climb: bool = True,
196
  traj_file: str | Path | None = None,
197
+ cache_subtasks: bool = False,
198
  ) -> dict[str, Any] | State:
199
  """Run the nudged elastic band (NEB) calculation from end points.
200
 
 
219
  """
220
 
221
  if relax_end_points:
222
+ relax = OPT.with_options(
223
+ refresh_cache=not cache_subtasks,
224
+ )(
225
  atoms=start.copy(),
226
  calculator_name=calculator_name,
227
  calculator_kwargs=calculator_kwargs,
 
234
  )
235
  start = relax["atoms"]
236
 
237
+ relax = OPT.with_options(
238
+ refresh_cache=not cache_subtasks,
239
+ )(
240
  atoms=end.copy(),
241
  calculator_name=calculator_name,
242
  calculator_kwargs=calculator_kwargs,
 
263
 
264
  images = [s.to_ase_atoms() for s in path]
265
 
266
+ return run.with_options(
267
+ refresh_cache=not cache_subtasks,
268
+ )(
269
  images,
270
  calculator_name,
271
  calculator_kwargs=calculator_kwargs,
mlip_arena/tasks/optimize.py CHANGED
@@ -7,18 +7,15 @@ from __future__ import annotations
7
  from prefect import task
8
  from prefect.cache_policies import INPUTS, TASK_SOURCE
9
  from prefect.runtime import task_run
10
- from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
11
 
12
  from ase import Atoms
13
- from ase.calculators.calculator import Calculator
14
- from ase.calculators.mixing import SumCalculator
15
  from ase.constraints import FixSymmetry
16
  from ase.filters import * # type: ignore
17
  from ase.filters import Filter
18
  from ase.optimize import * # type: ignore
19
  from ase.optimize.optimize import Optimizer
20
  from mlip_arena.models import MLIPEnum
21
- from mlip_arena.models.utils import get_freer_device
22
 
23
  _valid_filters: dict[str, Filter] = {
24
  "Filter": Filter,
@@ -54,17 +51,13 @@ def _generate_task_run_name():
54
 
55
 
56
  @task(
57
- name="OPT",
58
- task_run_name=_generate_task_run_name,
59
- cache_policy=TASK_SOURCE + INPUTS
60
- # cache_key_fn=task_input_hash,
61
- # cache_expiration=timedelta(days=1)
62
  )
63
  def run(
64
  atoms: Atoms,
65
  calculator_name: str | MLIPEnum,
66
  calculator_kwargs: dict | None = None,
67
- dispersion: str | None = None,
68
  dispersion_kwargs: dict | None = None,
69
  device: str | None = None,
70
  optimizer: Optimizer | str = BFGSLineSearch,
@@ -74,37 +67,13 @@ def run(
74
  criterion: dict | None = None,
75
  symmetry: bool = False,
76
  ):
77
- device = device or str(get_freer_device())
78
-
79
- print(f"Using device: {device}")
80
-
81
- calculator_kwargs = calculator_kwargs or {}
82
-
83
- if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
84
- assert issubclass(calculator_name.value, Calculator)
85
- calc = calculator_name.value(**calculator_kwargs)
86
- elif (
87
- isinstance(calculator_name, str) and calculator_name in MLIPEnum._member_names_
88
- ):
89
- calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
90
- else:
91
- raise ValueError(f"Invalid calculator: {calculator_name}")
92
-
93
- print(f"Using calculator: {calc}")
94
-
95
- dispersion_kwargs = dispersion_kwargs or {}
96
-
97
- dispersion_kwargs.update({"device": device})
98
-
99
- if dispersion is not None:
100
- disp_calc = TorchDFTD3Calculator(
101
- **dispersion_kwargs,
102
- )
103
- calc = SumCalculator([calc, disp_calc])
104
-
105
- print(f"Using dispersion: {dispersion}")
106
-
107
- atoms.calc = calc
108
 
109
  if isinstance(filter, str):
110
  if filter not in _valid_filters:
 
7
  from prefect import task
8
  from prefect.cache_policies import INPUTS, TASK_SOURCE
9
  from prefect.runtime import task_run
 
10
 
11
  from ase import Atoms
 
 
12
  from ase.constraints import FixSymmetry
13
  from ase.filters import * # type: ignore
14
  from ase.filters import Filter
15
  from ase.optimize import * # type: ignore
16
  from ase.optimize.optimize import Optimizer
17
  from mlip_arena.models import MLIPEnum
18
+ from mlip_arena.tasks.utils import get_calculator
19
 
20
  _valid_filters: dict[str, Filter] = {
21
  "Filter": Filter,
 
51
 
52
 
53
  @task(
54
+ name="OPT", task_run_name=_generate_task_run_name, cache_policy=TASK_SOURCE + INPUTS
 
 
 
 
55
  )
56
  def run(
57
  atoms: Atoms,
58
  calculator_name: str | MLIPEnum,
59
  calculator_kwargs: dict | None = None,
60
+ dispersion: bool = False,
61
  dispersion_kwargs: dict | None = None,
62
  device: str | None = None,
63
  optimizer: Optimizer | str = BFGSLineSearch,
 
67
  criterion: dict | None = None,
68
  symmetry: bool = False,
69
  ):
70
+ atoms.calc = get_calculator(
71
+ calculator_name=calculator_name,
72
+ calculator_kwargs=calculator_kwargs,
73
+ dispersion=dispersion,
74
+ dispersion_kwargs=dispersion_kwargs,
75
+ device=device,
76
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  if isinstance(filter, str):
79
  if filter not in _valid_filters:
mlip_arena/tasks/utils.py CHANGED
@@ -4,24 +4,33 @@ from __future__ import annotations
4
 
5
  from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
6
 
 
7
  from ase.calculators.calculator import Calculator
8
  from ase.calculators.mixing import SumCalculator
9
- from ase.filters import * # type: ignore
10
- from ase.optimize import * # type: ignore
11
  from mlip_arena.models import MLIPEnum
12
  from mlip_arena.models.utils import get_freer_device
13
 
 
 
 
 
 
 
 
 
 
14
 
15
  def get_calculator(
16
  calculator_name: str | MLIPEnum,
17
  calculator_kwargs: dict | None,
18
- dispersion: str | None = None,
19
  dispersion_kwargs: dict | None = None,
20
  device: str | None = None,
21
- ) -> Calculator:
 
22
  device = device or str(get_freer_device())
23
 
24
- print(f"Using device: {device}")
25
 
26
  calculator_kwargs = calculator_kwargs or {}
27
 
@@ -33,19 +42,25 @@ def get_calculator(
33
  else:
34
  raise ValueError(f"Invalid calculator: {calculator_name}")
35
 
36
- print(f"Using calculator: {calc}")
 
 
37
 
38
- dispersion_kwargs = dispersion_kwargs or {}
 
 
39
 
40
  dispersion_kwargs.update({"device": device})
41
 
42
- if dispersion is not None:
43
  disp_calc = TorchDFTD3Calculator(
44
  **dispersion_kwargs,
45
  )
46
  calc = SumCalculator([calc, disp_calc])
47
 
48
- print(f"Using dispersion: {dispersion}")
 
 
49
 
50
- assert isinstance(calc, Calculator)
51
  return calc
 
4
 
5
  from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
6
 
7
+ from ase import units
8
  from ase.calculators.calculator import Calculator
9
  from ase.calculators.mixing import SumCalculator
 
 
10
  from mlip_arena.models import MLIPEnum
11
  from mlip_arena.models.utils import get_freer_device
12
 
13
+ try:
14
+ from prefect.logging import get_run_logger
15
+
16
+ logger = get_run_logger()
17
+ except (ImportError, RuntimeError):
18
+ from loguru import logger
19
+
20
+ from pprint import pformat
21
+
22
 
23
  def get_calculator(
24
  calculator_name: str | MLIPEnum,
25
  calculator_kwargs: dict | None,
26
+ dispersion: bool = False,
27
  dispersion_kwargs: dict | None = None,
28
  device: str | None = None,
29
+ ) -> Calculator | SumCalculator:
30
+ """Get a calculator with optional dispersion correction."""
31
  device = device or str(get_freer_device())
32
 
33
+ logger.info(f"Using device: {device}")
34
 
35
  calculator_kwargs = calculator_kwargs or {}
36
 
 
42
  else:
43
  raise ValueError(f"Invalid calculator: {calculator_name}")
44
 
45
+ logger.info(f"Using calculator: {calc}")
46
+ if calculator_kwargs:
47
+ logger.info(pformat(calculator_kwargs))
48
 
49
+ dispersion_kwargs = dispersion_kwargs or dict(
50
+ damping="bj", xc="pbe", cutoff=40.0 * units.Bohr
51
+ )
52
 
53
  dispersion_kwargs.update({"device": device})
54
 
55
+ if dispersion:
56
  disp_calc = TorchDFTD3Calculator(
57
  **dispersion_kwargs,
58
  )
59
  calc = SumCalculator([calc, disp_calc])
60
 
61
+ logger.info(f"Using dispersion: {disp_calc}")
62
+ if dispersion_kwargs:
63
+ logger.info(pformat(dispersion_kwargs))
64
 
65
+ assert isinstance(calc, Calculator) or isinstance(calc, SumCalculator)
66
  return calc