File size: 4,230 Bytes
51638da
1effaf5
 
 
 
 
51638da
1d1ee87
51638da
 
 
1effaf5
 
 
1d1ee87
1effaf5
 
 
 
52c1bfb
 
1effaf5
 
 
 
 
 
 
51638da
1effaf5
 
 
 
1d1ee87
1effaf5
 
 
 
 
 
 
 
51638da
 
 
 
 
 
 
 
 
1effaf5
51638da
1effaf5
51638da
 
1d1ee87
51638da
1d1ee87
 
51638da
 
1effaf5
 
 
1d1ee87
1effaf5
 
 
 
 
 
 
 
1d1ee87
1effaf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d1ee87
 
 
1effaf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Define structure optimization tasks.
"""

from __future__ import annotations

from prefect import task
from prefect.cache_policies import INPUTS, TASK_SOURCE
from prefect.runtime import task_run
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator

from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.calculators.mixing import SumCalculator
from ase.constraints import FixSymmetry
from ase.filters import *  # type: ignore
from ase.filters import Filter
from ase.optimize import *  # type: ignore
from ase.optimize.optimize import Optimizer
from mlip_arena.models import MLIPEnum
from mlip_arena.models.utils import get_freer_device

_valid_filters: dict[str, Filter] = {
    "Filter": Filter,
    "UnitCell": UnitCellFilter,
    "ExpCell": ExpCellFilter,
    "Strain": StrainFilter,
    "FrechetCell": FrechetCellFilter,
}  # type: ignore

_valid_optimizers: dict[str, Optimizer] = {
    "MDMin": MDMin,
    "FIRE": FIRE,
    "FIRE2": FIRE2,
    "LBFGS": LBFGS,
    "LBFGSLineSearch": LBFGSLineSearch,
    "BFGS": BFGS,
    "BFGSLineSearch": BFGSLineSearch,
    "QuasiNewton": QuasiNewton,
    "GPMin": GPMin,
    "CellAwareBFGS": CellAwareBFGS,
    "ODE12r": ODE12r,
}  # type: ignore


def _generate_task_run_name():
    task_name = task_run.task_name
    parameters = task_run.parameters

    atoms = parameters["atoms"]
    calculator_name = parameters["calculator_name"]

    return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"


@task(
    name="OPT",
    task_run_name=_generate_task_run_name,
    cache_policy=TASK_SOURCE + INPUTS
    # cache_key_fn=task_input_hash,
    # cache_expiration=timedelta(days=1)
)
def run(
    atoms: Atoms,
    calculator_name: str | MLIPEnum,
    calculator_kwargs: dict | None = None,
    dispersion: str | None = None,
    dispersion_kwargs: dict | None = None,
    device: str | None = None,
    optimizer: Optimizer | str = BFGSLineSearch,
    optimizer_kwargs: dict | None = None,
    filter: Filter | str | None = None,
    filter_kwargs: dict | None = None,
    criterion: dict | None = None,
    symmetry: bool = False,
):
    device = device or str(get_freer_device())

    print(f"Using device: {device}")

    calculator_kwargs = calculator_kwargs or {}

    if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
        assert issubclass(calculator_name.value, Calculator)
        calc = calculator_name.value(**calculator_kwargs)
    elif (
        isinstance(calculator_name, str) and calculator_name in MLIPEnum._member_names_
    ):
        calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
    else:
        raise ValueError(f"Invalid calculator: {calculator_name}")

    print(f"Using calculator: {calc}")

    dispersion_kwargs = dispersion_kwargs or {}

    dispersion_kwargs.update({"device": device})

    if dispersion is not None:
        disp_calc = TorchDFTD3Calculator(
            **dispersion_kwargs,
        )
        calc = SumCalculator([calc, disp_calc])

        print(f"Using dispersion: {dispersion}")

    atoms.calc = calc

    if isinstance(filter, str):
        if filter not in _valid_filters:
            raise ValueError(f"Invalid filter: {filter}")
        filter = _valid_filters[filter]

    if isinstance(optimizer, str):
        if optimizer not in _valid_optimizers:
            raise ValueError(f"Invalid optimizer: {optimizer}")
        optimizer = _valid_optimizers[optimizer]

    filter_kwargs = filter_kwargs or {}
    optimizer_kwargs = optimizer_kwargs or {}
    criterion = criterion or {}

    if symmetry:
        atoms.set_constraint(FixSymmetry(atoms))

    if isinstance(filter, type) and issubclass(filter, Filter):
        filter_instance = filter(atoms, **filter_kwargs)
        print(f"Using filter: {filter_instance}")

        optimizer_instance = optimizer(atoms, **optimizer_kwargs)
        print(f"Using optimizer: {optimizer_instance}")

        optimizer_instance.run(**criterion)

    elif filter is None:
        optimizer_instance = optimizer(atoms, **optimizer_kwargs)
        print(f"Using optimizer: {optimizer_instance}")
        optimizer_instance.run(**criterion)

    return {
        "atoms": atoms,
    }