Spaces:
Running
Running
more logger; fix relaxation filter
Browse files- mlip_arena/models/utils.py +10 -3
- mlip_arena/tasks/elasticity.py +4 -2
- mlip_arena/tasks/neb.py +10 -17
- mlip_arena/tasks/optimize.py +11 -6
- mlip_arena/tasks/utils.py +11 -9
mlip_arena/models/utils.py
CHANGED
@@ -2,6 +2,13 @@
|
|
2 |
|
3 |
import torch
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
def get_freer_device() -> torch.device:
|
7 |
"""Get the GPU with the most free memory, or use MPS if available.
|
@@ -22,16 +29,16 @@ def get_freer_device() -> torch.device:
|
|
22 |
]
|
23 |
free_gpu_index = mem_free.index(max(mem_free))
|
24 |
device = torch.device(f"cuda:{free_gpu_index}")
|
25 |
-
|
26 |
f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
|
27 |
)
|
28 |
elif torch.backends.mps.is_available():
|
29 |
# If no CUDA GPUs are available but MPS is, use MPS
|
30 |
-
|
31 |
device = torch.device("mps")
|
32 |
else:
|
33 |
# Fallback to CPU if neither CUDA GPUs nor MPS are available
|
34 |
-
|
35 |
device = torch.device("cpu")
|
36 |
|
37 |
return device
|
|
|
2 |
|
3 |
import torch
|
4 |
|
5 |
+
try:
|
6 |
+
from prefect.logging import get_run_logger
|
7 |
+
|
8 |
+
logger = get_run_logger()
|
9 |
+
except (ImportError, RuntimeError):
|
10 |
+
from loguru import logger
|
11 |
+
|
12 |
|
13 |
def get_freer_device() -> torch.device:
|
14 |
"""Get the GPU with the most free memory, or use MPS if available.
|
|
|
29 |
]
|
30 |
free_gpu_index = mem_free.index(max(mem_free))
|
31 |
device = torch.device(f"cuda:{free_gpu_index}")
|
32 |
+
logger.info(
|
33 |
f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
|
34 |
)
|
35 |
elif torch.backends.mps.is_available():
|
36 |
# If no CUDA GPUs are available but MPS is, use MPS
|
37 |
+
logger.info("No GPU available. Using MPS.")
|
38 |
device = torch.device("mps")
|
39 |
else:
|
40 |
# Fallback to CPU if neither CUDA GPUs nor MPS are available
|
41 |
+
logger.info("No GPU or MPS available. Using CPU.")
|
42 |
device = torch.device("cpu")
|
43 |
|
44 |
return device
|
mlip_arena/tasks/elasticity.py
CHANGED
@@ -48,8 +48,6 @@ from prefect.runtime import task_run
|
|
48 |
from prefect.states import State
|
49 |
|
50 |
from ase import Atoms
|
51 |
-
from ase.filters import * # type: ignore
|
52 |
-
from ase.optimize import * # type: ignore
|
53 |
from ase.optimize.optimize import Optimizer
|
54 |
from mlip_arena.models import MLIPEnum
|
55 |
from mlip_arena.tasks.optimize import run as OPT
|
@@ -81,6 +79,8 @@ def run(
|
|
81 |
atoms: Atoms,
|
82 |
calculator_name: str | MLIPEnum,
|
83 |
calculator_kwargs: dict | None = None,
|
|
|
|
|
84 |
device: str | None = None,
|
85 |
optimizer: Optimizer | str = "BFGSLineSearch", # type: ignore
|
86 |
optimizer_kwargs: dict | None = None,
|
@@ -124,6 +124,8 @@ def run(
|
|
124 |
atoms=atoms,
|
125 |
calculator_name=calculator_name,
|
126 |
calculator_kwargs=calculator_kwargs,
|
|
|
|
|
127 |
device=device,
|
128 |
optimizer=optimizer,
|
129 |
optimizer_kwargs=optimizer_kwargs,
|
|
|
48 |
from prefect.states import State
|
49 |
|
50 |
from ase import Atoms
|
|
|
|
|
51 |
from ase.optimize.optimize import Optimizer
|
52 |
from mlip_arena.models import MLIPEnum
|
53 |
from mlip_arena.tasks.optimize import run as OPT
|
|
|
79 |
atoms: Atoms,
|
80 |
calculator_name: str | MLIPEnum,
|
81 |
calculator_kwargs: dict | None = None,
|
82 |
+
dispersion: bool = False,
|
83 |
+
dispersion_kwargs: dict | None = None,
|
84 |
device: str | None = None,
|
85 |
optimizer: Optimizer | str = "BFGSLineSearch", # type: ignore
|
86 |
optimizer_kwargs: dict | None = None,
|
|
|
124 |
atoms=atoms,
|
125 |
calculator_name=calculator_name,
|
126 |
calculator_kwargs=calculator_kwargs,
|
127 |
+
dispersion=dispersion,
|
128 |
+
dispersion_kwargs=dispersion_kwargs,
|
129 |
device=device,
|
130 |
optimizer=optimizer,
|
131 |
optimizer_kwargs=optimizer_kwargs,
|
mlip_arena/tasks/neb.py
CHANGED
@@ -39,7 +39,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
39 |
from __future__ import annotations
|
40 |
|
41 |
from pathlib import Path
|
42 |
-
from typing import
|
43 |
|
44 |
from prefect import task
|
45 |
from prefect.cache_policies import INPUTS, TASK_SOURCE
|
@@ -54,13 +54,9 @@ from ase.optimize.optimize import Optimizer
|
|
54 |
from ase.utils.forcecurve import fit_images
|
55 |
from mlip_arena.models import MLIPEnum
|
56 |
from mlip_arena.tasks.optimize import run as OPT
|
57 |
-
from mlip_arena.tasks.utils import get_calculator
|
58 |
from pymatgen.io.ase import AseAtomsAdaptor
|
59 |
|
60 |
-
|
61 |
-
if TYPE_CHECKING:
|
62 |
-
pass
|
63 |
-
|
64 |
_valid_optimizers: dict[str, Optimizer] = {
|
65 |
"MDMin": MDMin,
|
66 |
"FIRE": FIRE,
|
@@ -86,7 +82,7 @@ def _generate_task_run_name():
|
|
86 |
atoms = parameters["start"]
|
87 |
else:
|
88 |
raise ValueError("No images or start atoms found in parameters")
|
89 |
-
|
90 |
calculator_name = parameters["calculator_name"]
|
91 |
|
92 |
return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"
|
@@ -156,20 +152,17 @@ def run(
|
|
156 |
criterion = criterion or {}
|
157 |
|
158 |
optimizer_instance = optimizer(neb, trajectory=traj_file, **optimizer_kwargs) # type: ignore
|
159 |
-
|
|
|
|
|
160 |
optimizer_instance.run(**criterion)
|
161 |
|
162 |
neb_tool = NEBTools(neb.images)
|
163 |
-
barrier = neb_tool.get_barrier()
|
164 |
-
|
165 |
-
forcefit = fit_images(neb.images)
|
166 |
-
|
167 |
-
images = neb.images
|
168 |
|
169 |
return {
|
170 |
-
"barrier":
|
171 |
-
"images": images,
|
172 |
-
"forcefit":
|
173 |
}
|
174 |
|
175 |
|
@@ -261,7 +254,7 @@ def run_from_end_points(
|
|
261 |
)
|
262 |
)
|
263 |
|
264 |
-
images = [s.to_ase_atoms() for s in path]
|
265 |
|
266 |
return run.with_options(
|
267 |
refresh_cache=not cache_subtasks,
|
|
|
39 |
from __future__ import annotations
|
40 |
|
41 |
from pathlib import Path
|
42 |
+
from typing import Any, Literal
|
43 |
|
44 |
from prefect import task
|
45 |
from prefect.cache_policies import INPUTS, TASK_SOURCE
|
|
|
54 |
from ase.utils.forcecurve import fit_images
|
55 |
from mlip_arena.models import MLIPEnum
|
56 |
from mlip_arena.tasks.optimize import run as OPT
|
57 |
+
from mlip_arena.tasks.utils import get_calculator, logger, pformat
|
58 |
from pymatgen.io.ase import AseAtomsAdaptor
|
59 |
|
|
|
|
|
|
|
|
|
60 |
_valid_optimizers: dict[str, Optimizer] = {
|
61 |
"MDMin": MDMin,
|
62 |
"FIRE": FIRE,
|
|
|
82 |
atoms = parameters["start"]
|
83 |
else:
|
84 |
raise ValueError("No images or start atoms found in parameters")
|
85 |
+
|
86 |
calculator_name = parameters["calculator_name"]
|
87 |
|
88 |
return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"
|
|
|
152 |
criterion = criterion or {}
|
153 |
|
154 |
optimizer_instance = optimizer(neb, trajectory=traj_file, **optimizer_kwargs) # type: ignore
|
155 |
+
logger.info(f"Using optimizer: {optimizer_instance}")
|
156 |
+
logger.info(pformat(optimizer_kwargs))
|
157 |
+
logger.info(f"Criterion: {pformat(criterion)}")
|
158 |
optimizer_instance.run(**criterion)
|
159 |
|
160 |
neb_tool = NEBTools(neb.images)
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
return {
|
163 |
+
"barrier": neb_tool.get_barrier(),
|
164 |
+
"images": neb.images,
|
165 |
+
"forcefit": fit_images(neb.images),
|
166 |
}
|
167 |
|
168 |
|
|
|
254 |
)
|
255 |
)
|
256 |
|
257 |
+
images = [s.to_ase_atoms(msonable=False) for s in path]
|
258 |
|
259 |
return run.with_options(
|
260 |
refresh_cache=not cache_subtasks,
|
mlip_arena/tasks/optimize.py
CHANGED
@@ -15,7 +15,8 @@ from ase.filters import Filter
|
|
15 |
from ase.optimize import * # type: ignore
|
16 |
from ase.optimize.optimize import Optimizer
|
17 |
from mlip_arena.models import MLIPEnum
|
18 |
-
from mlip_arena.tasks.utils import get_calculator
|
|
|
19 |
|
20 |
_valid_filters: dict[str, Filter] = {
|
21 |
"Filter": Filter,
|
@@ -94,16 +95,20 @@ def run(
|
|
94 |
|
95 |
if isinstance(filter, type) and issubclass(filter, Filter):
|
96 |
filter_instance = filter(atoms, **filter_kwargs)
|
97 |
-
|
|
|
98 |
|
99 |
-
optimizer_instance = optimizer(
|
100 |
-
|
|
|
|
|
101 |
|
102 |
optimizer_instance.run(**criterion)
|
103 |
-
|
104 |
elif filter is None:
|
105 |
optimizer_instance = optimizer(atoms, **optimizer_kwargs)
|
106 |
-
|
|
|
|
|
107 |
optimizer_instance.run(**criterion)
|
108 |
|
109 |
return {
|
|
|
15 |
from ase.optimize import * # type: ignore
|
16 |
from ase.optimize.optimize import Optimizer
|
17 |
from mlip_arena.models import MLIPEnum
|
18 |
+
from mlip_arena.tasks.utils import get_calculator, logger, pformat
|
19 |
+
|
20 |
|
21 |
_valid_filters: dict[str, Filter] = {
|
22 |
"Filter": Filter,
|
|
|
95 |
|
96 |
if isinstance(filter, type) and issubclass(filter, Filter):
|
97 |
filter_instance = filter(atoms, **filter_kwargs)
|
98 |
+
logger.info(f"Using filter: {filter_instance}")
|
99 |
+
logger.info(pformat(filter_kwargs))
|
100 |
|
101 |
+
optimizer_instance = optimizer(filter_instance, **optimizer_kwargs)
|
102 |
+
logger.info(f"Using optimizer: {optimizer_instance}")
|
103 |
+
logger.info(pformat(optimizer_kwargs))
|
104 |
+
logger.info(f"Criterion: {pformat(criterion)}")
|
105 |
|
106 |
optimizer_instance.run(**criterion)
|
|
|
107 |
elif filter is None:
|
108 |
optimizer_instance = optimizer(atoms, **optimizer_kwargs)
|
109 |
+
logger.info(f"Using optimizer: {optimizer_instance}")
|
110 |
+
logger.info(pformat(optimizer_kwargs))
|
111 |
+
logger.info(f"Criterion: {pformat(criterion)}")
|
112 |
optimizer_instance.run(**criterion)
|
113 |
|
114 |
return {
|
mlip_arena/tasks/utils.py
CHANGED
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
5 |
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
|
6 |
|
7 |
from ase import units
|
8 |
-
from ase.calculators.calculator import Calculator
|
9 |
from ase.calculators.mixing import SumCalculator
|
10 |
from mlip_arena.models import MLIPEnum
|
11 |
from mlip_arena.models.utils import get_freer_device
|
@@ -21,7 +21,7 @@ from pprint import pformat
|
|
21 |
|
22 |
|
23 |
def get_calculator(
|
24 |
-
calculator_name: str | MLIPEnum | Calculator,
|
25 |
calculator_kwargs: dict | None,
|
26 |
dispersion: bool = False,
|
27 |
dispersion_kwargs: dict | None = None,
|
@@ -30,22 +30,24 @@ def get_calculator(
|
|
30 |
"""Get a calculator with optional dispersion correction."""
|
31 |
device = device or str(get_freer_device())
|
32 |
|
33 |
-
logger.info("Using device:
|
34 |
|
35 |
calculator_kwargs = calculator_kwargs or {}
|
36 |
|
37 |
if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
|
38 |
-
assert issubclass(calculator_name.value, Calculator)
|
39 |
calc = calculator_name.value(**calculator_kwargs)
|
40 |
elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
|
41 |
calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
|
42 |
-
elif isinstance(calculator_name,
|
43 |
-
logger.warning("Using custom calculator: {calculator_name}")
|
|
|
|
|
|
|
44 |
calc = calculator_name
|
45 |
else:
|
46 |
raise ValueError(f"Invalid calculator: {calculator_name}")
|
47 |
|
48 |
-
logger.info("Using calculator:
|
49 |
if calculator_kwargs:
|
50 |
logger.info(pformat(calculator_kwargs))
|
51 |
|
@@ -61,9 +63,9 @@ def get_calculator(
|
|
61 |
)
|
62 |
calc = SumCalculator([calc, disp_calc])
|
63 |
|
64 |
-
logger.info("Using dispersion:
|
65 |
if dispersion_kwargs:
|
66 |
logger.info(pformat(dispersion_kwargs))
|
67 |
|
68 |
-
assert isinstance(calc, Calculator
|
69 |
return calc
|
|
|
5 |
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
|
6 |
|
7 |
from ase import units
|
8 |
+
from ase.calculators.calculator import Calculator, BaseCalculator
|
9 |
from ase.calculators.mixing import SumCalculator
|
10 |
from mlip_arena.models import MLIPEnum
|
11 |
from mlip_arena.models.utils import get_freer_device
|
|
|
21 |
|
22 |
|
23 |
def get_calculator(
|
24 |
+
calculator_name: str | MLIPEnum | Calculator | SumCalculator,
|
25 |
calculator_kwargs: dict | None,
|
26 |
dispersion: bool = False,
|
27 |
dispersion_kwargs: dict | None = None,
|
|
|
30 |
"""Get a calculator with optional dispersion correction."""
|
31 |
device = device or str(get_freer_device())
|
32 |
|
33 |
+
logger.info(f"Using device: {device}")
|
34 |
|
35 |
calculator_kwargs = calculator_kwargs or {}
|
36 |
|
37 |
if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
|
|
|
38 |
calc = calculator_name.value(**calculator_kwargs)
|
39 |
elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
|
40 |
calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
|
41 |
+
elif isinstance(calculator_name, type) and issubclass(calculator_name, BaseCalculator):
|
42 |
+
logger.warning(f"Using custom calculator class: {calculator_name}")
|
43 |
+
calc = calculator_name(**calculator_kwargs)
|
44 |
+
elif isinstance(calculator_name, Calculator | SumCalculator):
|
45 |
+
logger.warning(f"Using custom calculator object (kwargs are ignored): {calculator_name}")
|
46 |
calc = calculator_name
|
47 |
else:
|
48 |
raise ValueError(f"Invalid calculator: {calculator_name}")
|
49 |
|
50 |
+
logger.info(f"Using calculator: {calc}")
|
51 |
if calculator_kwargs:
|
52 |
logger.info(pformat(calculator_kwargs))
|
53 |
|
|
|
63 |
)
|
64 |
calc = SumCalculator([calc, disp_calc])
|
65 |
|
66 |
+
logger.info(f"Using dispersion: {disp_calc}")
|
67 |
if dispersion_kwargs:
|
68 |
logger.info(pformat(dispersion_kwargs))
|
69 |
|
70 |
+
assert isinstance(calc, Calculator | SumCalculator)
|
71 |
return calc
|