annabossler commited on
Commit
c2be249
·
verified ·
1 Parent(s): b786476

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -165
app.py CHANGED
@@ -1,14 +1,25 @@
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
- from ase import Atoms
5
- from ase.io import read
6
  import tempfile
7
  import os
 
 
 
 
 
 
 
 
 
 
 
8
  from orb_models.forcefield import pretrained
9
  from orb_models.forcefield.calculator import ORBCalculator
10
 
11
- # Global variable for the model
 
 
12
  model_calc = None
13
 
14
  def load_orbmol_model():
@@ -19,7 +30,7 @@ def load_orbmol_model():
19
  print("Loading OrbMol model...")
20
  orbff = pretrained.orb_v3_conservative_inf_omat(
21
  device="cpu",
22
- precision="float32" # más seguro que "float32-high" según la versión
23
  )
24
  model_calc = ORBCalculator(orbff, device="cpu")
25
  print("✅ OrbMol model loaded successfully")
@@ -28,12 +39,11 @@ def load_orbmol_model():
28
  model_calc = None
29
  return model_calc
30
 
 
 
 
31
  def predict_molecule(xyz_content, charge=0, spin_multiplicity=1):
32
- """
33
- Main function: XYZ → OrbMol → Results
34
- """
35
  try:
36
- # Load model
37
  calc = load_orbmol_model()
38
  if calc is None:
39
  return "❌ Error: Could not load OrbMol model", ""
@@ -41,183 +51,127 @@ def predict_molecule(xyz_content, charge=0, spin_multiplicity=1):
41
  if not xyz_content.strip():
42
  return "❌ Error: Please enter XYZ coordinates", ""
43
 
44
- # Create temporary file with XYZ
45
  with tempfile.NamedTemporaryFile(mode='w', suffix='.xyz', delete=False) as f:
46
  f.write(xyz_content)
47
  xyz_file = f.name
48
 
49
- # Read molecular structure
50
  atoms = read(xyz_file)
51
-
52
- # Configure charge and spin (IMPORTANT for OrbMol!)
53
- atoms.info = {
54
- "charge": int(charge),
55
- "spin": int(spin_multiplicity)
56
- }
57
-
58
- # Assign OrbMol calculator
59
  atoms.calc = calc
60
 
61
- # Make the prediction!
62
- energy = atoms.get_potential_energy() # In eV
63
- forces = atoms.get_forces() # In eV/Å
64
-
65
- # Format results nicely
66
- result = f"""
67
- 🔋 **Total Energy**: {energy:.6f} eV
68
-
69
- ⚡ **Atomic Forces**:
70
- """
71
 
72
- for i, force in enumerate(forces):
73
- result += f"Atom {i+1}: [{force[0]:.4f}, {force[1]:.4f}, {force[2]:.4f}] eV/Å\n"
 
74
 
75
- # Additional statistics
76
  max_force = np.max(np.linalg.norm(forces, axis=1))
77
  result += f"\n📊 **Max Force**: {max_force:.4f} eV/Å"
78
 
79
- # Clean up temporary file
80
  os.unlink(xyz_file)
81
-
82
  return result, "✅ Calculation completed with OrbMol"
83
-
84
  except Exception as e:
85
  return f"❌ Error during calculation: {str(e)}", "Error"
86
 
