Yuan (Cyrus) Chiang commited on
Commit
51638da
1 Parent(s): aeaedb4

refactor eos into task (#28)

Browse files

* refactor eos into task

* task name and cache

* avoid duplicate test on PR branches

.github/README.md CHANGED
@@ -53,7 +53,7 @@ streamlit run serve/app.py
53
  > The following are some tasks implemented:
54
  > - [Prefect structure optimization (OPT)](../mlip_arena/tasks/optimize.py)
55
  > - [Prefect molecular dynamics (MD)](../mlip_arena/tasks/md.py)
56
- > - [Prefect equation of states (EOS)](../mlip_arena/tasks/eos/run.py)
57
 
58
  1. Follow the task template to implement the task class and upload the script along with metadata to the MLIP Arena [here](../mlip_arena/tasks/README.md).
59
  2. Code a benchmark script to evaluate the performance of your model on the task. The script should be able to load the model and the dataset, and output the evaluation metrics.
 
53
  > The following are some tasks implemented:
54
  > - [Prefect structure optimization (OPT)](../mlip_arena/tasks/optimize.py)
55
  > - [Prefect molecular dynamics (MD)](../mlip_arena/tasks/md.py)
56
+ > - [Prefect equation of states (EOS)](../mlip_arena/tasks/eos.py)
57
 
58
  1. Follow the task template to implement the task class and upload the script along with metadata to the MLIP Arena [here](../mlip_arena/tasks/README.md).
59
  2. Code a benchmark script to evaluate the performance of your model on the task. The script should be able to load the model and the dataset, and output the evaluation metrics.
.github/workflows/test.yaml CHANGED
@@ -1,6 +1,10 @@
1
  name: Python Test
2
 
3
- on: [push, pull_request]
 
 
 
 
4
 
5
  # env:
6
  # UV_SYSTEM_PYTHON: 1
 
1
  name: Python Test
2
 
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ branches: [main]
8
 
9
  # env:
10
  # UV_SYSTEM_PYTHON: 1
mlip_arena/tasks/{eos/run.py → eos.py} RENAMED
@@ -1,5 +1,5 @@
1
  """
2
- Define equation of state flows.
3
 
4
  https://github.com/materialsvirtuallab/matcalc/blob/main/matcalc/eos.py
5
  """
@@ -9,36 +9,25 @@ from __future__ import annotations
9
  from typing import TYPE_CHECKING
10
 
11
  import numpy as np
 
 
 
 
 
12
  from ase import Atoms
13
  from ase.filters import * # type: ignore
14
  from ase.optimize import * # type: ignore
15
  from ase.optimize.optimize import Optimizer
16
- from prefect import flow
17
- from prefect.futures import wait
18
- from prefect.runtime import flow_run, task_run
19
- from pymatgen.analysis.eos import BirchMurnaghan
20
-
21
  from mlip_arena.models import MLIPEnum
22
  from mlip_arena.tasks.optimize import run as OPT
 
23
 
24
  if TYPE_CHECKING:
25
  from ase.filters import Filter
26
 
27
 
28
- def generate_flow_run_name():
29
- flow_name = flow_run.flow_name
30
-
31
- parameters = flow_run.parameters
32
-
33
- atoms = parameters["atoms"]
34
- calculator_name = parameters["calculator_name"]
35
-
36
- return f"{flow_name}: {atoms.get_chemical_formula()} - {calculator_name}"
37
-
38
-
39
- def generate_task_run_name():
40
  task_name = task_run.task_name
41
-
42
  parameters = task_run.parameters
43
 
44
  atoms = parameters["atoms"]
@@ -47,10 +36,12 @@ def generate_task_run_name():
47
  return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"
48
 
49
 
50
- # https://docs.prefect.io/3.0/develop/write-tasks#custom-retry-behavior
51
- # @task(task_run_name=generate_task_run_name)
52
- @flow(flow_run_name=generate_flow_run_name, validate_parameters=False)
53
- def fit(
 
 
54
  atoms: Atoms,
55
  calculator_name: str | MLIPEnum,
56
  calculator_kwargs: dict | None,
 
1
  """
2
+ Define equation of state task.
3
 
4
  https://github.com/materialsvirtuallab/matcalc/blob/main/matcalc/eos.py
5
  """
 
9
  from typing import TYPE_CHECKING
10
 
11
  import numpy as np
12
+ from prefect import task
13
+ from prefect.futures import wait
14
+ from prefect.runtime import task_run
15
+ from prefect.tasks import task_input_hash
16
+
17
  from ase import Atoms
18
  from ase.filters import * # type: ignore
19
  from ase.optimize import * # type: ignore
20
  from ase.optimize.optimize import Optimizer
 
 
 
 
 
21
  from mlip_arena.models import MLIPEnum
22
  from mlip_arena.tasks.optimize import run as OPT
23
+ from pymatgen.analysis.eos import BirchMurnaghan
24
 
25
  if TYPE_CHECKING:
26
  from ase.filters import Filter
27
 
28
 
29
+ def _generate_task_run_name():
 
 
 
 
 
 
 
 
 
 
 
30
  task_name = task_run.task_name
 
31
  parameters = task_run.parameters
32
 
33
  atoms = parameters["atoms"]
 
36
  return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"
37
 
38
 
39
+ @task(
40
+ name="EOS",
41
+ task_run_name=_generate_task_run_name,
42
+ cache_key_fn=task_input_hash,
43
+ )
44
+ def run(
45
  atoms: Atoms,
46
  calculator_name: str | MLIPEnum,
47
  calculator_kwargs: dict | None,
mlip_arena/tasks/eos/__init__.py DELETED
File without changes
mlip_arena/tasks/eos/run.ipynb DELETED
@@ -1,600 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 2,
6
- "metadata": {},
7
- "outputs": [
8
- {
9
- "name": "stdout",
10
- "output_type": "stream",
11
- "text": [
12
- "No module named 'deepmd'\n"
13
- ]
14
- },
15
- {
16
- "data": {
17
- "text/plain": [
18
- "True"
19
- ]
20
- },
21
- "execution_count": 2,
22
- "metadata": {},
23
- "output_type": "execute_result"
24
- }
25
- ],
26
- "source": [
27
- "import os\n",
28
- "# from mp_api.client import MPRester\n",
29
- "from dask.distributed import Client\n",
30
- "from dask_jobqueue import SLURMCluster\n",
31
- "from prefect import task, flow\n",
32
- "from prefect.task_runners import ThreadPoolTaskRunner\n",
33
- "from prefect_dask import DaskTaskRunner\n",
34
- "from pymatgen.core.structure import Structure\n",
35
- "from dotenv import load_dotenv\n",
36
- "from ase import Atoms\n",
37
- "from ase.io import write, read\n",
38
- "from pathlib import Path\n",
39
- "import pandas as pd\n",
40
- "from prefect.futures import wait\n",
41
- "\n",
42
- "from mlip_arena.tasks.eos.run import fit as EOS\n",
43
- "from mlip_arena.models import MLIPEnum\n",
44
- "\n",
45
- "load_dotenv()\n",
46
- "\n",
47
- "# MP_API_KEY = os.environ.get(\"MP_API_KEY\", None)"
48
- ]
49
- },
50
- {
51
- "cell_type": "code",
52
- "execution_count": 2,
53
- "metadata": {},
54
- "outputs": [
55
- {
56
- "name": "stdout",
57
- "output_type": "stream",
58
- "text": [
59
- "MP Database version: 2023.11.1\n"
60
- ]
61
- },
62
- {
63
- "data": {
64
- "application/vnd.jupyter.widget-view+json": {
65
- "model_id": "bb6c1969c89840888c556f8fa59b4a67",
66
- "version_major": 2,
67
- "version_minor": 0
68
- },
69
- "text/plain": [
70
- "Retrieving SummaryDoc documents: 0%| | 0/5135 [00:00<?, ?it/s]"
71
- ]
72
- },
73
- "metadata": {},
74
- "output_type": "display_data"
75
- }
76
- ],
77
- "source": [
78
- "\n",
79
- "with MPRester(MP_API_KEY) as mpr:\n",
80
- " print(\"MP Database version:\", mpr.get_database_version())\n",
81
- "\n",
82
- " summary_docs = mpr.materials.summary.search(\n",
83
- " num_elements=(1, 2),\n",
84
- " is_stable=True,\n",
85
- " fields=[\"material_id\", \"structure\", \"formula_pretty\"]\n",
86
- " )\n"
87
- ]
88
- },
89
- {
90
- "cell_type": "code",
91
- "execution_count": 3,
92
- "metadata": {},
93
- "outputs": [],
94
- "source": [
95
- "\n",
96
- "atoms_list = []\n",
97
- "\n",
98
- "for doc in summary_docs:\n",
99
- "\n",
100
- " structure = doc.structure\n",
101
- " assert isinstance(structure, Structure)\n",
102
- "\n",
103
- " atoms = structure.to_ase_atoms()\n",
104
- "\n",
105
- " atoms_list.append(atoms)\n"
106
- ]
107
- },
108
- {
109
- "cell_type": "code",
110
- "execution_count": 4,
111
- "metadata": {},
112
- "outputs": [],
113
- "source": [
114
- "write(\"all.extxyz\", atoms_list)"
115
- ]
116
- },
117
- {
118
- "cell_type": "code",
119
- "execution_count": 2,
120
- "metadata": {
121
- "tags": []
122
- },
123
- "outputs": [],
124
- "source": [
125
- "atoms_list = read(\"all.extxyz\", index=':')"
126
- ]
127
- },
128
- {
129
- "cell_type": "code",
130
- "execution_count": 3,
131
- "metadata": {},
132
- "outputs": [
133
- {
134
- "name": "stdout",
135
- "output_type": "stream",
136
- "text": [
137
- "#!/bin/bash\n",
138
- "\n",
139
- "#SBATCH -A matgen\n",
140
- "#SBATCH --mem=0\n",
141
- "#SBATCH -t 00:30:00\n",
142
- "#SBATCH -N 1\n",
143
- "#SBATCH -G 4\n",
144
- "#SBATCH -q debug\n",
145
- "#SBATCH -C gpu\n",
146
- "#SBATCH -J eos\n",
147
- "source ~/.bashrc\n",
148
- "module load python\n",
149
- "source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena\n",
150
- "/pscratch/sd/c/cyrusyc/.conda/mlip-arena/bin/python -m distributed.cli.dask_worker tcp://128.55.64.49:36289 --name dummy-name --nthreads 1 --memory-limit 59.60GiB --nanny --death-timeout 60\n",
151
- "\n"
152
- ]
153
- }
154
- ],
155
- "source": [
156
- "nodes_per_alloc = 1\n",
157
- "gpus_per_alloc = 4\n",
158
- "ntasks = 1\n",
159
- "\n",
160
- "cluster_kwargs = {\n",
161
- " \"cores\": 1,\n",
162
- " \"memory\": \"64 GB\",\n",
163
- " \"shebang\": \"#!/bin/bash\",\n",
164
- " \"account\": \"matgen\",\n",
165
- " \"walltime\": \"00:30:00\",\n",
166
- " \"job_mem\": \"0\",\n",
167
- " \"job_script_prologue\": [\n",
168
- " \"source ~/.bashrc\",\n",
169
- " \"module load python\",\n",
170
- " \"source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena\",\n",
171
- " ],\n",
172
- " \"job_directives_skip\": [\"-n\", \"--cpus-per-task\", \"-J\"],\n",
173
- " \"job_extra_directives\": [f\"-N {nodes_per_alloc}\", f\"-G {gpus_per_alloc}\", \"-q debug\", \"-C gpu\", \"-J eos\"],\n",
174
- "}\n",
175
- "cluster = SLURMCluster(**cluster_kwargs)\n",
176
- "\n",
177
- "print(cluster.job_script())\n",
178
- "cluster.adapt(minimum_jobs=2, maximum_jobs=2)\n",
179
- "client = Client(cluster)\n"
180
- ]
181
- },
182
- {
183
- "cell_type": "code",
184
- "execution_count": 4,
185
- "metadata": {
186
- "tags": []
187
- },
188
- "outputs": [],
189
- "source": [
190
- "from prefect.concurrency.sync import concurrency\n",
191
- "from prefect.runtime import flow_run, task_run\n",
192
- "\n",
193
- "def postprocess(output, model: str, formula: str):\n",
194
- " row = {\n",
195
- " \"formula\": formula,\n",
196
- " \"method\": model,\n",
197
- " \"volumes\": output[\"eos\"][\"volumes\"],\n",
198
- " \"energies\": output[\"eos\"][\"energies\"],\n",
199
- " \"K\": output[\"K\"],\n",
200
- " }\n",
201
- "\n",
202
- " fpath = Path(REGISTRY[model][\"family\"]) / f\"{model}.parquet\"\n",
203
- "\n",
204
- " if not fpath.exists():\n",
205
- " fpath.parent.mkdir(parents=True, exist_ok=True)\n",
206
- " df = pd.DataFrame([row]) # Convert the dictionary to a DataFrame with a list\n",
207
- " else:\n",
208
- " df = pd.read_parquet(fpath)\n",
209
- " new_row = pd.DataFrame([row]) # Convert dictionary to DataFrame with a list\n",
210
- " df = pd.concat([df, new_row], ignore_index=True)\n",
211
- "\n",
212
- " df.drop_duplicates(subset=[\"formula\", \"method\"], keep='last', inplace=True)\n",
213
- " df.to_parquet(fpath)\n",
214
- "\n",
215
- "\n",
216
- "\n",
217
- "task_runner = DaskTaskRunner(address=client.scheduler.address)\n",
218
- "EOS = EOS.with_options(\n",
219
- " # task_runner=task_runner, \n",
220
- " log_prints=True,\n",
221
- " timeout_seconds=120, \n",
222
- " # result_storage=None\n",
223
- ")\n",
224
- "\n",
225
- "from prefect import get_client\n",
226
- "\n",
227
- "async with get_client() as client:\n",
228
- " limit_id = await client.create_concurrency_limit(\n",
229
- " tag=\"bottleneck\", \n",
230
- " concurrency_limit=2\n",
231
- " )\n",
232
- "\n",
233
- "def generate_task_run_name():\n",
234
- " task_name = task_run.task_name\n",
235
- "\n",
236
- " parameters = task_run.parameters\n",
237
- "\n",
238
- " atoms = parameters[\"atoms\"]\n",
239
- " \n",
240
- " return f\"{task_name}: {atoms.get_chemical_formula()}\"\n",
241
- "\n",
242
- "@task(task_run_name=generate_task_run_name, tags=[\"bottleneck\"], timeout_seconds=150)\n",
243
- "def fit_one(atoms: Atoms, model: str):\n",
244
- " \n",
245
- " eos = EOS(\n",
246
- " atoms=atoms,\n",
247
- " calculator_name=model,\n",
248
- " calculator_kwargs={},\n",
249
- " device=None,\n",
250
- " optimizer=\"QuasiNewton\",\n",
251
- " optimizer_kwargs=None,\n",
252
- " filter=\"FrechetCell\",\n",
253
- " filter_kwargs=None,\n",
254
- " criterion=dict(\n",
255
- " fmax=0.1,\n",
256
- " ),\n",
257
- " max_abs_strain=0.1,\n",
258
- " npoints=7,\n",
259
- " )\n",
260
- " if isinstance(eos, dict):\n",
261
- " postprocess(output=eos, model=model, formula=atoms.get_chemical_formula())\n",
262
- " eos[\"method\"] = model\n",
263
- " \n",
264
- " return eos\n",
265
- " \n",
266
- "#https://docs-3.prefect.io/3.0/develop/task-runners#use-multiple-task-runners\n",
267
- "# @flow(task_runner=ThreadPoolTaskRunner(max_workers=50), log_prints=True)\n",
268
- "@flow(task_runner=task_runner, log_prints=True)\n",
269
- "def fit_all(atoms_list: list[Atoms]):\n",
270
- " \n",
271
- " futures = []\n",
272
- " for atoms in atoms_list:\n",
273
- " futures_per_atoms = []\n",
274
- " for model in MLIPEnum:\n",
275
- " \n",
276
- " # with concurrency(\"bottleneck\", occupy=2):\n",
277
- " future = fit_one.submit(atoms, model.name)\n",
278
- " # if not futures_per_atoms:\n",
279
- " # if not futures:\n",
280
- " # future = fit_one.submit(atoms, model.name)\n",
281
- " # else:\n",
282
- " # future = fit_one.submit(atoms, model.name, wait_for=[futures[-1]]) \n",
283
- " # else:\n",
284
- " # future = fit_one.submit(atoms, model.name, wait_for=[future])\n",
285
- " futures_per_atoms.append(future)\n",
286
- " \n",
287
- " futures.extend(futures_per_atoms)\n",
288
- "\n",
289
- " return [f.result() for f in futures]\n",
290
- "\n",
291
- "\n",
292
- "# @task(task_run_name=generate_task_run_name, result_storage=None)\n",
293
- "# def fit_one(atoms: Atoms):\n",
294
- " \n",
295
- "# outputs = []\n",
296
- "# for model in MLIPEnum:\n",
297
- "# try:\n",
298
- "# eos = EOS(\n",
299
- "# atoms=atoms,\n",
300
- "# calculator_name=model.name,\n",
301
- "# calculator_kwargs={},\n",
302
- "# device=None,\n",
303
- "# optimizer=\"QuasiNewton\",\n",
304
- "# optimizer_kwargs=None,\n",
305
- "# filter=\"FrechetCell\",\n",
306
- "# filter_kwargs=None,\n",
307
- "# criterion=dict(\n",
308
- "# fmax=0.1,\n",
309
- "# ),\n",
310
- "# max_abs_strain=0.1,\n",
311
- "# npoints=7,\n",
312
- "# )\n",
313
- "# if isinstance(eos, dict):\n",
314
- "# postprocess(output=eos, model=model.name, formula=atoms.get_chemical_formula())\n",
315
- "# eos[\"method\"] = model.name\n",
316
- "# outputs.append(eos)\n",
317
- "# except:\n",
318
- "# continue\n",
319
- " \n",
320
- "# return outputs\n",
321
- "\n",
322
- "# # https://orion-docs.prefect.io/latest/concepts/task-runners/#using-multiple-task-runners\n",
323
- "# @flow(task_runner=DaskTaskRunner(address=client.scheduler.address), log_prints=True, result_storage=None)\n",
324
- "# def fit_all(atoms_list: list[Atoms]):\n",
325
- " \n",
326
- "# futures = []\n",
327
- "# for atoms in atoms_list:\n",
328
- "# future = fit_one.submit(atoms)\n",
329
- "# futures.append(future)\n",
330
- " \n",
331
- "# wait(futures)\n",
332
- " \n",
333
- "# return [f.result(raise_on_failure=False) for f in futures]"
334
- ]
335
- },
336
- {
337
- "cell_type": "code",
338
- "execution_count": null,
339
- "metadata": {
340
- "scrolled": true,
341
- "tags": []
342
- },
343
- "outputs": [
344
- {
345
- "data": {
346
- "text/html": [
347
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">18:53:47.335 | <span style=\"color: #008080; text-decoration-color: #008080\">INFO</span> | prefect.engine - Created flow run<span style=\"color: #800080; text-decoration-color: #800080\"> 'vengeful-malkoha'</span> for flow<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> 'fit-all'</span>\n",
348
- "</pre>\n"
349
- ],
350
- "text/plain": [
351
- "18:53:47.335 | \u001b[36mINFO\u001b[0m | prefect.engine - Created flow run\u001b[35m 'vengeful-malkoha'\u001b[0m for flow\u001b[1;35m 'fit-all'\u001b[0m\n"
352
- ]
353
- },
354
- "metadata": {},
355
- "output_type": "display_data"
356
- },
357
- {
358
- "data": {
359
- "text/html": [
360
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">18:53:47.341 | <span style=\"color: #008080; text-decoration-color: #008080\">INFO</span> | prefect.engine - View at <span style=\"color: #0000ff; text-decoration-color: #0000ff\">https://app.prefect.cloud/account/f7d40474-9362-4bfa-8950-ee6a43ec00f3/workspace/d4bb0913-5f5e-49f7-bfc5-06509088baeb/runs/flow-run/909d2bc4-695f-4eeb-8b7c-7660397a0692</span>\n",
361
- "</pre>\n"
362
- ],
363
- "text/plain": [
364
- "18:53:47.341 | \u001b[36mINFO\u001b[0m | prefect.engine - View at \u001b[94mhttps://app.prefect.cloud/account/f7d40474-9362-4bfa-8950-ee6a43ec00f3/workspace/d4bb0913-5f5e-49f7-bfc5-06509088baeb/runs/flow-run/909d2bc4-695f-4eeb-8b7c-7660397a0692\u001b[0m\n"
365
- ]
366
- },
367
- "metadata": {},
368
- "output_type": "display_data"
369
- },
370
- {
371
- "data": {
372
- "text/html": [
373
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">18:53:47.654 | <span style=\"color: #008080; text-decoration-color: #008080\">INFO</span> | prefect.task_runner.dask - Connecting to existing Dask cluster SLURMCluster(df8c3d55, 'tcp://128.55.64.49:36289', workers=0, threads=0, memory=0 B)\n",
374
- "</pre>\n"
375
- ],
376
- "text/plain": [
377
- "18:53:47.654 | \u001b[36mINFO\u001b[0m | prefect.task_runner.dask - Connecting to existing Dask cluster SLURMCluster(df8c3d55, 'tcp://128.55.64.49:36289', workers=0, threads=0, memory=0 B)\n"
378
- ]
379
- },
380
- "metadata": {},
381
- "output_type": "display_data"
382
- }
383
- ],
384
- "source": [
385
- "fit_all(atoms_list)"
386
- ]
387
- },
388
- {
389
- "cell_type": "markdown",
390
- "metadata": {},
391
- "source": [
392
- "```\n",
393
- "Note that, because the DaskTaskRunner uses multiprocessing, calls to flows in scripts must be guarded with if __name__ == \"__main__\": or you will encounter warnings and errors.\n",
394
- "```"
395
- ]
396
- },
397
- {
398
- "cell_type": "code",
399
- "execution_count": 9,
400
- "metadata": {
401
- "tags": []
402
- },
403
- "outputs": [],
404
- "source": [
405
- "# import os\n",
406
- "# import tempfile\n",
407
- "# import shutil\n",
408
- "# from contextlib import contextmanager\n",
409
- "\n",
410
- "# @contextmanager\n",
411
- "# def twd():\n",
412
- " \n",
413
- "# pwd = os.getcwd()\n",
414
- "# temp_dir = tempfile.mkdtemp()\n",
415
- " \n",
416
- "# try:\n",
417
- "# os.chdir(temp_dir)\n",
418
- "# yield\n",
419
- "# finally:\n",
420
- "# os.chdir(pwd)\n",
421
- "# shutil.rmtree(temp_dir)\n",
422
- "\n",
423
- "# with twd():\n",
424
- "\n",
425
- "# fit_all(atoms_list)"
426
- ]
427
- },
428
- {
429
- "cell_type": "code",
430
- "execution_count": 10,
431
- "metadata": {
432
- "tags": []
433
- },
434
- "outputs": [],
435
- "source": [
436
- "import pandas as pd\n",
437
- "\n",
438
- "df = pd.read_parquet('mace-mp/MACE-MP(M).parquet')"
439
- ]
440
- },
441
- {
442
- "cell_type": "code",
443
- "execution_count": 11,
444
- "metadata": {},
445
- "outputs": [
446
- {
447
- "data": {
448
- "text/html": [
449
- "<div>\n",
450
- "<style scoped>\n",
451
- " .dataframe tbody tr th:only-of-type {\n",
452
- " vertical-align: middle;\n",
453
- " }\n",
454
- "\n",
455
- " .dataframe tbody tr th {\n",
456
- " vertical-align: top;\n",
457
- " }\n",
458
- "\n",
459
- " .dataframe thead th {\n",
460
- " text-align: right;\n",
461
- " }\n",
462
- "</style>\n",
463
- "<table border=\"1\" class=\"dataframe\">\n",
464
- " <thead>\n",
465
- " <tr style=\"text-align: right;\">\n",
466
- " <th></th>\n",
467
- " <th>formula</th>\n",
468
- " <th>method</th>\n",
469
- " <th>volumes</th>\n",
470
- " <th>energies</th>\n",
471
- " <th>K</th>\n",
472
- " </tr>\n",
473
- " </thead>\n",
474
- " <tbody>\n",
475
- " <tr>\n",
476
- " <th>1</th>\n",
477
- " <td>Ac2O3</td>\n",
478
- " <td>MACE-MP(M)</td>\n",
479
- " <td>[82.36010147441682, 85.41047560309894, 88.4608...</td>\n",
480
- " <td>[-39.47541427612305, -39.65580749511719, -39.7...</td>\n",
481
- " <td>95.755459</td>\n",
482
- " </tr>\n",
483
- " <tr>\n",
484
- " <th>2</th>\n",
485
- " <td>Ac6In2</td>\n",
486
- " <td>MACE-MP(M)</td>\n",
487
- " <td>[278.3036976131417, 288.61124196918433, 298.91...</td>\n",
488
- " <td>[-31.21324348449707, -31.40914535522461, -31.5...</td>\n",
489
- " <td>33.370214</td>\n",
490
- " </tr>\n",
491
- " <tr>\n",
492
- " <th>3</th>\n",
493
- " <td>Ac6Tl2</td>\n",
494
- " <td>MACE-MP(M)</td>\n",
495
- " <td>[278.30267000598286, 288.6101763025008, 298.91...</td>\n",
496
- " <td>[-29.572534561157227, -29.833026885986328, -30...</td>\n",
497
- " <td>29.065081</td>\n",
498
- " </tr>\n",
499
- " <tr>\n",
500
- " <th>4</th>\n",
501
- " <td>Ac3Sn</td>\n",
502
- " <td>MACE-MP(M)</td>\n",
503
- " <td>[135.293532345587, 140.30440391394214, 145.315...</td>\n",
504
- " <td>[-17.135194778442383, -17.228239059448242, -17...</td>\n",
505
- " <td>30.622045</td>\n",
506
- " </tr>\n",
507
- " <tr>\n",
508
- " <th>5</th>\n",
509
- " <td>AcAg</td>\n",
510
- " <td>MACE-MP(M)</td>\n",
511
- " <td>[55.376437498321394, 57.4274166649259, 59.4783...</td>\n",
512
- " <td>[-7.274301528930664, -7.346108913421631, -7.39...</td>\n",
513
- " <td>40.212164</td>\n",
514
- " </tr>\n",
515
- " <tr>\n",
516
- " <th>6</th>\n",
517
- " <td>Ac4</td>\n",
518
- " <td>MACE-MP(M)</td>\n",
519
- " <td>[166.09086069175856, 172.2423740507126, 178.39...</td>\n",
520
- " <td>[-16.326059341430664, -16.406923294067383, -16...</td>\n",
521
- " <td>25.409891</td>\n",
522
- " </tr>\n",
523
- " <tr>\n",
524
- " <th>7</th>\n",
525
- " <td>Ac16S24</td>\n",
526
- " <td>MACE-MP(M)</td>\n",
527
- " <td>[1006.5670668063424, 1043.84732853991, 1081.12...</td>\n",
528
- " <td>[-249.4179229736328, -250.7970733642578, -251....</td>\n",
529
- " <td>61.734158</td>\n",
530
- " </tr>\n",
531
- " </tbody>\n",
532
- "</table>\n",
533
- "</div>"
534
- ],
535
- "text/plain": [
536
- " formula method volumes \\\n",
537
- "1 Ac2O3 MACE-MP(M) [82.36010147441682, 85.41047560309894, 88.4608... \n",
538
- "2 Ac6In2 MACE-MP(M) [278.3036976131417, 288.61124196918433, 298.91... \n",
539
- "3 Ac6Tl2 MACE-MP(M) [278.30267000598286, 288.6101763025008, 298.91... \n",
540
- "4 Ac3Sn MACE-MP(M) [135.293532345587, 140.30440391394214, 145.315... \n",
541
- "5 AcAg MACE-MP(M) [55.376437498321394, 57.4274166649259, 59.4783... \n",
542
- "6 Ac4 MACE-MP(M) [166.09086069175856, 172.2423740507126, 178.39... \n",
543
- "7 Ac16S24 MACE-MP(M) [1006.5670668063424, 1043.84732853991, 1081.12... \n",
544
- "\n",
545
- " energies K \n",
546
- "1 [-39.47541427612305, -39.65580749511719, -39.7... 95.755459 \n",
547
- "2 [-31.21324348449707, -31.40914535522461, -31.5... 33.370214 \n",
548
- "3 [-29.572534561157227, -29.833026885986328, -30... 29.065081 \n",
549
- "4 [-17.135194778442383, -17.228239059448242, -17... 30.622045 \n",
550
- "5 [-7.274301528930664, -7.346108913421631, -7.39... 40.212164 \n",
551
- "6 [-16.326059341430664, -16.406923294067383, -16... 25.409891 \n",
552
- "7 [-249.4179229736328, -250.7970733642578, -251.... 61.734158 "
553
- ]
554
- },
555
- "execution_count": 11,
556
- "metadata": {},
557
- "output_type": "execute_result"
558
- }
559
- ],
560
- "source": [
561
- "df"
562
- ]
563
- },
564
- {
565
- "cell_type": "code",
566
- "execution_count": null,
567
- "metadata": {},
568
- "outputs": [],
569
- "source": []
570
- }
571
- ],
572
- "metadata": {
573
- "kernelspec": {
574
- "display_name": "mlip-arena",
575
- "language": "python",
576
- "name": "mlip-arena"
577
- },
578
- "language_info": {
579
- "codemirror_mode": {
580
- "name": "ipython",
581
- "version": 3
582
- },
583
- "file_extension": ".py",
584
- "mimetype": "text/x-python",
585
- "name": "python",
586
- "nbconvert_exporter": "python",
587
- "pygments_lexer": "ipython3",
588
- "version": "3.11.10"
589
- },
590
- "widgets": {
591
- "application/vnd.jupyter.widget-state+json": {
592
- "state": {},
593
- "version_major": 2,
594
- "version_minor": 0
595
- }
596
- }
597
- },
598
- "nbformat": 4,
599
- "nbformat_minor": 4
600
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mlip_arena/tasks/md.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Define molecular dynamics tasks.
3
 
4
  This script has been adapted from Atomate2 MLFF MD workflow written by Aaron Kaplan and Yuan Chiang
5
  https://github.com/materialsproject/atomate2/blob/main/src/atomate2/forcefields/md.py
@@ -60,6 +60,14 @@ from pathlib import Path
60
  from typing import Literal
61
 
62
  import numpy as np
 
 
 
 
 
 
 
 
63
  from ase import Atoms, units
64
  from ase.calculators.calculator import Calculator
65
  from ase.calculators.mixing import SumCalculator
@@ -77,13 +85,6 @@ from ase.md.velocitydistribution import (
77
  ZeroRotation,
78
  )
79
  from ase.md.verlet import VelocityVerlet
80
- from prefect import task
81
- from prefect.tasks import task_input_hash
82
- from scipy.interpolate import interp1d
83
- from scipy.linalg import schur
84
- from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
85
- from tqdm.auto import tqdm
86
-
87
  from mlip_arena.models import MLIPEnum
88
  from mlip_arena.models.utils import get_freer_device
89
 
@@ -186,7 +187,22 @@ def _get_ensemble_defaults(
186
  return ase_md_kwargs
187
 
188
 
189
- @task(cache_key_fn=task_input_hash, cache_expiration=timedelta(days=1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  def run(
191
  atoms: Atoms,
192
  calculator_name: str | MLIPEnum,
@@ -196,10 +212,10 @@ def run(
196
  device: str | None = None,
197
  ensemble: Literal["nve", "nvt", "npt"] = "nvt",
198
  dynamics: str | MolecularDynamics = "langevin",
199
- time_step: float | None = None, # fs
200
  total_time: float = 1000, # fs
201
- temperature: float | Sequence | np.ndarray | None = 300.0, # K
202
- pressure: float | Sequence | np.ndarray | None = None, # eV/A^3
203
  ase_md_kwargs: dict | None = None,
204
  md_velocity_seed: int | None = None,
205
  zero_linear_momentum: bool = True,
@@ -282,7 +298,7 @@ def run(
282
  raise ValueError(f"Invalid dynamics: {dynamics}")
283
 
284
  if md_class is NPT:
285
- # Note that until md_func is instantiated, isinstance(md_func,NPT) is False
286
  # ASE NPT implementation requires upper triangular cell
287
  u, _ = schur(atoms.get_cell(complete=True), output="complex")
288
  atoms.set_cell(u.real, scale_atoms=True)
 
1
  """
2
+ Define molecular dynamics task.
3
 
4
  This script has been adapted from Atomate2 MLFF MD workflow written by Aaron Kaplan and Yuan Chiang
5
  https://github.com/materialsproject/atomate2/blob/main/src/atomate2/forcefields/md.py
 
60
  from typing import Literal
61
 
62
  import numpy as np
63
+ from prefect import task
64
+ from prefect.runtime import task_run
65
+ from prefect.tasks import task_input_hash
66
+ from scipy.interpolate import interp1d
67
+ from scipy.linalg import schur
68
+ from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
69
+ from tqdm.auto import tqdm
70
+
71
  from ase import Atoms, units
72
  from ase.calculators.calculator import Calculator
73
  from ase.calculators.mixing import SumCalculator
 
85
  ZeroRotation,
86
  )
87
  from ase.md.verlet import VelocityVerlet
 
 
 
 
 
 
 
88
  from mlip_arena.models import MLIPEnum
89
  from mlip_arena.models.utils import get_freer_device
90
 
 
187
  return ase_md_kwargs
188
 
189
 
190
+ def _generate_task_run_name():
191
+ task_name = task_run.task_name
192
+ parameters = task_run.parameters
193
+
194
+ atoms = parameters["atoms"]
195
+ calculator_name = parameters["calculator_name"]
196
+
197
+ return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"
198
+
199
+
200
+ @task(
201
+ name="MD",
202
+ task_run_name=_generate_task_run_name,
203
+ cache_key_fn=task_input_hash,
204
+ # cache_expiration=timedelta(days=1)
205
+ )
206
  def run(
207
  atoms: Atoms,
208
  calculator_name: str | MLIPEnum,
 
212
  device: str | None = None,
213
  ensemble: Literal["nve", "nvt", "npt"] = "nvt",
214
  dynamics: str | MolecularDynamics = "langevin",
215
+ time_step: float | None = None, # fs
216
  total_time: float = 1000, # fs
217
+ temperature: float | Sequence | np.ndarray | None = 300.0, # K
218
+ pressure: float | Sequence | np.ndarray | None = None, # eV/A^3
219
  ase_md_kwargs: dict | None = None,
220
  md_velocity_seed: int | None = None,
221
  zero_linear_momentum: bool = True,
 
298
  raise ValueError(f"Invalid dynamics: {dynamics}")
299
 
300
  if md_class is NPT:
301
+ # Note that until md_func is instantiated, isinstance(md_func,NPT) is False
302
  # ASE NPT implementation requires upper triangular cell
303
  u, _ = schur(atoms.get_cell(complete=True), output="complex")
304
  atoms.set_cell(u.real, scale_atoms=True)
mlip_arena/tasks/optimize.py CHANGED
@@ -1,4 +1,4 @@
1
- """
2
  Define structure optimization tasks.
3
  """
4
 
@@ -6,6 +6,11 @@ from __future__ import annotations
6
 
7
  from datetime import timedelta
8
 
 
 
 
 
 
9
  from ase import Atoms
10
  from ase.calculators.calculator import Calculator
11
  from ase.calculators.mixing import SumCalculator
@@ -13,10 +18,6 @@ from ase.filters import * # type: ignore
13
  from ase.filters import Filter
14
  from ase.optimize import * # type: ignore
15
  from ase.optimize.optimize import Optimizer
16
- from prefect import task
17
- from prefect.tasks import task_input_hash
18
- from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
19
-
20
  from mlip_arena.models import MLIPEnum
21
  from mlip_arena.models.utils import get_freer_device
22
 
@@ -26,7 +27,7 @@ _valid_filters: dict[str, Filter] = {
26
  "ExpCell": ExpCellFilter,
27
  "Strain": StrainFilter,
28
  "FrechetCell": FrechetCellFilter,
29
- } # type: ignore
30
 
31
  _valid_optimizers: dict[str, Optimizer] = {
32
  "MDMin": MDMin,
@@ -39,14 +40,25 @@ _valid_optimizers: dict[str, Optimizer] = {
39
  "GPMin": GPMin,
40
  "CellAwareBFGS": CellAwareBFGS,
41
  "ODE12r": ODE12r,
42
- } # type: ignore
 
 
 
 
 
 
 
 
43
 
 
44
 
45
- # @task(
46
- # cache_key_fn=task_input_hash,
47
- # cache_expiration=timedelta(days=1),
48
- # timeout_seconds=120)
49
- @task(timeout_seconds=120, result_storage=None)
 
 
50
  def run(
51
  atoms: Atoms,
52
  calculator_name: str | MLIPEnum,
@@ -103,7 +115,6 @@ def run(
103
  raise ValueError(f"Invalid optimizer: {optimizer}")
104
  optimizer = _valid_optimizers[optimizer]
105
 
106
-
107
  filter_kwargs = filter_kwargs or {}
108
  optimizer_kwargs = optimizer_kwargs or {}
109
  criterion = criterion or {}
@@ -126,4 +137,3 @@ def run(
126
  return {
127
  "atoms": atoms,
128
  }
129
-
 
1
+ """
2
  Define structure optimization tasks.
3
  """
4
 
 
6
 
7
  from datetime import timedelta
8
 
9
+ from prefect import task
10
+ from prefect.runtime import task_run
11
+ from prefect.tasks import task_input_hash
12
+ from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
13
+
14
  from ase import Atoms
15
  from ase.calculators.calculator import Calculator
16
  from ase.calculators.mixing import SumCalculator
 
18
  from ase.filters import Filter
19
  from ase.optimize import * # type: ignore
20
  from ase.optimize.optimize import Optimizer
 
 
 
 
21
  from mlip_arena.models import MLIPEnum
22
  from mlip_arena.models.utils import get_freer_device
23
 
 
27
  "ExpCell": ExpCellFilter,
28
  "Strain": StrainFilter,
29
  "FrechetCell": FrechetCellFilter,
30
+ } # type: ignore
31
 
32
  _valid_optimizers: dict[str, Optimizer] = {
33
  "MDMin": MDMin,
 
40
  "GPMin": GPMin,
41
  "CellAwareBFGS": CellAwareBFGS,
42
  "ODE12r": ODE12r,
43
+ } # type: ignore
44
+
45
+
46
+ def _generate_task_run_name():
47
+ task_name = task_run.task_name
48
+ parameters = task_run.parameters
49
+
50
+ atoms = parameters["atoms"]
51
+ calculator_name = parameters["calculator_name"]
52
 
53
+ return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"
54
 
55
+
56
+ @task(
57
+ name="MD",
58
+ task_run_name=_generate_task_run_name,
59
+ cache_key_fn=task_input_hash,
60
+ # cache_expiration=timedelta(days=1)
61
+ )
62
  def run(
63
  atoms: Atoms,
64
  calculator_name: str | MLIPEnum,
 
115
  raise ValueError(f"Invalid optimizer: {optimizer}")
116
  optimizer = _valid_optimizers[optimizer]
117
 
 
118
  filter_kwargs = filter_kwargs or {}
119
  optimizer_kwargs = optimizer_kwargs or {}
120
  criterion = criterion or {}
 
137
  return {
138
  "atoms": atoms,
139
  }
 
tests/test_eos.py CHANGED
@@ -2,14 +2,38 @@ import sys
2
 
3
  import pytest
4
  from ase.build import bulk
 
5
  from prefect.testing.utilities import prefect_test_harness
6
 
7
  from mlip_arena.models import MLIPEnum
8
- from mlip_arena.tasks.eos.run import fit as EOS
9
 
10
- atoms = bulk("Cu", "fcc", a=3.6)
11
 
12
- @pytest.mark.skipif(sys.version_info[:2] != (3,11), reason="avoid prefect race condition on concurrent tasks")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @pytest.mark.parametrize("model", [MLIPEnum["MACE-MP(M)"]])
14
  def test_eos(model: MLIPEnum):
15
  """
@@ -17,21 +41,8 @@ def test_eos(model: MLIPEnum):
17
  """
18
 
19
  with prefect_test_harness():
20
-
21
- result = EOS(
22
- atoms=atoms,
23
  calculator_name=model.name,
24
- calculator_kwargs={},
25
- device=None,
26
- optimizer="BFGSLineSearch",
27
- optimizer_kwargs=None,
28
- filter="FrechetCell",
29
- filter_kwargs=None,
30
- criterion=dict(
31
- fmax=0.1,
32
- ),
33
- max_abs_strain=0.1,
34
- npoints=6,
35
  )
36
 
37
  assert isinstance(result["K"], float)
 
2
 
3
  import pytest
4
  from ase.build import bulk
5
+ from prefect import flow
6
  from prefect.testing.utilities import prefect_test_harness
7
 
8
  from mlip_arena.models import MLIPEnum
9
+ from mlip_arena.tasks.eos import run as EOS
10
 
 
11
 
12
+ @flow
13
+ def single_eos_flow(calculator_name):
14
+ atoms = bulk("Cu", "fcc", a=3.6)
15
+
16
+ return EOS(
17
+ atoms=atoms,
18
+ calculator_name=calculator_name,
19
+ calculator_kwargs={},
20
+ device=None,
21
+ optimizer="BFGSLineSearch",
22
+ optimizer_kwargs=None,
23
+ filter="FrechetCell",
24
+ filter_kwargs=None,
25
+ criterion=dict(
26
+ fmax=0.1,
27
+ ),
28
+ max_abs_strain=0.1,
29
+ npoints=6,
30
+ )
31
+
32
+
33
+ @pytest.mark.skipif(
34
+ sys.version_info[:2] != (3, 11),
35
+ reason="avoid prefect race condition on concurrent tasks",
36
+ )
37
  @pytest.mark.parametrize("model", [MLIPEnum["MACE-MP(M)"]])
38
  def test_eos(model: MLIPEnum):
39
  """
 
41
  """
42
 
43
  with prefect_test_harness():
44
+ result = single_eos_flow(
 
 
45
  calculator_name=model.name,
 
 
 
 
 
 
 
 
 
 
 
46
  )
47
 
48
  assert isinstance(result["K"], float)