Yuan (Cyrus) Chiang commited on
Commit
e59bc30
·
unverified ·
1 Parent(s): 587c7e5

Add neb task (#38)

Browse files

* add neb task

* return neb tool fit object

* add run from endpoints

* add neb test

* fix typo

* try parallel pytest

* add pytest-xdist

.github/workflows/test.yaml CHANGED
@@ -26,7 +26,7 @@ jobs:
26
  - name: Install uv
27
  uses: astral-sh/setup-uv@v4
28
  with:
29
- enable-cahce: true
30
  cache-dependency-glob: "pyproject.toml"
31
 
32
  - name: Set up Python ${{ matrix.python-version }}
@@ -63,4 +63,4 @@ jobs:
63
  PREFECT_API_KEY: ${{ secrets.PREFECT_API_KEY }}
64
  PREFECT_API_URL: ${{ secrets.PREFECT_API_URL }}
65
  run: |
66
- pytest -vvra tests
 
26
  - name: Install uv
27
  uses: astral-sh/setup-uv@v4
28
  with:
29
+ enable-cache: true
30
  cache-dependency-glob: "pyproject.toml"
31
 
32
  - name: Set up Python ${{ matrix.python-version }}
 
63
  PREFECT_API_KEY: ${{ secrets.PREFECT_API_KEY }}
64
  PREFECT_API_URL: ${{ secrets.PREFECT_API_URL }}
65
  run: |
66
+ pytest -vra tests -n 5
mlip_arena/tasks/neb.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines nudged elastic band (NEB) task
3
+
4
+ This module has been modified from MatCalc
5
+ https://github.com/materialsvirtuallab/matcalc/blob/main/src/matcalc/neb.py
6
+
7
+ https://github.com/materialsvirtuallab/matcalc/blob/main/LICENSE
8
+
9
+ BSD 3-Clause License
10
+
11
+ Copyright (c) 2023, Materials Virtual Lab
12
+
13
+ Redistribution and use in source and binary forms, with or without
14
+ modification, are permitted provided that the following conditions are met:
15
+
16
+ 1. Redistributions of source code must retain the above copyright notice, this
17
+ list of conditions and the following disclaimer.
18
+
19
+ 2. Redistributions in binary form must reproduce the above copyright notice,
20
+ this list of conditions and the following disclaimer in the documentation
21
+ and/or other materials provided with the distribution.
22
+
23
+ 3. Neither the name of the copyright holder nor the names of its
24
+ contributors may be used to endorse or promote products derived from
25
+ this software without specific prior written permission.
26
+
27
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
28
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
29
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
30
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
31
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
32
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
33
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
34
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
35
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
36
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37
+ """
38
+
39
+ from __future__ import annotations
40
+
41
+ from pathlib import Path
42
+ from typing import TYPE_CHECKING, Any, Literal
43
+
44
+ from prefect import task
45
+ from prefect.cache_policies import INPUTS, TASK_SOURCE
46
+ from prefect.runtime import task_run
47
+ from prefect.states import State
48
+
49
+ from ase import Atoms
50
+ from ase.filters import * # type: ignore
51
+ from ase.mep.neb import NEB, NEBTools
52
+ from ase.optimize import * # type: ignore
53
+ from ase.optimize.optimize import Optimizer
54
+ from ase.utils.forcecurve import fit_images
55
+ from mlip_arena.models import MLIPEnum
56
+ from mlip_arena.tasks.optimize import run as OPT
57
+ from mlip_arena.tasks.utils import get_calculator
58
+ from pymatgen.io.ase import AseAtomsAdaptor
59
+
60
+ if TYPE_CHECKING:
61
+ pass
62
+
63
+ _valid_optimizers: dict[str, Optimizer] = {
64
+ "MDMin": MDMin,
65
+ "FIRE": FIRE,
66
+ "FIRE2": FIRE2,
67
+ "LBFGS": LBFGS,
68
+ "LBFGSLineSearch": LBFGSLineSearch,
69
+ "BFGS": BFGS,
70
+ # "BFGSLineSearch": BFGSLineSearch, # NEB does not support BFGSLineSearch
71
+ "QuasiNewton": QuasiNewton,
72
+ "GPMin": GPMin,
73
+ "CellAwareBFGS": CellAwareBFGS,
74
+ "ODE12r": ODE12r,
75
+ } # type: ignore
76
+
77
+
78
+ def _generate_task_run_name():
79
+ task_name = task_run.task_name
80
+ parameters = task_run.parameters
81
+
82
+ if "images" in parameters:
83
+ atoms = parameters["images"][0]
84
+ elif "start" in parameters:
85
+ atoms = parameters["start"]
86
+ else:
87
+ raise ValueError("No images or start atoms found in parameters")
88
+
89
+ calculator_name = parameters["calculator_name"]
90
+
91
+ return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"
92
+
93
+
94
+ @task(
95
+ name="NEB from images",
96
+ task_run_name=_generate_task_run_name,
97
+ cache_policy=TASK_SOURCE + INPUTS,
98
+ )
99
+ def run(
100
+ images: list[Atoms],
101
+ calculator_name: str | MLIPEnum,
102
+ calculator_kwargs: dict | None = None,
103
+ dispersion: str | None = None,
104
+ dispersion_kwargs: dict | None = None,
105
+ device: str | None = None,
106
+ optimizer: Optimizer | str = "MDMin", # type: ignore
107
+ optimizer_kwargs: dict | None = None,
108
+ criterion: dict | None = None,
109
+ interpolation: Literal["linear", "idpp"] = "idpp",
110
+ climb: bool = True,
111
+ traj_file: str | Path | None = None,
112
+ ) -> dict[str, Any] | State:
113
+ """Run the nudged elastic band (NEB) calculation.
114
+
115
+ Args:
116
+ images (list[Atoms]): The images.
117
+ calculator_name (str | MLIPEnum): The calculator name.
118
+ calculator_kwargs (dict, optional): The calculator kwargs. Defaults to None.
119
+ dispersion (str, optional): The dispersion. Defaults to None.
120
+ dispersion_kwargs (dict, optional): The dispersion kwargs. Defaults to None.
121
+ device (str, optional): The device. Defaults to None.
122
+ optimizer (Optimizer | str, optional): The optimizer. Defaults to "BFGSLineSearch".
123
+ optimizer_kwargs (dict, optional): The optimizer kwargs. Defaults to None.
124
+ criterion (dict, optional): The criterion. Defaults to None.
125
+ interpolation (Literal['linear', 'idpp'], optional): The interpolation method. Defaults to "idpp".
126
+ climb (bool, optional): Whether to use the climbing image. Defaults to True.
127
+ traj_file (str | Path, optional): The trajectory file. Defaults to None.
128
+
129
+ Returns:
130
+ dict[str, Any] | State: The energy barrier.
131
+ """
132
+
133
+ calc = get_calculator(
134
+ calculator_name,
135
+ calculator_kwargs,
136
+ dispersion=dispersion,
137
+ dispersion_kwargs=dispersion_kwargs,
138
+ device=device,
139
+ )
140
+
141
+ for image in images:
142
+ assert isinstance(image, Atoms)
143
+ image.calc = calc
144
+
145
+ neb = NEB(images, climb=climb, allow_shared_calculator=True)
146
+
147
+ neb.interpolate(method=interpolation)
148
+
149
+ if isinstance(optimizer, str):
150
+ if optimizer not in _valid_optimizers:
151
+ raise ValueError(f"Invalid optimizer: {optimizer}")
152
+ optimizer = _valid_optimizers[optimizer]
153
+
154
+ optimizer_kwargs = optimizer_kwargs or {}
155
+ criterion = criterion or {}
156
+
157
+ optimizer_instance = optimizer(neb, trajectory=traj_file, **optimizer_kwargs) # type: ignore
158
+
159
+ optimizer_instance.run(**criterion)
160
+
161
+ neb_tool = NEBTools(neb.images)
162
+
163
+ return {
164
+ "barrier": neb_tool.get_barrier(),
165
+ "images": neb.images,
166
+ "forcefit": fit_images(neb.images),
167
+ }
168
+
169
+
170
+ @task(
171
+ name="NEB from end points",
172
+ task_run_name=_generate_task_run_name,
173
+ cache_policy=TASK_SOURCE + INPUTS,
174
+ )
175
+ def run_from_end_points(
176
+ start: Atoms,
177
+ end: Atoms,
178
+ n_images: int,
179
+ calculator_name: str | MLIPEnum,
180
+ calculator_kwargs: dict | None = None,
181
+ dispersion: str | None = None,
182
+ dispersion_kwargs: dict | None = None,
183
+ device: str | None = None,
184
+ optimizer: Optimizer | str = "BFGS", # type: ignore
185
+ optimizer_kwargs: dict | None = None,
186
+ criterion: dict | None = None,
187
+ relax_end_points: bool = True,
188
+ interpolation: Literal["linear", "idpp"] = "idpp",
189
+ climb: bool = True,
190
+ traj_file: str | Path | None = None,
191
+ ) -> dict[str, Any] | State:
192
+ """Run the nudged elastic band (NEB) calculation from end points.
193
+
194
+ Args:
195
+ start (Atoms): The start image.
196
+ end (Atoms): The end image.
197
+ n_images (int): The number of images.
198
+ calculator_name (str | MLIPEnum): The calculator name.
199
+ calculator_kwargs (dict, optional): The calculator kwargs. Defaults to None.
200
+ dispersion (str, optional): The dispersion. Defaults to None.
201
+ dispersion_kwargs (dict, optional): The dispersion kwargs. Defaults to None.
202
+ device (str, optional): The device. Defaults to None.
203
+ optimizer (Optimizer | str, optional): The optimizer. Defaults to "BFGSLineSearch".
204
+ optimizer_kwargs (dict, optional): The optimizer kwargs. Defaults to None.
205
+ criterion (dict, optional): The criterion. Defaults to None.
206
+ interpolation (Literal['linear', 'idpp'], optional): The interpolation method. Defaults to "idpp".
207
+ climb (bool, optional): Whether to use the climbing image. Defaults to True.
208
+ traj_file (str | Path, optional): The trajectory file. Defaults to None.
209
+
210
+ Returns:
211
+ dict[str, Any] | State: The energy barrier.
212
+ """
213
+
214
+ if relax_end_points:
215
+ relax = OPT(
216
+ atoms=start.copy(),
217
+ calculator_name=calculator_name,
218
+ calculator_kwargs=calculator_kwargs,
219
+ dispersion=dispersion,
220
+ dispersion_kwargs=dispersion_kwargs,
221
+ device=device,
222
+ optimizer=optimizer,
223
+ optimizer_kwargs=optimizer_kwargs,
224
+ criterion=criterion,
225
+ )
226
+ start = relax["atoms"]
227
+
228
+ relax = OPT(
229
+ atoms=end.copy(),
230
+ calculator_name=calculator_name,
231
+ calculator_kwargs=calculator_kwargs,
232
+ dispersion=dispersion,
233
+ dispersion_kwargs=dispersion_kwargs,
234
+ device=device,
235
+ optimizer=optimizer,
236
+ optimizer_kwargs=optimizer_kwargs,
237
+ criterion=criterion,
238
+ )
239
+ end = relax["atoms"]
240
+
241
+ path = (
242
+ AseAtomsAdaptor()
243
+ .get_structure(start)
244
+ .interpolate(
245
+ AseAtomsAdaptor().get_structure(end),
246
+ nimages=n_images - 1,
247
+ interpolate_lattices=False,
248
+ pbc=False,
249
+ autosort_tol=0.5,
250
+ )
251
+ )
252
+
253
+ images = [s.to_ase_atoms() for s in path]
254
+
255
+ return run(
256
+ images,
257
+ calculator_name,
258
+ calculator_kwargs=calculator_kwargs,
259
+ dispersion=dispersion,
260
+ dispersion_kwargs=dispersion_kwargs,
261
+ device=device,
262
+ optimizer=optimizer,
263
+ optimizer_kwargs=optimizer_kwargs,
264
+ criterion=criterion,
265
+ interpolation=interpolation,
266
+ climb=climb,
267
+ traj_file=traj_file,
268
+ )
mlip_arena/tasks/utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for MLIP models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
6
+
7
+ from ase.calculators.calculator import Calculator
8
+ from ase.calculators.mixing import SumCalculator
9
+ from ase.filters import * # type: ignore
10
+ from ase.optimize import * # type: ignore
11
+ from mlip_arena.models import MLIPEnum
12
+ from mlip_arena.models.utils import get_freer_device
13
+
14
+
15
+ def get_calculator(
16
+ calculator_name: str | MLIPEnum,
17
+ calculator_kwargs: dict | None,
18
+ dispersion: str | None = None,
19
+ dispersion_kwargs: dict | None = None,
20
+ device: str | None = None,
21
+ ) -> Calculator:
22
+ device = device or str(get_freer_device())
23
+
24
+ print(f"Using device: {device}")
25
+
26
+ calculator_kwargs = calculator_kwargs or {}
27
+
28
+ if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
29
+ assert issubclass(calculator_name.value, Calculator)
30
+ calc = calculator_name.value(**calculator_kwargs)
31
+ elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
32
+ calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
33
+ else:
34
+ raise ValueError(f"Invalid calculator: {calculator_name}")
35
+
36
+ print(f"Using calculator: {calc}")
37
+
38
+ dispersion_kwargs = dispersion_kwargs or {}
39
+
40
+ dispersion_kwargs.update({"device": device})
41
+
42
+ if dispersion is not None:
43
+ disp_calc = TorchDFTD3Calculator(
44
+ **dispersion_kwargs,
45
+ )
46
+ calc = SumCalculator([calc, disp_calc])
47
+
48
+ print(f"Using dispersion: {dispersion}")
49
+
50
+ assert isinstance(calc, Calculator)
51
+ return calc
pyproject.toml CHANGED
@@ -64,6 +64,7 @@ test = [
64
  "alignn==2024.5.27",
65
  "mattersim==1.0.0rc9",
66
  "pytest",
 
67
  "prefect>=3.0.4,<3.1.0",
68
  ]
69
  mace = [
 
64
  "alignn==2024.5.27",
65
  "mattersim==1.0.0rc9",
66
  "pytest",
67
+ "pytest-xdist",
68
  "prefect>=3.0.4,<3.1.0",
69
  ]
70
  mace = [
tests/test_neb.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import pytest
4
+ from mlip_arena.models import MLIPEnum
5
+ from mlip_arena.tasks.neb import run as NEB
6
+ from mlip_arena.tasks.neb import run_from_end_points as NEB
7
+ from prefect.testing.utilities import prefect_test_harness
8
+
9
+ from ase.spacegroup import crystal
10
+
11
+ pristine = crystal(
12
+ "Al", [(0, 0, 0)], spacegroup=225, cellpar=[4.05, 4.05, 4.05, 90, 90, 90]
13
+ ) * (3, 3, 3)
14
+
15
+ atoms = pristine.copy()
16
+ del atoms[0]
17
+ start = atoms.copy()
18
+
19
+ atoms = pristine.copy()
20
+ del atoms[1]
21
+ end = atoms.copy()
22
+
23
+
24
+ @pytest.mark.skipif(
25
+ sys.version_info[:2] != (3, 11),
26
+ reason="avoid prefect race condition on concurrent tasks",
27
+ )
28
+ @pytest.mark.parametrize("model", [MLIPEnum["MACE-MP(M)"]])
29
+ def test_neb(model: MLIPEnum):
30
+ """
31
+ Test NEB prefect workflow with a simple cubic lattice.
32
+ """
33
+
34
+ with prefect_test_harness():
35
+ result = NEB(
36
+ start=start.copy(),
37
+ end=end.copy(),
38
+ n_images=5,
39
+ calculator_name=model.name,
40
+ optimizer="FIRE2",
41
+ )
42
+
43
+ assert isinstance(result, dict)
44
+ assert isinstance(result["barrier"][0], float)
45
+ assert isinstance(result["barrier"][1], float)