87
- # Predefined examples
88
- examples = [
89
- ["""2
90
- Hydrogen molecule
91
- H 0.0 0.0 0.0
92
- H 0.0 0.0 0.74""", 0, 1],
93
-
94
- ["""3
95
- Water molecule
96
- O 0.0000 0.0000 0.0000
97
- H 0.7571 0.0000 0.5864
98
- H -0.7571 0.0000 0.5864""", 0, 1],
99
-
100
- ["""4
101
- Methane
102
- C 0.0000 0.0000 0.0000
103
- H 1.0890 0.0000 0.0000
104
- H -0.3630 1.0267 0.0000
105
- H -0.3630 -0.5133 0.8887
106
- H -0.3630 -0.5133 -0.8887""", 0, 1]
107
- ]
108
-
109
- # Gradio interface - using FAIR Chem UMA style
110
- with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
111
-
112
- with gr.Row():
113
- with gr.Column(scale=2):
114
- with gr.Column(variant="panel"):
115
- gr.Markdown("# OrbMol Demo - Quantum-Accurate Molecular Predictions")
116
-
117
- gr.Markdown("""
118
- **OrbMol** is a neural network potential trained on the **OMol25** dataset (100M+ high-accuracy DFT calculations).
119
-
120
- Predicts **energies** and **forces** with quantum accuracy, optimized for:
121
- * 🧬 Biomolecules
122
- * ⚗️ Metal complexes
123
- * 🔋 Electrolytes
124
- """)
125
-
126
- gr.Markdown("## Simulation inputs")
127
-
128
- with gr.Column(variant="panel"):
129
- gr.Markdown("### Input molecular structure")
130
-
131
- xyz_input = gr.Textbox(
132
- label="XYZ Coordinates",
133
- placeholder="""3
134
- Water molecule
135
- O 0.0000 0.0000 0.0000
136
- H 0.7571 0.0000 0.5864
137
- H -0.7571 0.0000 0.5864""",
138
- lines=12,
139
- info="Paste XYZ coordinates of your molecule here"
140
- )
141
-
142
- gr.Markdown("OMol-specific settings for total charge and spin multiplicity")
143
- with gr.Row():
144
- charge_input = gr.Slider(
145
- value=0, label="Total Charge", minimum=-10, maximum=10, step=1
146
- )
147
- spin_input = gr.Slider(
148
- value=1, maximum=11, minimum=1, step=1, label="Spin Multiplicity"
149
- )
150
-
151
- predict_btn = gr.Button("Run OrbMol Prediction", variant="primary", size="lg")
152
-
153
- with gr.Column(variant="panel", elem_id="results", min_width=500):
154
- gr.Markdown("## OrbMol Prediction Results")
155
-
156
- results_output = gr.Textbox(
157
- label="Energy & Forces",
158
- lines=15,
159
- interactive=False,
160
- info="OrbMol energy and force predictions"
161
- )
162
 
