cyrusyc commited on
Commit
b881f30
1 Parent(s): 034838c
serve/tasks/homonuclear-diatomics.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import numpy.linalg as LA
5
+ import plotly.express as px
6
+ import streamlit as st
7
+ from ase.data import chemical_symbols
8
+ from ase.io import read
9
+ from scipy.interpolate import CubicSpline
10
+
11
+ st.markdown("# Homonuclear diatomics")
12
+
13
+ DATA_DIR = Path("mlip_arena/tasks/diatomics")
14
+
15
+
16
+ for i, symbol in enumerate(chemical_symbols[1:10]):
17
+
18
+ if i % 3 == 0:
19
+ cols = st.columns(3)
20
+
21
+ fpath = DATA_DIR / "gpaw" / f"{symbol+symbol}_AFM" / "traj.extxyz"
22
+
23
+ if not fpath.exists():
24
+ continue
25
+
26
+ trj = read(fpath, index=":")
27
+
28
+ rs, es, s2s = [], [], []
29
+
30
+ for atoms in trj:
31
+ rs.append(LA.norm(atoms.positions[1] - atoms.positions[0]))
32
+ es.append(atoms.get_potential_energy())
33
+ s2s.append(np.power(atoms.get_magnetic_moments(), 2).mean())
34
+
35
+ rs = np.array(rs)
36
+ ind = np.argsort(rs)
37
+ es = np.array(es)
38
+ s2s = np.array(s2s)
39
+
40
+ rs = rs[ind]
41
+ es = es[ind]
42
+ s2s = s2s[ind]
43
+
44
+ es = es - es[-1]
45
+
46
+ xs = np.linspace(rs.min()*0.99, rs.max()*1.01, int(5e2))
47
+
48
+ cs = CubicSpline(rs, es)
49
+ ys = cs(xs)
50
+
51
+ cs = CubicSpline(rs, s2s)
52
+ s2s = cs(xs)
53
+
54
+ ylo = min(ys.min()*1.5, -1)
55
+
56
+ fig = px.scatter(
57
+ x=xs, y=ys,
58
+ render_mode="webgl",
59
+ color=s2s,
60
+ range_color=[0, s2s.max()],
61
+ width=500,
62
+ range_y=[ylo, 1.2*(abs(ylo))],
63
+ # title=f"{atoms.get_chemical_formula()}",
64
+ labels={"x": "Bond length (Å)", "y": "Energy", "color": "Magnetic moment"},
65
+ )
66
+
67
+ cols[i % 3].title(f"{symbol+symbol}")
68
+ cols[i % 3].plotly_chart(fig, use_container_width=False)
69
+
70
+ # st.latex(r"\frac{d^2E}{dr^2} = \frac{d^2E}{dr^2}")
71
+
72
+ # st.components.v1.html(fig.to_html(include_mathjax='cdn'),height=500)
73
+