Spaces:
Running
Running
File size: 4,506 Bytes
3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 75ac94f 3397e50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# import functools
from pathlib import Path
import pandas as pd
from ase import Atoms
from ase.db import connect
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
from prefect import flow, task
from prefect.cache_policies import INPUTS, TASK_SOURCE
from prefect.runtime import task_run
from prefect_dask import DaskTaskRunner
from mlip_arena.models import REGISTRY, MLIPEnum
from mlip_arena.tasks.utils import get_calculator
@task
def load_wbm_structures():
"""
Load the WBM structures from an ASE database file.
Reads structures from 'wbm_structures.db' and yields them as ASE Atoms objects
with additional metadata preserved from the database.
Yields:
ase.Atoms: Individual atomic structures from the WBM database with preserved
metadata in the .info dictionary.
"""
with connect("../wbm_structures.db") as db:
for row in db.select():
yield row.toatoms(add_additional_information=True)
# def save_result(
# tsk: Task,
# run: TaskRun,
# state: State,
# model_name: str,
# id: str,
# ):
# result = run.state.result()
# assert isinstance(result, dict)
# result["method"] = model_name
# result["id"] = id
# result.pop("atoms", None)
# fpath = Path(f"{model_name}")
# fpath.mkdir(exist_ok=True)
# fpath = fpath / f"{result['id']}.pkl"
# df = pd.DataFrame([result])
# df.to_pickle(fpath)
@task(
name="EOS bulk - WBM",
task_run_name=lambda: f"{task_run.task_name}: {task_run.parameters['atoms'].get_chemical_formula()} - {task_run.parameters['model'].name}",
cache_policy=TASK_SOURCE + INPUTS,
)
def eos_bulk(atoms: Atoms, model: MLIPEnum):
from mlip_arena.tasks.eos import run as EOS
from mlip_arena.tasks.optimize import run as OPT
calculator = get_calculator(
model
) # avoid sending entire model over prefect and select freer GPU
result = OPT.with_options(
refresh_cache=True,
)(
atoms,
calculator,
optimizer="FIRE",
criterion=dict(
fmax=0.1,
),
)
result = EOS.with_options(
refresh_cache=True,
# on_completion=[functools.partial(
# save_result,
# model_name=model.name,
# id=atoms.info["key_value_pairs"]["wbm_id"],
# )],
)(
atoms=result["atoms"],
calculator=calculator,
optimizer="FIRE",
npoints=21,
max_abs_strain=0.2,
concurrent=False
)
result["method"] = model.name
result["id"] = atoms.info["key_value_pairs"]["wbm_id"]
result.pop("atoms", None)
fpath = Path(f"{model.name}")
fpath.mkdir(exist_ok=True)
fpath = fpath / f"{result['id']}.pkl"
df = pd.DataFrame([result])
df.to_pickle(fpath)
return df
@flow
def submit_tasks():
futures = []
for atoms in load_wbm_structures():
model = MLIPEnum["eSEN"]
# for model in MLIPEnum:
if "eos_bulk" not in REGISTRY[model.name].get("gpu-tasks", []):
continue
try:
result = eos_bulk.with_options(
refresh_cache=True
).submit(atoms, model)
futures.append(result)
except Exception:
# print(f"Failed to submit task for {model.name}: {e}")
continue
return [f.result(raise_on_failure=False) for f in futures]
if __name__ == "__main__":
nodes_per_alloc = 1
gpus_per_alloc = 1
ntasks = 1
cluster_kwargs = dict(
cores=1,
memory="64 GB",
shebang="#!/bin/bash",
account="m3828",
walltime="00:30:00",
job_mem="0",
job_script_prologue=[
"source ~/.bashrc",
"module load python",
"module load cudatoolkit/12.4",
"source activate /pscratch/sd/c/cyrusyc/.conda/dev",
],
job_directives_skip=["-n", "--cpus-per-task", "-J"],
job_extra_directives=[
"-J eos_bulk",
"-q regular",
f"-N {nodes_per_alloc}",
"-C gpu",
f"-G {gpus_per_alloc}",
# "--exclusive",
],
)
cluster = SLURMCluster(**cluster_kwargs)
print(cluster.job_script())
cluster.adapt(minimum_jobs=50, maximum_jobs=50)
client = Client(cluster)
submit_tasks.with_options(
task_runner=DaskTaskRunner(address=client.scheduler.address),
log_prints=True,
)()
|