File size: 5,763 Bytes
1e50f35
51638da
1e50f35
 
 
 
 
 
1d1ee87
1e50f35
 
51638da
1d1ee87
51638da
00b56e2
51638da
1d1ee87
51638da
1e50f35
 
 
 
52c1bfb
1e50f35
51638da
1e50f35
 
 
 
 
51638da
c7922c2
 
 
 
 
 
 
 
 
51638da
00b56e2
51638da
 
1e50f35
 
1d1ee87
1e50f35
c7922c2
1e50f35
1d1ee87
1e50f35
 
 
 
1d1ee87
ef47233
b7a7786
1d1ee87
1e50f35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d1ee87
00b56e2
b7a7786
1e50f35
 
1d1ee87
1e50f35
b7a7786
 
ef47233
00b56e2
b7a7786
 
 
1e50f35
 
 
 
 
 
 
 
 
1d1ee87
1e50f35
 
1d1ee87
 
00b56e2
 
 
 
 
 
 
1e50f35
 
 
 
 
 
1d1ee87
 
 
 
 
 
b7a7786
1d1ee87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7a7786
1d1ee87
 
 
 
 
 
 
 
 
 
 
 
 
00b56e2
 
 
 
 
 
c7922c2
 
 
 
 
 
 
 
1e50f35
 
 
 
 
1d1ee87
 
1e50f35
 
08a88d8
 
 
 
1e50f35
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""
Define equation of state task.

https://github.com/materialsvirtuallab/matcalc/blob/main/matcalc/eos.py
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np
from prefect import task
from prefect.cache_policies import INPUTS, TASK_SOURCE
from prefect.futures import wait
from prefect.results import ResultRecord
from prefect.runtime import task_run
from prefect.states import State

from ase import Atoms
from ase.filters import *  # type: ignore
from ase.optimize import *  # type: ignore
from ase.optimize.optimize import Optimizer
from mlip_arena.models import MLIPEnum
from mlip_arena.tasks.optimize import run as OPT
from pymatgen.analysis.eos import BirchMurnaghan

if TYPE_CHECKING:
    from ase.filters import Filter


def _generate_task_run_name():
    task_name = task_run.task_name
    parameters = task_run.parameters

    atoms = parameters["atoms"]
    calculator_name = parameters["calculator_name"]

    return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"


@task(
    name="EOS", task_run_name=_generate_task_run_name, cache_policy=TASK_SOURCE + INPUTS
)
def run(
    atoms: Atoms,
    calculator_name: str | MLIPEnum,
    calculator_kwargs: dict | None = None,
    device: str | None = None,
    optimizer: Optimizer | str = "BFGSLineSearch",  # type: ignore
    optimizer_kwargs: dict | None = None,
    filter: Filter | str | None = "FrechetCell",  # type: ignore
    filter_kwargs: dict | None = None,
    criterion: dict | None = None,
    max_abs_strain: float = 0.1,
    npoints: int = 11,
    concurrent: bool = True,
    persist_opt: bool = True,
    cache_opt: bool = True,
) -> dict[str, Any] | State:
    """
    Compute the equation of state (EOS) for the given atoms and calculator.

    Args:
        atoms: The input atoms.
        calculator_name: The name of the calculator to use.
        calculator_kwargs: Additional kwargs to pass to the calculator.
        device: The device to use.
        optimizer: The optimizer to use.
        optimizer_kwargs: Additional kwargs to pass to the optimizer.
        filter: The filter to use.
        filter_kwargs: Additional kwargs to pass to the filter.
        criterion: The criterion to use.
        max_abs_strain: The maximum absolute strain to use.
        npoints: The number of points to sample.
        concurrent: Whether to relax multiple structures concurrently.
        persist_opt: Whether to persist the optimization results.
        cache_opt: Whether to cache the intermediate optimization results.

    Returns:
        A dictionary containing the EOS data, bulk modulus, equilibrium volume, and equilibrium energy if successful. Otherwise, a prefect state object.
    """

    OPT_ = OPT.with_options(
        refresh_cache=not cache_opt,
        persist_result=persist_opt,
    )

    state = OPT_(
        atoms=atoms,
        calculator_name=calculator_name,
        calculator_kwargs=calculator_kwargs,
        device=device,
        optimizer=optimizer,
        optimizer_kwargs=optimizer_kwargs,
        filter=filter,
        filter_kwargs=filter_kwargs,
        criterion=criterion,
        return_state=True,
    )

    if state.is_failed():
        return state

    first_relax = state.result(raise_on_failure=False)

    if isinstance(first_relax, ResultRecord):
        relaxed = first_relax.result["atoms"]
    else:
        relaxed = first_relax["atoms"]

    # p0 = relaxed.get_positions()
    c0 = relaxed.get_cell()

    factors = np.linspace(1 - max_abs_strain, 1 + max_abs_strain, npoints) ** (1 / 3)

    if concurrent:
        futures = []
        for f in factors:
            atoms = relaxed.copy()
            atoms.set_cell(c0 * f, scale_atoms=True)

            future = OPT_.submit(
                atoms=atoms,
                calculator_name=calculator_name,
                calculator_kwargs=calculator_kwargs,
                device=device,
                optimizer=optimizer,
                optimizer_kwargs=optimizer_kwargs,
                filter=None,
                filter_kwargs=None,
                criterion=criterion,
            )
            futures.append(future)

        wait(futures)

        results = [
            f.result(raise_on_failure=False)
            for f in futures
            if future.state.is_completed()
        ]
    else:
        states = []
        for f in factors:
            atoms = relaxed.copy()
            atoms.set_cell(c0 * f, scale_atoms=True)

            state = OPT_(
                atoms=atoms,
                calculator_name=calculator_name,
                calculator_kwargs=calculator_kwargs,
                device=device,
                optimizer=optimizer,
                optimizer_kwargs=optimizer_kwargs,
                filter=None,
                filter_kwargs=None,
                criterion=criterion,
                return_state=True,
            )
            states.append(state)

        results = [s.result(raise_on_failure=False) for s in states if s.is_completed()]

    results = [r.result if isinstance(r, ResultRecord) else r for r in results]

    volumes = [r["atoms"].get_volume() for r in results]
    energies = [r["atoms"].get_potential_energy() for r in results]

    volumes, energies = map(
        list,
        zip(
            *sorted(zip(volumes, energies, strict=True), key=lambda i: i[0]),
            strict=True,
        ),
    )

    bm = BirchMurnaghan(volumes=volumes, energies=energies)
    bm.fit()

    return {
        "atoms": relaxed,
        "calculator_name": calculator_name,
        "eos": {"volumes": volumes, "energies": energies},
        "K": bm.b0_GPa,
        "b0": bm.b0,
        "b1": bm.b1,
        "e0": bm.e0,
        "v0": bm.v0,
    }