File size: 4,118 Bytes
1d1ee87
 
 
 
 
 
 
 
aadf5d0
1d1ee87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f271cc
 
1d1ee87
5f271cc
 
 
 
aadf5d0
5f271cc
1d1ee87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a2a83
1d1ee87
 
aadf5d0
1d1ee87
 
 
 
 
 
 
 
 
5f271cc
 
1d1ee87
 
ef47233
 
 
 
1d1ee87
 
 
 
 
 
 
 
 
 
ef47233
1d1ee87
 
ef47233
 
1d1ee87
 
ef47233
1d1ee87
 
 
 
 
 
 
 
 
ef47233
1d1ee87
 
 
 
 
 
 
 
 
 
ef47233
 
 
1d1ee87
ef47233
1d1ee87
ef47233
1d1ee87
 
ef47233
 
 
1d1ee87
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from functools import partial
from pathlib import Path

import pandas as pd
from huggingface_hub import hf_hub_download
from prefect import Task, flow, task
from prefect.client.schemas.objects import TaskRun
from prefect.futures import wait
from prefect.states import State

from ase.db import connect
from mlip_arena.data.local import SafeHDFStore
from mlip_arena.models import REGISTRY, MLIPEnum
from mlip_arena.tasks.eos import run as EOS


@task
def get_atoms_from_db(db_path: Path | str):
    db_path = Path(db_path)
    if not db_path.exists():
        db_path = hf_hub_download(
            repo_id="atomind/mlip-arena",
            repo_type="dataset",
            subfolder=f"{Path(__file__).parent.name}",
            filename=str(db_path),
        )
    with connect(db_path) as db:
        for row in db.select():
            yield row.toatoms()


def save_to_hdf(
    tsk: Task, run: TaskRun, state: State, fpath: Path | str, table_name: str
):
    """
    Define a hook on completion of EOS task to save results to HDF5 file.
    """

    if run.state.is_failed():
        return

    result = run.state.result(raise_on_failure=False)

    if not isinstance(result, dict):
        return

    try:
        atoms = result["atoms"]
        calculator_name = (
            run.task_inputs["calculator_name"] or result["calculator_name"]
        )

        energies = [float(e) for e in result["eos"]["energies"]]

        formula = atoms.get_chemical_formula()

        df = pd.DataFrame(
            {
                "method": calculator_name,
                "formula": formula,
                "total_run_time": run.total_run_time,
                "v0": result["v0"],
                "e0": result["e0"],
                "b0": result["b0"],
                "b1": result["b1"],
                "volume": result["eos"]["volumes"],
                "energy": energies,
            }
        )

        fpath = Path(fpath)
        fpath = fpath.with_stem(fpath.stem + f"_{calculator_name}")

        family_path = Path(__file__).parent / REGISTRY[calculator_name]["family"]
        family_path.mkdir(parents=True, exist_ok=True)

        df.to_json(family_path / f"{calculator_name}_{formula}.json", indent=2)

        with SafeHDFStore(fpath, mode="a") as store:
            store.append(
                table_name,
                df,
                format="table",
                data_columns=True,
                min_itemsize={"formula": 50, "method": 20},
            )
    except Exception as e:
        print(e)


@flow(
    name="EOS Alloy"
)
def run(
    db_path: Path | str,
    out_path: Path | str,
    table_name: str,
    optimizer="FIRE",
    optimizer_kwargs=None,
    filter="FrechetCell",
    filter_kwargs=None,
    criterion=dict(fmax=0.1, steps=1000),
    max_abs_strain=0.20,
    concurrent=False,
    cache=True,
):
    EOS_ = EOS.with_options(
        on_completion=[partial(save_to_hdf, fpath=out_path, table_name=table_name)],
        refresh_cache=not cache,
    )

    futures = []
    for atoms in get_atoms_from_db(db_path):
        for mlip in MLIPEnum:
            if not REGISTRY[mlip.name]["npt"]:
                continue
            if Path(__file__).parent.name not in (
                REGISTRY[mlip.name].get("cpu-tasks", [])
                + REGISTRY[mlip.name].get("gpu-tasks", [])
            ):
                continue
            future = EOS_.submit(
                atoms=atoms,
                calculator_name=mlip.name,
                calculator_kwargs=dict(),
                optimizer=optimizer,
                optimizer_kwargs=optimizer_kwargs,
                filter=filter,
                filter_kwargs=filter_kwargs,
                criterion=criterion,
                max_abs_strain=max_abs_strain,
                concurrent=concurrent,
                persist_opt=cache,
                cache_opt=cache,
                # return_state=True
            )
            futures.append(future)

    wait(futures)

    return [
        f.result(timeout=None, raise_on_failure=False)
        for f in futures
        if f.state.is_completed()
    ]