163
- status_output = gr.Textbox(
164
- label="Status",
165
- interactive=False,
166
- max_lines=1
167
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- # Examples section
170
- gr.Markdown("### 🧪 Try These Examples")
171
- gr.Examples(
172
- examples=examples,
173
- inputs=[
174
- xyz_input,
175
- gr.Slider(visible=False, minimum=-10, maximum=10), # charge
176
- gr.Slider(visible=False, minimum=1, maximum=11) # spin
177
- ],
178
- label="Click any example to load it"
179
- )
180
-
181
- # Connect button to function
182
- predict_btn.click(
183
- predict_molecule,
184
- inputs=[xyz_input, charge_input, spin_input],
185
- outputs=[results_output, status_output]
186
- )
187
-
188
- # Footer info - matching FAIR Chem UMA style
189
- with gr.Sidebar(open=True):
190
- gr.Markdown("## Learn more about OrbMol")
191
- with gr.Accordion("What is OrbMol?", open=False):
192
- gr.Markdown("""
193
- * OrbMol is a neural network potential for molecular property prediction with quantum-level accuracy
194
- * Built on the Orb-v3 architecture and trained on OMol25 dataset (100M+ DFT calculations)
195
- * Optimized for biomolecules, metal complexes, and electrolytes
196
- * Supports configurable charge and spin multiplicity
197
-
198
- [Read more about OrbMol](https://orbitalmaterials.com/posts/orbmol-extending-orb-to-molecular-systems)
199
- """)
200
-
201
- with gr.Accordion("Model Disclaimers", open=False):
202
- gr.Markdown("""
203
- * While OrbMol represents significant progress in molecular ML potentials, the model has limitations
204
- * Always validate results for your specific use case
205
- * Consider the limitations of the ωB97M-V/def2-TZVPD level of theory used in training
206
- """)
207
-
208
- with gr.Accordion("Open source packages", open=False):
209
- gr.Markdown("""
210
- * Model code available at [orbital-materials/orb-models](https://github.com/orbital-materials/orb-models)
211
- * This demo uses ASE, Gradio, and other open source packages
212
- """)
213
-
214
- # Load model on startup
215
  print("🚀 Starting OrbMol model loading...")
216
  load_orbmol_model()
217
 
218
  if __name__ == "__main__":
219
- demo.launch(
220
- server_name="0.0.0.0",
221
- server_port=7860,
222
- show_error=True
223
- )
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
 
 
4
  import tempfile
5
  import os
6
+
7
+ from ase.io import read
8
+ from ase import units
9
+ from ase.optimize import LBFGS
10
+ from ase.md.verlet import VelocityVerlet
11
+ from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
12
+ from ase.md import MDLogger
13
+ from ase.io.trajectory import Trajectory
14
+
15
+ import py3Dmol
16
+
17
  from orb_models.forcefield import pretrained
18
  from orb_models.forcefield.calculator import ORBCalculator
19
 
20
+ # -----------------------------
21
+ # Global model
22
+ # -----------------------------
23
  model_calc = None
24
 
25
  def load_orbmol_model():
 
30
  print("Loading OrbMol model...")
31
  orbff = pretrained.orb_v3_conservative_inf_omat(
32
  device="cpu",
33
+ precision="float32-high"
34
  )
35
  model_calc = ORBCalculator(orbff, device="cpu")
36
  print("✅ OrbMol model loaded successfully")
 
39
  model_calc = None
40
  return model_calc
41
 
42
+ # -----------------------------
43
+ # Single-point calculation
44
+ # -----------------------------
45
  def predict_molecule(xyz_content, charge=0, spin_multiplicity=1):
 
 
 
46
  try:
 
47
  calc = load_orbmol_model()
48
  if calc is None:
49
  return "❌ Error: Could not load OrbMol model", ""
 
51
  if not xyz_content.strip():
52
  return "❌ Error: Please enter XYZ coordinates", ""
53
 
 
54
  with tempfile.NamedTemporaryFile(mode='w', suffix='.xyz', delete=False) as f:
55
  f.write(xyz_content)
56
  xyz_file = f.name
57
 
 
58
  atoms = read(xyz_file)
59
+ atoms.info = {"charge": int(charge), "spin": int(spin_multiplicity)}
 
 
 
 
 
 
 
60
  atoms.calc = calc
61
 
62
+ energy = atoms.get_potential_energy()
63
+ forces = atoms.get_forces()
 
 
 
 
 
 
 
 
64
 
65
+ result = f"🔋 **Total Energy**: {energy:.6f} eV\n\n⚡ **Atomic Forces**:\n"
66
+ for i, f in enumerate(forces):
67
+ result += f"Atom {i+1}: [{f[0]:.4f}, {f[1]:.4f}, {f[2]:.4f}] eV/Å\n"
68
 
 
69
  max_force = np.max(np.linalg.norm(forces, axis=1))
70
  result += f"\n📊 **Max Force**: {max_force:.4f} eV/Å"
71
 
 
72
  os.unlink(xyz_file)
 
73
  return result, "✅ Calculation completed with OrbMol"
 
74
  except Exception as e:
75
  return f"❌ Error during calculation: {str(e)}", "Error"
76
 
77
+ # -----------------------------
78
+ # Helper: convert trajectory → HTML animation
79
+ # -----------------------------
80
+ def traj_to_html(traj_file):
81
+ traj = Trajectory(traj_file)
82
+ view = py3Dmol.view(width=400, height=400)
83
+ for atoms in traj:
84
+ symbols = atoms.get_chemical_symbols()
85
+ xyz = atoms.get_positions()
86
+ mol = ""
87
+ for s, (x, y, z) in zip(symbols, xyz):
88
+ mol += f"{s} {x} {y} {z}\n"
89
+ view.addModel(mol, "xyz")
90
+ view.setStyle({"stick": {}})
91
+ view.zoomTo()
92
+ view.animate({"loop": "forward"})
93
+ return view._make_html()
94
+
95
+ # -----------------------------
96
+ # Molecular dynamics simulation
97
+ # -----------------------------
98
+ def run_md(xyz_content, charge=0, spin_multiplicity=1, steps=100, temperature=300, timestep=1.0):
99
+ try:
100
+ calc = load_orbmol_model()
101
+ if calc is None:
102
+ return "❌ Error: Could not load OrbMol model", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ if not xyz_content.strip():
105
+ return "❌ Error: Please enter XYZ coordinates", ""
106
+
107
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.xyz', delete=False) as f:
108
+ f.write(xyz_content)
109
+ xyz_file = f.name
110
+
111
+ atoms = read(xyz_file)
112
+ atoms.info = {"charge": int(charge), "spin": int(spin_multiplicity)}
113
+ atoms.calc = calc
114
+
115
+ # Pre-relaxation
116
+ opt = LBFGS(atoms)
117
+ opt.run(fmax=0.05, steps=20)
118
+
119
+ # Velocities
120
+ MaxwellBoltzmannDistribution(atoms, temperature_K=2 * temperature)
121
+
122
+ # MD setup
123
+ dyn = VelocityVerlet(atoms, timestep=timestep * units.fs)
124
+
125
+ traj_file = tempfile.NamedTemporaryFile(suffix=".traj", delete=False)
126
+ traj = Trajectory(traj_file.name, "w", atoms)
127
+ dyn.attach(traj.write, interval=1)
128
+
129
+ dyn.run(steps)
130
+
131
+ html = traj_to_html(traj_file.name)
132
+
133
+ os.unlink(xyz_file)
134
+ return f"✅ MD completed: {steps} steps at {temperature} K", html
135
+ except Exception as e:
136
+ return f"❌ Error during MD simulation: {str(e)}", ""
137
+
138
+ # -----------------------------
139
+ # Gradio UI
140
+ # -----------------------------
141
+ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol + MD Demo") as demo:
142
+ gr.Markdown("# OrbMol Demo with Molecular Dynamics")
143
+
144
+ with gr.Tab("Single Point Energy"):
145
+ xyz_input = gr.Textbox(label="XYZ Coordinates", lines=12)
146
+ charge_input = gr.Slider(value=0, minimum=-10, maximum=10, step=1, label="Charge")
147
+ spin_input = gr.Slider(value=1, minimum=1, maximum=11, step=1, label="Spin Multiplicity")
148
+ run_btn = gr.Button("Run OrbMol Calculation")
149
+ results_output = gr.Textbox(label="Results", lines=15)
150
+ status_output = gr.Textbox(label="Status")
151
+ run_btn.click(
152
+ predict_molecule,
153
+ inputs=[xyz_input, charge_input, spin_input],
154
+ outputs=[results_output, status_output],
155
+ )
156
+
157
+ with gr.Tab("Molecular Dynamics"):
158
+ xyz_input_md = gr.Textbox(label="XYZ Coordinates", lines=12)
159
+ charge_input_md = gr.Slider(value=0, minimum=-10, maximum=10, step=1, label="Charge")
160
+ spin_input_md = gr.Slider(value=1, minimum=1, maximum=11, step=1, label="Spin Multiplicity")
161
+ steps_input = gr.Slider(value=100, minimum=10, maximum=1000, step=10, label="Steps")
162
+ temp_input = gr.Slider(value=300, minimum=10, maximum=1000, step=10, label="Temperature (K)")
163
+ timestep_input = gr.Slider(value=1.0, minimum=0.1, maximum=5.0, step=0.1, label="Timestep (fs)")
164
+ run_md_btn = gr.Button("Run MD Simulation")
165
+ md_status = gr.Textbox(label="MD Status", lines=2)
166
+ md_view = gr.HTML()
167
+ run_md_btn.click(
168
+ run_md,
169
+ inputs=[xyz_input_md, charge_input_md, spin_input_md, steps_input, temp_input, timestep_input],
170
+ outputs=[md_status, md_view],
171
+ )
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  print("🚀 Starting OrbMol model loading...")
174
  load_orbmol_model()
175
 
176
  if __name__ == "__main__":
177
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)