orbmol / app.py
annabossler's picture
Update app.py
1fedfc5 verified
raw
history blame
12.6 kB
import os
import tempfile
import numpy as np
import gradio as gr
from ase.io import read, write
from ase.io.trajectory import Trajectory
import hashlib
# ==== Forzar visor HTML con 3Dmol.js ====
HAVE_MOL3D = False
# ==== Fallback HTML con 3Dmol.js ====
def traj_to_html(traj_path, width=520, height=520, interval_ms=200):
"""
Lee una .traj de ASE y genera un visor HTML (3Dmol.js) con animación.
"""
if not traj_path or not os.path.exists(traj_path):
return "<div style='color:#b00; padding:20px;'>No trajectory file found</div>"
viewer_id = f"viewer_{abs(hash(traj_path)) % 100000}"
try:
traj = Trajectory(traj_path)
if len(traj) == 0:
return "<div style='color:#555; padding:20px;'>Empty trajectory</div>"
except Exception as e:
return f"<div style='color:#b00; padding:20px;'>Error: {e}</div>"
xyz_frames = []
for atoms in traj:
symbols = atoms.get_chemical_symbols()
coords = atoms.get_positions()
parts = [str(len(symbols)), "frame"]
for s, (x, y, z) in zip(symbols, coords):
parts.append(f"{s} {x:.6f} {y:.6f} {z:.6f}")
xyz_frames.append("\n".join(parts))
frames_json = str(xyz_frames).replace("'", '"')
html = f"""
<div style="margin-bottom:10px; padding:10px; background:#f5f5f5; border-radius:5px;">
<strong>🧬 3D Molecular Viewer</strong> — {len(xyz_frames)} frames
</div>
<div id="{viewer_id}" style="width:{width}px; height:{height}px; position:relative; border:2px solid #ddd; border-radius:8px; background:#fafafa;"></div>
<script>
if (typeof window.$3Dmol === 'undefined') {{
var script = document.createElement('script');
script.src = 'https://3dmol.org/build/3Dmol-min.js';
script.onload = function() {{ setTimeout(function() {{ initViewer_{viewer_id}(); }}, 100); }};
document.head.appendChild(script);
}} else {{
initViewer_{viewer_id}();
}}
function initViewer_{viewer_id}() {{
var el = document.getElementById("{viewer_id}");
if (!el || typeof $3Dmol === "undefined") return;
var viewer = $3Dmol.createViewer(el, {{backgroundColor: 'white'}});
var frames = {frames_json};
var currentFrame = 0;
function showFrame(i) {{
viewer.clear();
viewer.addModel(frames[i], "xyz");
viewer.setStyle({{}}, {{stick: {{}}, sphere: {{}}}});
viewer.zoomTo();
viewer.render();
}}
showFrame(0);
if (frames.length > 1) {{
setInterval(function() {{
currentFrame = (currentFrame + 1) % frames.length;
showFrame(currentFrame);
}}, {interval_ms});
}}
}}
</script>
"""
return html
# ==== OrbMol SPE ====
from orb_models.forcefield import pretrained
from orb_models.forcefield.calculator import ORBCalculator
_MODEL_CALC = None
def _load_orbmol_calc():
global _MODEL_CALC
if _MODEL_CALC is None:
orbff = pretrained.orb_v3_conservative_inf_omat(
device="cpu", precision="float32-high"
)
_MODEL_CALC = ORBCalculator(orbff, device="cpu")
return _MODEL_CALC
def predict_molecule(structure_file, charge=0, spin_multiplicity=1):
"""
Single Point Energy + fuerzas (OrbMol). Acepta archivos subidos.
"""
try:
calc = _load_orbmol_calc()
if not structure_file:
return "Error: Please upload a structure file", "Error"
# structure_file es directamente el path del archivo en Gradio
file_path = structure_file
# Verificar que el archivo existe y no está vacío
if not os.path.exists(file_path):
return f"Error: File not found: {file_path}", "Error"
if os.path.getsize(file_path) == 0:
return f"Error: Empty file: {file_path}", "Error"
atoms = read(file_path)
atoms.info = {"charge": int(charge), "spin": int(spin_multiplicity)}
atoms.calc = calc
energy = atoms.get_potential_energy() # eV
forces = atoms.get_forces() # eV/Å
lines = [f"Total Energy: {energy:.6f} eV", "", "Atomic Forces:"]
for i, fc in enumerate(forces):
lines.append(f"Atom {i+1}: [{fc[0]:.4f}, {fc[1]:.4f}, {fc[2]:.4f}] eV/Å")
max_force = float(np.max(np.linalg.norm(forces, axis=1)))
lines += ["", f"Max Force: {max_force:.4f} eV/Å"]
return "\n".join(lines), "Calculation completed with OrbMol"
except Exception as e:
return f"Error during calculation: {e}", "Error"
# ==== Simulaciones (helpers) ====
from simulation_scripts_orbmol import (
run_md_simulation,
run_relaxation_simulation,
)
# ==== Wrappers: usan archivos subidos ====
def md_wrapper(structure_file, charge, spin, steps, tempK, timestep_fs, ensemble):
try:
if not structure_file:
return ("Error: Please upload a structure file", None, "", "", "", None, "")
# structure_file es directamente el path del archivo
file_path = structure_file
traj_path, log_text, script_text, explanation = run_md_simulation(
file_path,
int(steps),
20,
float(timestep_fs),
float(tempK),
"NVT" if ensemble == "NVT" else "NVE",
int(charge),
int(spin),
)
status = f"MD completed: {int(steps)} steps at {int(tempK)} K ({ensemble})"
html_value = traj_to_html(traj_path)
return (status, traj_path, log_text, script_text, explanation, None, html_value)
except Exception as e:
return (f"Error: {e}", None, "", "", "", None, "")
def relax_wrapper(structure_file, steps, fmax, charge, spin, relax_cell):
try:
if not structure_file:
return ("Error: Please upload a structure file", None, "", "", "", None, "")
# structure_file es directamente el path del archivo
file_path = structure_file
traj_path, log_text, script_text, explanation = run_relaxation_simulation(
file_path,
int(steps),
float(fmax),
int(charge),
int(spin),
bool(relax_cell),
)
status = f"Relaxation finished (≤ {int(steps)} steps, fmax={float(fmax)} eV/Å)"
html_value = traj_to_html(traj_path)
return (status, traj_path, log_text, script_text, explanation, None, html_value)
except Exception as e:
return (f"Error: {e}", None, "", "", "", None, "")
# ==== UI ====
with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
with gr.Tabs():
# -------- SPE --------
with gr.Tab("Single Point Energy"):
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("# OrbMol — Quantum-Accurate Molecular Predictions")
gr.Markdown("Upload molecular structure files (.xyz, .pdb, .cif, .traj) for energy and force calculations.")
xyz_input = gr.File(
label="Upload Structure File (.xyz/.pdb/.cif/.traj)",
file_types=[".xyz", ".pdb", ".cif", ".traj", ".mol", ".sdf"],
file_count="single"
)
with gr.Row():
charge_input = gr.Slider(minimum=-10, maximum=10, value=0, step=1, label="Charge")
spin_input = gr.Slider(minimum=1, maximum=11, value=1, step=1, label="Spin Multiplicity")
run_spe = gr.Button("Run OrbMol Prediction", variant="primary")
with gr.Column(variant="panel", min_width=500):
spe_out = gr.Textbox(label="Energy & Forces", lines=15, interactive=False)
spe_status = gr.Textbox(label="Status", interactive=False, max_lines=1)
run_spe.click(predict_molecule, [xyz_input, charge_input, spin_input], [spe_out, spe_status])
# -------- MD --------
with gr.Tab("Molecular Dynamics"):
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("## Molecular Dynamics Simulation")
gr.Markdown("Upload your molecular structure and configure MD parameters.")
xyz_md = gr.File(
label="Upload Structure File (.xyz/.pdb/.cif/.traj)",
file_types=[".xyz", ".pdb", ".cif", ".traj", ".mol", ".sdf"],
file_count="single"
)
with gr.Row():
charge_md = gr.Slider(minimum=-10, maximum=10, value=0, step=1, label="Charge")
spin_md = gr.Slider(minimum=1, maximum=11, value=1, step=1, label="Spin Multiplicity")
with gr.Row():
steps_md = gr.Slider(minimum=10, maximum=2000, value=100, step=10, label="Steps")
temp_md = gr.Slider(minimum=10, maximum=1500, value=300, step=10, label="Temperature (K)")
with gr.Row():
timestep_md = gr.Slider(minimum=0.1, maximum=5.0, value=1.0, step=0.1, label="Timestep (fs)")
ensemble_md = gr.Radio(["NVE", "NVT"], value="NVE", label="Ensemble")
run_md_btn = gr.Button("Run MD Simulation", variant="primary")
with gr.Column(variant="panel", min_width=520):
md_status = gr.Textbox(label="MD Status", interactive=False)
md_traj = gr.File(label="Trajectory (.traj)", interactive=False)
md_viewer_placeholder = gr.HTML(visible=False)
md_html = gr.HTML(label="Trajectory Viewer")
md_log = gr.Textbox(label="Log", interactive=False, lines=15, max_lines=25)
md_script = gr.Code(label="Reproduction Script", language="python", interactive=False, lines=20, max_lines=30)
md_explain = gr.Markdown()
run_md_btn.click(
md_wrapper,
inputs=[xyz_md, charge_md, spin_md, steps_md, temp_md, timestep_md, ensemble_md],
outputs=[md_status, md_traj, md_log, md_script, md_explain, md_viewer_placeholder, md_html],
)
# -------- Relax --------
with gr.Tab("Relaxation / Optimization"):
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("## Structure Relaxation/Optimization")
gr.Markdown("Upload your molecular structure for geometry optimization.")
xyz_rlx = gr.File(
label="Upload Structure File (.xyz/.pdb/.cif/.traj)",
file_types=[".xyz", ".pdb", ".cif", ".traj", ".mol", ".sdf"],
file_count="single"
)
steps_rlx = gr.Slider(minimum=1, maximum=2000, value=300, step=1, label="Max Steps")
fmax_rlx = gr.Slider(minimum=0.001, maximum=0.5, value=0.05, step=0.001, label="Fmax (eV/Å)")
with gr.Row():
charge_rlx = gr.Slider(minimum=-10, maximum=10, value=0, step=1, label="Charge")
spin_rlx = gr.Slider(minimum=1, maximum=11, value=1, step=1, label="Spin")
relax_cell = gr.Checkbox(False, label="Relax Unit Cell")
run_rlx_btn = gr.Button("Run Optimization", variant="primary")
with gr.Column(variant="panel", min_width=520):
rlx_status = gr.Textbox(label="Status", interactive=False)
rlx_traj = gr.File(label="Trajectory (.traj)", interactive=False)
rlx_viewer_placeholder = gr.HTML(visible=False)
rlx_html = gr.HTML(label="Final Structure")
rlx_log = gr.Textbox(label="Log", interactive=False, lines=15, max_lines=25)
rlx_script = gr.Code(label="Reproduction Script", language="python", interactive=False, lines=20, max_lines=30)
rlx_explain = gr.Markdown()
run_rlx_btn.click(
relax_wrapper,
inputs=[xyz_rlx, steps_rlx, fmax_rlx, charge_rlx, spin_rlx, relax_cell],
outputs=[rlx_status, rlx_traj, rlx_log, rlx_script, rlx_explain, rlx_viewer_placeholder, rlx_html],
)
print("Starting OrbMol model loading…")
_ = _load_orbmol_calc()
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)