Upload 16 files
Browse files- .gitattributes +1 -0
- analysis.py +84 -0
- app.py +188 -0
- data/perovskite_dataset.pt +3 -0
- generate.py +93 -0
- model_weights.pth +3 -0
- rdf_analysis.png +0 -0
- requirements.txt +4 -0
- result_plot.png +3 -0
- src/__pycache__/layers.cpython-312.pyc +0 -0
- src/__pycache__/model.cpython-312.pyc +0 -0
- src/data_loader.py +89 -0
- src/layers.py +88 -0
- src/model.py +89 -0
- train.py +98 -0
- validate.py +66 -0
- visualize.py +80 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
result_plot.png filter=lfs diff=lfs merge=lfs -text
|
analysis.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from src.model import CrystalDiffusionModel
|
| 5 |
+
|
| 6 |
+
# Load your model and generate a crystal
|
| 7 |
+
def generate_crystal():
|
| 8 |
+
model = CrystalDiffusionModel()
|
| 9 |
+
model.load_state_dict(torch.load("model_weights.pth", map_location='cpu'))
|
| 10 |
+
model.eval()
|
| 11 |
+
|
| 12 |
+
# Generate 5 atoms
|
| 13 |
+
num_atoms = 5
|
| 14 |
+
z = torch.tensor([56, 22, 8, 8, 8]) # BaTiO3
|
| 15 |
+
|
| 16 |
+
# Graph Setup
|
| 17 |
+
row = torch.repeat_interleave(torch.arange(num_atoms), num_atoms)
|
| 18 |
+
col = torch.arange(num_atoms).repeat(num_atoms)
|
| 19 |
+
mask = row != col
|
| 20 |
+
edge_index = torch.stack([row[mask], col[mask]], dim=0)
|
| 21 |
+
|
| 22 |
+
# Diffusion
|
| 23 |
+
x = torch.randn(num_atoms, 3) # Start with noise
|
| 24 |
+
steps = 50
|
| 25 |
+
dt = 1.0 / steps
|
| 26 |
+
|
| 27 |
+
for i in range(steps):
|
| 28 |
+
t = torch.tensor([[1.0 - i*dt]])
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
pred = model(x, z, t, edge_index)
|
| 31 |
+
x = x + (pred - x) * 0.1
|
| 32 |
+
|
| 33 |
+
return x.numpy()
|
| 34 |
+
|
| 35 |
+
def compute_rdf(coords, box_size=5.0, bins=50):
|
| 36 |
+
"""
|
| 37 |
+
Calculates the Radial Distribution Function (RDF).
|
| 38 |
+
"""
|
| 39 |
+
distances = []
|
| 40 |
+
num_atoms = len(coords)
|
| 41 |
+
|
| 42 |
+
for i in range(num_atoms):
|
| 43 |
+
for j in range(i + 1, num_atoms):
|
| 44 |
+
dist = np.linalg.norm(coords[i] - coords[j])
|
| 45 |
+
distances.append(dist)
|
| 46 |
+
|
| 47 |
+
# Histogram
|
| 48 |
+
hist, bin_edges = np.histogram(distances, bins=bins, range=(0, box_size))
|
| 49 |
+
r = (bin_edges[:-1] + bin_edges[1:]) / 2
|
| 50 |
+
|
| 51 |
+
# Normalize (Volume correction)
|
| 52 |
+
dr = bin_edges[1] - bin_edges[0]
|
| 53 |
+
volume = 4 * np.pi * r**2 * dr
|
| 54 |
+
rdf = hist / (volume * num_atoms) # Density normalization
|
| 55 |
+
|
| 56 |
+
return r, rdf
|
| 57 |
+
|
| 58 |
+
def plot_comparison():
|
| 59 |
+
print("Generating Analysis Plot...")
|
| 60 |
+
|
| 61 |
+
# 1. Get RDF for Random Noise
|
| 62 |
+
noise = np.random.randn(5, 3)
|
| 63 |
+
r_noise, rdf_noise = compute_rdf(noise)
|
| 64 |
+
|
| 65 |
+
# 2. Get RDF for Generated Crystal
|
| 66 |
+
crystal = generate_crystal()
|
| 67 |
+
r_crys, rdf_crys = compute_rdf(crystal)
|
| 68 |
+
|
| 69 |
+
# 3. Plot
|
| 70 |
+
plt.figure(figsize=(10, 6))
|
| 71 |
+
plt.plot(r_noise, rdf_noise, label='Random Noise', linestyle='--', color='gray')
|
| 72 |
+
plt.plot(r_crys, rdf_crys, label='Generated Crystal (AI)', linewidth=3, color='blue')
|
| 73 |
+
|
| 74 |
+
plt.title("Radial Distribution Function (RDF) Analysis")
|
| 75 |
+
plt.xlabel("Distance (Angstroms)")
|
| 76 |
+
plt.ylabel("Probability Density")
|
| 77 |
+
plt.legend()
|
| 78 |
+
plt.grid(True, alpha=0.3)
|
| 79 |
+
|
| 80 |
+
plt.savefig("rdf_analysis.png")
|
| 81 |
+
print("✅ Saved 'rdf_analysis.png'. Put this in your README!")
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
plot_comparison()
|
app.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import py3Dmol
|
| 5 |
+
from stmol import showmol
|
| 6 |
+
from src.model import CrystalDiffusionModel
|
| 7 |
+
|
| 8 |
+
# --- PAGE CONFIGURATION ---
|
| 9 |
+
st.set_page_config(
|
| 10 |
+
page_title="CrystalDiff: AI Material Designer",
|
| 11 |
+
layout="wide",
|
| 12 |
+
page_icon="💎",
|
| 13 |
+
initial_sidebar_state="expanded"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
# --- SIDEBAR: CONTROLS & INFO ---
|
| 17 |
+
with st.sidebar:
|
| 18 |
+
st.title("💎 CrystalDiff Controls")
|
| 19 |
+
|
| 20 |
+
st.markdown("### 1. Select Chemistry")
|
| 21 |
+
target_atom = st.selectbox(
|
| 22 |
+
"Choose A-Site Cation",
|
| 23 |
+
["Ca (Calcium)", "Sr (Strontium)", "Ba (Barium)", "Pb (Lead)"],
|
| 24 |
+
index=1,
|
| 25 |
+
help="The large atom in the center of the cage."
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
st.markdown("### 2. Diffusion Settings")
|
| 29 |
+
steps = st.slider("Denoising Steps", 10, 100, 50, help="More steps = higher quality, but slower.")
|
| 30 |
+
noise_scale = st.slider("Initial Chaos (Noise)", 0.5, 2.0, 1.0, help="Higher noise means the AI has to be more creative.")
|
| 31 |
+
|
| 32 |
+
st.divider()
|
| 33 |
+
|
| 34 |
+
st.markdown("### 🧠 How it Works")
|
| 35 |
+
st.info("""
|
| 36 |
+
**Generative Diffusion:**
|
| 37 |
+
The model starts with random noise (chaos) and iteratively subtracts noise to find a stable crystal structure.
|
| 38 |
+
|
| 39 |
+
**E(n)-Equivariance:**
|
| 40 |
+
The AI uses a custom Graph Neural Network that respects the laws of physics (rotational symmetry).
|
| 41 |
+
""")
|
| 42 |
+
|
| 43 |
+
st.markdown("---")
|
| 44 |
+
st.caption("Built with PyTorch & Streamlit by Aditya Mangal. Inspired by DeepMind's work on generative models for materials science.")
|
| 45 |
+
|
| 46 |
+
# --- MAIN PAGE ---
|
| 47 |
+
st.title("💎 CrystalDiff: Generative Material Design")
|
| 48 |
+
st.markdown("""
|
| 49 |
+
This application uses **Geometric Deep Learning** to hallucinate new stable crystals.
|
| 50 |
+
It was trained on the **Materials Project** database to understand the chemical rules of **Perovskite Oxides ($ABO_3$)**.
|
| 51 |
+
""")
|
| 52 |
+
|
| 53 |
+
# Map selection to Atomic Number
|
| 54 |
+
atom_map = {
|
| 55 |
+
"Ca (Calcium)": 20, "Sr (Strontium)": 38,
|
| 56 |
+
"Ba (Barium)": 56, "Pb (Lead)": 82
|
| 57 |
+
}
|
| 58 |
+
selected_z = atom_map[target_atom]
|
| 59 |
+
formula_display = f"{target_atom.split()[0]}TiO₃"
|
| 60 |
+
|
| 61 |
+
# --- HELPER FUNCTIONS ---
|
| 62 |
+
@st.cache_resource
|
| 63 |
+
def load_model():
|
| 64 |
+
device = torch.device("cpu")
|
| 65 |
+
model = CrystalDiffusionModel()
|
| 66 |
+
try:
|
| 67 |
+
model.load_state_dict(torch.load("model_weights.pth", map_location=device))
|
| 68 |
+
model.eval()
|
| 69 |
+
return model, device
|
| 70 |
+
except FileNotFoundError:
|
| 71 |
+
return None, None
|
| 72 |
+
|
| 73 |
+
def calculate_metrics(pos, z):
|
| 74 |
+
"""Calculates bond lengths to validate physics."""
|
| 75 |
+
# Find Ti (22) and O (8)
|
| 76 |
+
ti_idx = [i for i, atom in enumerate(z) if atom == 22]
|
| 77 |
+
o_idx = [i for i, atom in enumerate(z) if atom == 8]
|
| 78 |
+
|
| 79 |
+
if not ti_idx or not o_idx: return 0.0
|
| 80 |
+
|
| 81 |
+
ti_pos = pos[ti_idx[0]]
|
| 82 |
+
dists = []
|
| 83 |
+
for o in o_idx:
|
| 84 |
+
d = np.linalg.norm(ti_pos - pos[o])
|
| 85 |
+
dists.append(d)
|
| 86 |
+
|
| 87 |
+
return np.mean(dists)
|
| 88 |
+
|
| 89 |
+
def make_view(pos, z):
|
| 90 |
+
"""Creates a 3D molecule view"""
|
| 91 |
+
view = py3Dmol.view(width=800, height=500)
|
| 92 |
+
xyz_str = f"{len(pos)}\nGenerated\n"
|
| 93 |
+
for i in range(len(pos)):
|
| 94 |
+
elem = "O" if z[i] == 8 else "Ti" if z[i] == 22 else target_atom.split()[0]
|
| 95 |
+
xyz_str += f"{elem} {pos[i,0]:.4f} {pos[i,1]:.4f} {pos[i,2]:.4f}\n"
|
| 96 |
+
view.addModel(xyz_str, "xyz")
|
| 97 |
+
# Style: spheres for atoms, sticks for bonds
|
| 98 |
+
view.setStyle({'sphere': {'scale': 0.25}, 'stick': {'radius': 0.1}})
|
| 99 |
+
view.zoomTo()
|
| 100 |
+
return view
|
| 101 |
+
|
| 102 |
+
# --- APP LOGIC ---
|
| 103 |
+
model, device = load_model()
|
| 104 |
+
|
| 105 |
+
if model is None:
|
| 106 |
+
st.error(" Model weights not found! Please run 'train.py' first.")
|
| 107 |
+
st.stop()
|
| 108 |
+
|
| 109 |
+
# Layout: Two columns
|
| 110 |
+
col1, col2 = st.columns([1, 2])
|
| 111 |
+
|
| 112 |
+
with col1:
|
| 113 |
+
st.subheader("🧪 Experiment Setup")
|
| 114 |
+
st.write(f"**Target Material:** {formula_display}")
|
| 115 |
+
st.write(f"**Structure Family:** Cubic Perovskite")
|
| 116 |
+
|
| 117 |
+
if st.button("✨ Generate Crystal", type="primary", use_container_width=True):
|
| 118 |
+
|
| 119 |
+
# 1. Setup Data
|
| 120 |
+
z = torch.tensor([selected_z, 22, 8, 8, 8], device=device) # A-Site, Ti, O, O, O
|
| 121 |
+
num_atoms = 5
|
| 122 |
+
|
| 123 |
+
# Graph connections
|
| 124 |
+
row = torch.repeat_interleave(torch.arange(num_atoms), num_atoms)
|
| 125 |
+
col = torch.arange(num_atoms).repeat(num_atoms)
|
| 126 |
+
mask = row != col
|
| 127 |
+
edge_index = torch.stack([row[mask], col[mask]], dim=0).to(device)
|
| 128 |
+
|
| 129 |
+
# 2. Diffusion Loop
|
| 130 |
+
x = torch.randn(num_atoms, 3, device=device) * noise_scale
|
| 131 |
+
|
| 132 |
+
progress_bar = st.progress(0)
|
| 133 |
+
status = st.empty()
|
| 134 |
+
|
| 135 |
+
dt = 1.0 / steps
|
| 136 |
+
for i in range(steps):
|
| 137 |
+
t_val = 1.0 - (i * dt)
|
| 138 |
+
t_tensor = torch.tensor([[t_val]], device=device)
|
| 139 |
+
|
| 140 |
+
with torch.no_grad():
|
| 141 |
+
x_pred = model(x, z, t_tensor, edge_index)
|
| 142 |
+
|
| 143 |
+
# Euler update
|
| 144 |
+
x = x + (x_pred - x) * 0.1
|
| 145 |
+
|
| 146 |
+
if i % 5 == 0:
|
| 147 |
+
progress_bar.progress(i / steps)
|
| 148 |
+
status.text(f"Denoising... Step {i}/{steps}")
|
| 149 |
+
|
| 150 |
+
progress_bar.progress(1.0)
|
| 151 |
+
status.success("Done!")
|
| 152 |
+
|
| 153 |
+
# 3. Store result in session state to keep it on screen
|
| 154 |
+
st.session_state['generated_pos'] = x.numpy()
|
| 155 |
+
st.session_state['generated_z'] = z.numpy()
|
| 156 |
+
|
| 157 |
+
with col2:
|
| 158 |
+
st.subheader("⚛️ 3D Visualization")
|
| 159 |
+
|
| 160 |
+
if 'generated_pos' in st.session_state:
|
| 161 |
+
pos = st.session_state['generated_pos']
|
| 162 |
+
z = st.session_state['generated_z']
|
| 163 |
+
|
| 164 |
+
# Calculate Physics
|
| 165 |
+
avg_bond = calculate_metrics(pos, z)
|
| 166 |
+
|
| 167 |
+
# Display Metrics
|
| 168 |
+
m1, m2 = st.columns(2)
|
| 169 |
+
m1.metric("Avg Ti-O Bond Length", f"{avg_bond:.3f} Å")
|
| 170 |
+
|
| 171 |
+
# Validation Logic
|
| 172 |
+
if 1.8 < avg_bond < 2.2:
|
| 173 |
+
m2.success("✅ Physically Valid")
|
| 174 |
+
else:
|
| 175 |
+
m2.warning("⚠️ Unstable Structure")
|
| 176 |
+
|
| 177 |
+
# Render 3D
|
| 178 |
+
view = make_view(pos, z)
|
| 179 |
+
showmol(view, height=500, width=800)
|
| 180 |
+
|
| 181 |
+
else:
|
| 182 |
+
st.info("👈 Select your chemistry on the left and click 'Generate Crystal' to start the AI.")
|
| 183 |
+
st.markdown("""
|
| 184 |
+
<div style="text-align: center; padding: 50px; border: 2px dashed #444; border-radius: 10px; margin-top: 20px;">
|
| 185 |
+
<h1 style="color: #666;">🧊</h1>
|
| 186 |
+
<p style="color: #888;">Waiting for generation...</p>
|
| 187 |
+
</div>
|
| 188 |
+
""", unsafe_allow_html=True)
|
data/perovskite_dataset.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9a29272d987115656b20c250170a62efa7d76142ad58a8d85da6ceccedd797db
|
| 3 |
+
size 31765
|
generate.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from src.model import CrystalDiffusionModel
|
| 4 |
+
|
| 5 |
+
# --- CONFIGURATION ---
|
| 6 |
+
# Select the atoms you want to generate!
|
| 7 |
+
# Example: BaTiO3 (Barium Titanate) -> Ba=56, Ti=22, O=8
|
| 8 |
+
# Example: SrTiO3 (Strontium Titanate) -> Sr=38, Ti=22, O=8
|
| 9 |
+
# Example: CaTiO3 (Calcium Titanate) -> Ca=20, Ti=22, O=8
|
| 10 |
+
|
| 11 |
+
# Let's try generating Strontium Titanate (SrTiO3) this time
|
| 12 |
+
TARGET_ATOMS = [38, 22, 8, 8, 8] # Sr, Ti, O, O, O
|
| 13 |
+
MODEL_PATH = "model_weights.pth"
|
| 14 |
+
STEPS = 50 # Number of diffusion steps
|
| 15 |
+
|
| 16 |
+
def save_xyz(pos, z, filename):
|
| 17 |
+
"""
|
| 18 |
+
Saves the crystal in XYZ format for visualization.
|
| 19 |
+
"""
|
| 20 |
+
with open(filename, "w") as f:
|
| 21 |
+
f.write(f"{len(pos)}\n")
|
| 22 |
+
f.write("Generated by CrystalDiff\n")
|
| 23 |
+
for i in range(len(pos)):
|
| 24 |
+
# Simple periodic table lookup for common perovskite elements
|
| 25 |
+
# You can add more if you generate other materials
|
| 26 |
+
elem_map = {
|
| 27 |
+
8: "O", 22: "Ti", 20: "Ca",
|
| 28 |
+
56: "Ba", 38: "Sr", 82: "Pb",
|
| 29 |
+
26: "Fe", 40: "Zr"
|
| 30 |
+
}
|
| 31 |
+
atom_symbol = elem_map.get(int(z[i]), "X") # Default to X if unknown
|
| 32 |
+
f.write(f"{atom_symbol} {pos[i,0]:.4f} {pos[i,1]:.4f} {pos[i,2]:.4f}\n")
|
| 33 |
+
|
| 34 |
+
def generate():
|
| 35 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
+
print(f"--- 💎 Generating Crystal on {device} ---")
|
| 37 |
+
|
| 38 |
+
# 1. Load Model
|
| 39 |
+
model = CrystalDiffusionModel().to(device)
|
| 40 |
+
try:
|
| 41 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
|
| 42 |
+
except FileNotFoundError:
|
| 43 |
+
print(f"❌ Error: Could not find '{MODEL_PATH}'. Did you run train.py?")
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
model.eval()
|
| 47 |
+
|
| 48 |
+
# 2. Setup Target Chemistry
|
| 49 |
+
z = torch.tensor(TARGET_ATOMS, device=device)
|
| 50 |
+
num_atoms = len(z)
|
| 51 |
+
|
| 52 |
+
print(f"Target Atoms: {z.tolist()}")
|
| 53 |
+
|
| 54 |
+
# Create fully connected graph
|
| 55 |
+
row = torch.repeat_interleave(torch.arange(num_atoms), num_atoms)
|
| 56 |
+
col = torch.arange(num_atoms).repeat(num_atoms)
|
| 57 |
+
mask = row != col
|
| 58 |
+
edge_index = torch.stack([row[mask], col[mask]], dim=0).to(device)
|
| 59 |
+
|
| 60 |
+
# 3. Start with Pure Noise (The "Chaos")
|
| 61 |
+
# We use a noise scale of 1.0 to match training
|
| 62 |
+
x = torch.randn(num_atoms, 3, device=device)
|
| 63 |
+
|
| 64 |
+
print(f"Initial State: Random Gas Cloud")
|
| 65 |
+
save_xyz(x, z, "gen_step_00.xyz")
|
| 66 |
+
|
| 67 |
+
# 4. The Reverse Diffusion Loop
|
| 68 |
+
dt = 1.0 / STEPS
|
| 69 |
+
|
| 70 |
+
for i in range(STEPS):
|
| 71 |
+
# Time goes from 1.0 -> 0.0
|
| 72 |
+
t_val = 1.0 - (i * dt)
|
| 73 |
+
t_tensor = torch.tensor([[t_val]], device=device)
|
| 74 |
+
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
# Predict where the atoms SHOULD be
|
| 77 |
+
x_pred = model(x, z, t_tensor, edge_index)
|
| 78 |
+
|
| 79 |
+
# Update Position (Euler Integration)
|
| 80 |
+
# We move 10% towards the prediction at each step for stability
|
| 81 |
+
x = x + (x_pred - x) * 0.1
|
| 82 |
+
|
| 83 |
+
if i % 10 == 0:
|
| 84 |
+
print(f"Step {i}/{STEPS}: Denoising...")
|
| 85 |
+
save_xyz(x, z, f"gen_step_{i:02d}.xyz")
|
| 86 |
+
|
| 87 |
+
# 5. Final Save
|
| 88 |
+
print(f"✅ Final Structure Generated!")
|
| 89 |
+
save_xyz(x, z, "gen_final.xyz")
|
| 90 |
+
print("Check 'gen_final.xyz' to see your crystal.")
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
generate()
|
model_weights.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1507e932ff678dd90a144ff2b5e5f280870f528ee9769e8a9303d8a56f226099
|
| 3 |
+
size 355888
|
rdf_analysis.png
ADDED
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mp-api
|
| 2 |
+
python-dotenv
|
| 3 |
+
streamlit
|
| 4 |
+
py3Dmol
|
result_plot.png
ADDED
|
Git LFS Details
|
src/__pycache__/layers.cpython-312.pyc
ADDED
|
Binary file (3.25 kB). View file
|
|
|
src/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (3.81 kB). View file
|
|
|
src/data_loader.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from mp_api.client import MPRester
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
# --- CONFIGURATION ---
|
| 9 |
+
load_dotenv() # Load .env file for API keys
|
| 10 |
+
API_KEY = os.getenv("MPI_API_KEY")
|
| 11 |
+
|
| 12 |
+
# Save path: project root `data/` folder
|
| 13 |
+
repo_root = Path(__file__).resolve().parents[1]
|
| 14 |
+
SAVE_PATH = repo_root / "data" / "perovskite_dataset.pt"
|
| 15 |
+
|
| 16 |
+
def fetch_data(limit=2000):
|
| 17 |
+
"""
|
| 18 |
+
Fetches a large dataset of ABO3 Perovskites (5 atoms) for the Foundation Model.
|
| 19 |
+
"""
|
| 20 |
+
print(f"Connecting to Materials Project...")
|
| 21 |
+
|
| 22 |
+
with MPRester(API_KEY) as mpr:
|
| 23 |
+
# 1. Broad Search: Get all stable materials with 5 atoms
|
| 24 |
+
# We search for materials with exactly 5 sites (atoms) in the unit cell.
|
| 25 |
+
# This implicitly targets ABO3 structures (1+1+3 = 5).
|
| 26 |
+
docs = mpr.materials.summary.search(
|
| 27 |
+
is_stable=True,
|
| 28 |
+
nsites=5,
|
| 29 |
+
fields=["structure", "material_id", "formula_pretty"]
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
print(f"Found {len(docs)} stable 5-atom crystals. Processing...")
|
| 33 |
+
|
| 34 |
+
dataset = []
|
| 35 |
+
|
| 36 |
+
# 2. Filter and Process
|
| 37 |
+
# We want oxygen-containing perovskites generally, but let's keep it broad for now.
|
| 38 |
+
# The 'nsites=5' filter does most of the heavy lifting.
|
| 39 |
+
|
| 40 |
+
count = 0
|
| 41 |
+
for doc in tqdm(docs):
|
| 42 |
+
if count >= limit:
|
| 43 |
+
break
|
| 44 |
+
|
| 45 |
+
structure = doc.structure
|
| 46 |
+
formula = doc.formula_pretty
|
| 47 |
+
|
| 48 |
+
# Heuristic check: Perovskites usually have 3 Oxygens.
|
| 49 |
+
# This filters out random 5-atom things that aren't Perovskites.
|
| 50 |
+
# (Optional but recommended for cleaner data)
|
| 51 |
+
if "O3" not in formula:
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
# --- TENSOR CREATION ---
|
| 55 |
+
|
| 56 |
+
# A. Atomic Numbers (Integers) -> The "Identity"
|
| 57 |
+
atomic_numbers = [site.specie.number for site in structure]
|
| 58 |
+
z_tensor = torch.tensor(atomic_numbers, dtype=torch.long)
|
| 59 |
+
|
| 60 |
+
# B. Coordinates (Floats) -> The "Geometry"
|
| 61 |
+
coords = [site.coords for site in structure]
|
| 62 |
+
r_tensor = torch.tensor(coords, dtype=torch.float32)
|
| 63 |
+
|
| 64 |
+
# C. Center of Mass Correction (CRITICAL for Diffusion)
|
| 65 |
+
# We shift the crystal so its center is at (0,0,0).
|
| 66 |
+
# If we don't do this, the model wastes time learning absolute positions.
|
| 67 |
+
r_tensor = r_tensor - torch.mean(r_tensor, dim=0, keepdim=True)
|
| 68 |
+
|
| 69 |
+
# Create Data Object
|
| 70 |
+
data_point = {
|
| 71 |
+
"id": str(doc.material_id),
|
| 72 |
+
"formula": formula,
|
| 73 |
+
"z": z_tensor, # Features
|
| 74 |
+
"pos": r_tensor # Positions (Centered)
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
dataset.append(data_point)
|
| 78 |
+
count += 1
|
| 79 |
+
|
| 80 |
+
# 3. Save to Disk
|
| 81 |
+
# Ensure directory exists
|
| 82 |
+
SAVE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
torch.save(dataset, SAVE_PATH)
|
| 85 |
+
print(f"✅ Successfully saved {len(dataset)} crystals to {SAVE_PATH}")
|
| 86 |
+
print(f" (Filtered for 5-atom unit cells containing 'O3')")
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
fetch_data(limit=2000)
|
src/layers.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class EGNNLayer(nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Equivariant GNN.
|
| 7 |
+
Update node features (h) and coordinates (x) while respecting rotation.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def __init__(self, c_in, c_out):
|
| 11 |
+
super().__init__()
|
| 12 |
+
|
| 13 |
+
# Edge MLP: Compute message based on features and distance
|
| 14 |
+
# input: h_i + h_j + distance(1)
|
| 15 |
+
self.edge_mlp = nn.Sequential(
|
| 16 |
+
nn.Linear(c_in * 2 + 1, c_out), # c_in * 2 + 1 means we concatenate h_i, h_j and the distance scalar
|
| 17 |
+
nn.SiLU(),
|
| 18 |
+
nn.Linear(c_out, c_out),
|
| 19 |
+
nn.SiLU()
|
| 20 |
+
)
|
| 21 |
+
# Node MLP: Update atom features
|
| 22 |
+
# Input: h_i + aggregated_message
|
| 23 |
+
self.node_mlp = nn.Sequential(
|
| 24 |
+
nn.Linear(c_in + c_out, c_out),
|
| 25 |
+
nn.SiLU(),
|
| 26 |
+
nn.Linear(c_out, c_out)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Coord MLP: Update position (x)
|
| 30 |
+
# Input : message (c_out)
|
| 31 |
+
self.coord_mlp = nn.Sequential(
|
| 32 |
+
nn.Linear(c_out, 1), # output a single scalar 'weight' for the coordinate update
|
| 33 |
+
nn.Tanh() # keeps updates stable (-1 to 1)
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, h, x, edge_index):
|
| 37 |
+
"""
|
| 38 |
+
h: Node features (N, c_in)
|
| 39 |
+
x: Coordinates (N, 3)
|
| 40 |
+
edge_index: Adjacency list (2, E) where E is number of edges-> who connects to whom
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
row, col = edge_index # row = source, col = target
|
| 44 |
+
|
| 45 |
+
# setp 1 : calculate distance
|
| 46 |
+
# get coordinates of source and target nodes
|
| 47 |
+
x_i = x[row] # (E, 3)
|
| 48 |
+
x_j = x[col] # (E, 3),
|
| 49 |
+
# example: if edge_index has [0, 1] in row and [2, 3] in col,
|
| 50 |
+
# then x_i will have coordinates of nodes 0 and 1, while x_j will have coordinates of nodes 2 and 3
|
| 51 |
+
|
| 52 |
+
# calculate squared distance (rotation invariant)
|
| 53 |
+
dist_sq = torch.sum((x_i - x_j)**2, dim=-1, keepdim=True)
|
| 54 |
+
# sum(-1) means we sum over the coordinate dimension, resulting in a scalar distance for each edge. keepdim=True keeps the output shape as (E, 1)
|
| 55 |
+
|
| 56 |
+
# step 2 : calculate edge messages
|
| 57 |
+
# Concatenate: Feature_i, Feature_j, Distance
|
| 58 |
+
# h[row] for source node features, h[col] for target node features
|
| 59 |
+
edge_input = torch.cat([h[row], h[col], dist_sq], dim=-1) # (E, c_in*2 + 1)
|
| 60 |
+
|
| 61 |
+
# pass through edge MLP to get messages
|
| 62 |
+
m_ij = self.edge_mlp(edge_input)
|
| 63 |
+
|
| 64 |
+
# step 3 : Update coordinates ( Equivariant part)
|
| 65 |
+
# predict a weight for vector (x_i - x_j) based on the message
|
| 66 |
+
coord_weight = self.coord_mlp(m_ij)
|
| 67 |
+
|
| 68 |
+
# Update x_new = x + sum((x_i - x_j) * weight), transform
|
| 69 |
+
trans = (x_i - x_j) * coord_weight
|
| 70 |
+
|
| 71 |
+
# Aggregate coordinate updates using scatter_add_ (preserves autograd)
|
| 72 |
+
idx_exp = row.unsqueeze(-1).expand(-1, x.size(-1)) # (E, 3)
|
| 73 |
+
x_agg = torch.zeros_like(x)
|
| 74 |
+
x_agg = x_agg.scatter_add_(0, idx_exp, trans)
|
| 75 |
+
|
| 76 |
+
x_new = x + x_agg
|
| 77 |
+
|
| 78 |
+
# step 4 : Update node features
|
| 79 |
+
m_idx_exp = row.unsqueeze(-1).expand(-1, m_ij.size(-1))
|
| 80 |
+
m_agg = torch.zeros(h.shape[0], m_ij.shape[1], device=h.device)
|
| 81 |
+
m_agg = m_agg.scatter_add_(0, m_idx_exp, m_ij)
|
| 82 |
+
|
| 83 |
+
# Combine old features with new message
|
| 84 |
+
|
| 85 |
+
h_input = torch.cat([h, m_agg], dim=-1)
|
| 86 |
+
h_new = self.node_mlp(h_input)
|
| 87 |
+
|
| 88 |
+
return h_new, x_new
|
src/model.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from src.layers import EGNNLayer
|
| 4 |
+
|
| 5 |
+
class TimeEmbedding(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Converts a time scalar 't' into a vector embedding.
|
| 8 |
+
This allows the neural network to understand the noise level (time step).
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, dim: int):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.dim = dim
|
| 14 |
+
self.linear_1 = nn.Linear(1, dim)
|
| 15 |
+
self.linear_2 = nn.Linear(dim, dim)
|
| 16 |
+
self.act = nn.SiLU() # SiLU is standard for diffusion models
|
| 17 |
+
|
| 18 |
+
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
| 19 |
+
"""
|
| 20 |
+
Args:
|
| 21 |
+
t (Tensor): Time scalars of shape (Batch_Size, 1).
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Tensor: Time embeddings of shape (Batch_Size, dim).
|
| 25 |
+
"""
|
| 26 |
+
# t is shape (Batch_Size, 1) -> we want to output (Batch_Size, dim)
|
| 27 |
+
x = self.act(self.linear_1(t))
|
| 28 |
+
x = self.linear_2(x)
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
class CrystalDiffusionModel(nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
E(n)-Equivariant Diffusion Model for Crystal Generation.
|
| 34 |
+
Predicts the denoised coordinates given a noisy input.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, hidden_dim: int = 64, num_layers: int = 3, max_atom_type: int = 100):
|
| 38 |
+
super().__init__()
|
| 39 |
+
|
| 40 |
+
# 1. Atom Embedding: Integer -> Vector
|
| 41 |
+
# Maps atomic numbers (e.g., 8 for Oxygen) to a dense vector
|
| 42 |
+
self.atom_embed = nn.Embedding(max_atom_type, hidden_dim)
|
| 43 |
+
|
| 44 |
+
# 2. Time Embedding: Scalar -> Vector
|
| 45 |
+
# Helps the model know if it's looking at pure noise (t=1) or a crystal (t=0)
|
| 46 |
+
self.time_embed = TimeEmbedding(hidden_dim)
|
| 47 |
+
|
| 48 |
+
# 3. Backbone: Stack of Equivariant GNN layers
|
| 49 |
+
# These update both features (h) and positions (x)
|
| 50 |
+
self.layers = nn.ModuleList([
|
| 51 |
+
EGNNLayer(c_in=hidden_dim, c_out=hidden_dim)
|
| 52 |
+
for _ in range(num_layers)
|
| 53 |
+
])
|
| 54 |
+
|
| 55 |
+
# Note: We don't need a final linear layer for positions because
|
| 56 |
+
# the EGNN layers update the coordinates 'x' directly at every step.
|
| 57 |
+
|
| 58 |
+
def forward(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
"""
|
| 60 |
+
Forward pass of the diffusion model.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
x (Tensor): Noisy atom positions. Shape (N, 3).
|
| 64 |
+
z (Tensor): Atomic numbers. Shape (N,).
|
| 65 |
+
t (Tensor): Time step/Noise level. Shape (Batch_Size, 1).
|
| 66 |
+
edge_index (Tensor): Graph connectivity (Adjacency list). Shape (2, E).
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Tensor: Denoised atom positions. Shape (N, 3).
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
# 1. Embed Inputs
|
| 73 |
+
h = self.atom_embed(z) # (N, hidden_dim)
|
| 74 |
+
t_emb = self.time_embed(t) # (Batch, hidden_dim)
|
| 75 |
+
|
| 76 |
+
# 2. Condition on Time
|
| 77 |
+
# Broadcast time embedding to all atoms in the batch
|
| 78 |
+
# (Assuming single batch or handled externally for simplicity)
|
| 79 |
+
h = h + t_emb.mean(dim=0, keepdim=True)
|
| 80 |
+
|
| 81 |
+
# 3. Message Passing (The "Brain")
|
| 82 |
+
for layer in self.layers:
|
| 83 |
+
# Update features (h) and positions (x) respecting symmetry
|
| 84 |
+
h, x = layer(h, x, edge_index)
|
| 85 |
+
|
| 86 |
+
# Return the updated (denoised) positions
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
|
train.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.optim as optim
|
| 3 |
+
import random
|
| 4 |
+
import os
|
| 5 |
+
from src.model import CrystalDiffusionModel
|
| 6 |
+
|
| 7 |
+
# --- CONFIGURATION ---
|
| 8 |
+
# Check your data folder to ensure filename matches exactly
|
| 9 |
+
DATA_PATH = "data/perovskite_dataset.pt"
|
| 10 |
+
EPOCHS = 3000
|
| 11 |
+
LEARNING_RATE = 1e-3
|
| 12 |
+
|
| 13 |
+
def load_dataset():
|
| 14 |
+
if not os.path.exists(DATA_PATH):
|
| 15 |
+
raise FileNotFoundError(f"❌ Could not find dataset at {DATA_PATH}. Check spelling!")
|
| 16 |
+
|
| 17 |
+
data = torch.load(DATA_PATH)
|
| 18 |
+
print(f"✅ Loaded {len(data)} crystals for training.")
|
| 19 |
+
return data
|
| 20 |
+
|
| 21 |
+
def get_random_batch(dataset, device):
|
| 22 |
+
"""
|
| 23 |
+
Picks a RANDOM crystal from the dataset.
|
| 24 |
+
This is crucial for generalization (learning rules vs memorizing one shape).
|
| 25 |
+
"""
|
| 26 |
+
# 1. Pick random sample
|
| 27 |
+
sample = random.choice(dataset)
|
| 28 |
+
|
| 29 |
+
# 2. Extract Data
|
| 30 |
+
z = sample["z"].to(device).long()
|
| 31 |
+
x_real = sample["pos"].to(device).float()
|
| 32 |
+
|
| 33 |
+
# 3. Build Graph (Fully Connected)
|
| 34 |
+
# We build this dynamically in case crystals have different sizes
|
| 35 |
+
num_atoms = z.shape[0]
|
| 36 |
+
|
| 37 |
+
# Create all pairs (0,0), (0,1)... (N,N)
|
| 38 |
+
row = torch.repeat_interleave(torch.arange(num_atoms), num_atoms)
|
| 39 |
+
col = torch.arange(num_atoms).repeat(num_atoms)
|
| 40 |
+
|
| 41 |
+
# Remove self-loops (atoms don't connect to themselves)
|
| 42 |
+
mask = row != col
|
| 43 |
+
edge_index = torch.stack([row[mask], col[mask]], dim=0).to(device)
|
| 44 |
+
|
| 45 |
+
return x_real, z, edge_index
|
| 46 |
+
|
| 47 |
+
def train():
|
| 48 |
+
# 1. Setup
|
| 49 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 50 |
+
print(f"--- 🚀 Training on {device} ---")
|
| 51 |
+
|
| 52 |
+
# Load Data
|
| 53 |
+
dataset = load_dataset()
|
| 54 |
+
|
| 55 |
+
# Initialize Model
|
| 56 |
+
model = CrystalDiffusionModel().to(device)
|
| 57 |
+
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
| 58 |
+
model.train()
|
| 59 |
+
|
| 60 |
+
print(f"--- Starting Training Loop ({EPOCHS} Epochs) ---")
|
| 61 |
+
|
| 62 |
+
for epoch in range(1, EPOCHS + 1):
|
| 63 |
+
optimizer.zero_grad()
|
| 64 |
+
|
| 65 |
+
# 2. Get Random Batch
|
| 66 |
+
x_real, z, edge_index = get_random_batch(dataset, device)
|
| 67 |
+
|
| 68 |
+
# 3. Diffusion Step (Forward)
|
| 69 |
+
# Sample random time 't' (how much noise to add)
|
| 70 |
+
t = torch.rand(1, 1, device=device)
|
| 71 |
+
|
| 72 |
+
# Create Noise
|
| 73 |
+
noise = torch.randn_like(x_real)
|
| 74 |
+
|
| 75 |
+
# Add noise: x_noisy = Real + (Noise * t)
|
| 76 |
+
x_noisy = x_real + (noise * t)
|
| 77 |
+
|
| 78 |
+
# 4. Model Prediction (Reverse)
|
| 79 |
+
# Predict the denoised structure
|
| 80 |
+
x_pred = model(x_noisy, z, t, edge_index)
|
| 81 |
+
|
| 82 |
+
# 5. Calculate Loss
|
| 83 |
+
# We want the predicted position to match the real position
|
| 84 |
+
loss = torch.mean((x_pred - x_real)**2)
|
| 85 |
+
|
| 86 |
+
loss.backward()
|
| 87 |
+
optimizer.step()
|
| 88 |
+
|
| 89 |
+
# Log progress
|
| 90 |
+
if epoch % 200 == 0:
|
| 91 |
+
print(f"Epoch {epoch} | Loss: {loss.item():.6f}")
|
| 92 |
+
|
| 93 |
+
# Save the smarter model
|
| 94 |
+
torch.save(model.state_dict(), "model_weights.pth")
|
| 95 |
+
print("✅ Training Complete. Model saved to model_weights.pth!")
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
train()
|
validate.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
def read_xyz(filename):
|
| 4 |
+
coords = []
|
| 5 |
+
atoms = []
|
| 6 |
+
with open(filename, 'r') as f:
|
| 7 |
+
lines = f.readlines()
|
| 8 |
+
for line in lines[2:]:
|
| 9 |
+
parts = line.split()
|
| 10 |
+
atoms.append(parts[0])
|
| 11 |
+
coords.append(np.array([float(parts[1]), float(parts[2]), float(parts[3])]))
|
| 12 |
+
return atoms, coords
|
| 13 |
+
|
| 14 |
+
def check_physics():
|
| 15 |
+
print("--- 🧪 Scientific Validation ---")
|
| 16 |
+
|
| 17 |
+
# 1. Load the Generated Crystal
|
| 18 |
+
atoms, coords = read_xyz("gen_final.xyz")
|
| 19 |
+
|
| 20 |
+
# 2. Find the Titanium (Ti) and Oxygens (O)
|
| 21 |
+
# Note: In our code, we mapped:
|
| 22 |
+
# 22 -> Ti (Titanium)
|
| 23 |
+
# 8 -> O (Oxygen)
|
| 24 |
+
# 20 -> Ca (Calcium)
|
| 25 |
+
|
| 26 |
+
ti_indices = [i for i, atom in enumerate(atoms) if atom == "Ti"]
|
| 27 |
+
o_indices = [i for i, atom in enumerate(atoms) if atom == "O"]
|
| 28 |
+
|
| 29 |
+
if not ti_indices or not o_indices:
|
| 30 |
+
print("❌ Could not find Ti or O atoms to measure bonds.")
|
| 31 |
+
return
|
| 32 |
+
|
| 33 |
+
print(f"Found {len(ti_indices)} Titanium and {len(o_indices)} Oxygen atoms.")
|
| 34 |
+
|
| 35 |
+
# 3. Measure Distances
|
| 36 |
+
bond_lengths = []
|
| 37 |
+
|
| 38 |
+
for ti_idx in ti_indices:
|
| 39 |
+
ti_pos = coords[ti_idx]
|
| 40 |
+
for o_idx in o_indices:
|
| 41 |
+
o_pos = coords[o_idx]
|
| 42 |
+
|
| 43 |
+
# Calculate Euclidean Distance
|
| 44 |
+
dist = np.linalg.norm(ti_pos - o_pos)
|
| 45 |
+
bond_lengths.append(dist)
|
| 46 |
+
|
| 47 |
+
# 4. Analyze Results
|
| 48 |
+
min_bond = min(bond_lengths)
|
| 49 |
+
avg_bond = sum(bond_lengths) / len(bond_lengths)
|
| 50 |
+
|
| 51 |
+
print(f"\nMeasured Bond Lengths (Ti - O):")
|
| 52 |
+
print(f" Minimum: {min_bond:.4f} Å")
|
| 53 |
+
print(f" Average: {avg_bond:.4f} Å")
|
| 54 |
+
|
| 55 |
+
# 5. The "DeepMind" Pass/Fail
|
| 56 |
+
# Real Physics: Ti-O bond is typically 1.90 - 2.05 Å
|
| 57 |
+
# We allow some error since this is a tiny model trained for 5 minutes
|
| 58 |
+
if 1.5 < min_bond < 2.5:
|
| 59 |
+
print("\n✅ SUCCESS: The model learned valid chemical bonds!")
|
| 60 |
+
print(" (Target range: ~1.9 Å)")
|
| 61 |
+
else:
|
| 62 |
+
print("\n⚠️ WARNING: Bonds are physically unrealistic.")
|
| 63 |
+
print(" (Try training for more epochs or checking the dataset)")
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
check_physics()
|
visualize.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def read_xyz(filename):
|
| 6 |
+
"""
|
| 7 |
+
Reads an XYZ file and returns coordinates and atom types.
|
| 8 |
+
"""
|
| 9 |
+
coords = []
|
| 10 |
+
atoms = []
|
| 11 |
+
with open(filename, 'r') as f:
|
| 12 |
+
lines = f.readlines()
|
| 13 |
+
# Skip header lines (first 2)
|
| 14 |
+
for line in lines[2:]:
|
| 15 |
+
parts = line.split()
|
| 16 |
+
atoms.append(parts[0])
|
| 17 |
+
coords.append([float(parts[1]), float(parts[2]), float(parts[3])])
|
| 18 |
+
return np.array(coords), atoms
|
| 19 |
+
|
| 20 |
+
def plot_crystal(ax, coords, atoms, title):
|
| 21 |
+
"""
|
| 22 |
+
Plots a single crystal in a 3D subplot.
|
| 23 |
+
"""
|
| 24 |
+
# Define colors for atoms (Titanium=Silver, Oxygen=Red, Ca=Green)
|
| 25 |
+
colors = {'Ti': 'gray', 'O': 'red', 'Ca': 'green', 'Pb': 'black', 'I': 'purple'}
|
| 26 |
+
|
| 27 |
+
# Scatter plot
|
| 28 |
+
# s=size of atom, alpha=transparency
|
| 29 |
+
for i, atom in enumerate(atoms):
|
| 30 |
+
color = colors.get(atom, 'blue') # Default to blue if unknown
|
| 31 |
+
ax.scatter(coords[i,0], coords[i,1], coords[i,2],
|
| 32 |
+
c=color, s=200, edgecolors='k', alpha=0.8)
|
| 33 |
+
|
| 34 |
+
# Draw "bonds" (lines between atoms close to each other)
|
| 35 |
+
# This helps visualize the structure
|
| 36 |
+
num_atoms = len(coords)
|
| 37 |
+
for i in range(num_atoms):
|
| 38 |
+
for j in range(i + 1, num_atoms):
|
| 39 |
+
dist = np.linalg.norm(coords[i] - coords[j])
|
| 40 |
+
# If atoms are closer than 2.8 Angstroms, draw a line
|
| 41 |
+
if dist < 2.8:
|
| 42 |
+
ax.plot([coords[i,0], coords[j,0]],
|
| 43 |
+
[coords[i,1], coords[j,1]],
|
| 44 |
+
[coords[i,2], coords[j,2]],
|
| 45 |
+
c='black', linewidth=1, alpha=0.5)
|
| 46 |
+
|
| 47 |
+
ax.set_title(title)
|
| 48 |
+
ax.set_xlabel('X')
|
| 49 |
+
ax.set_ylabel('Y')
|
| 50 |
+
ax.set_zlabel('Z')
|
| 51 |
+
|
| 52 |
+
# Set consistent limits so we can compare
|
| 53 |
+
ax.set_xlim(-2, 5)
|
| 54 |
+
ax.set_ylim(-2, 5)
|
| 55 |
+
ax.set_zlim(-2, 5)
|
| 56 |
+
|
| 57 |
+
def create_comparison_figure():
|
| 58 |
+
# 1. Read Data
|
| 59 |
+
# Make sure you ran generate.py first to get these files!
|
| 60 |
+
noise_pos, atoms = read_xyz("gen_step_00.xyz")
|
| 61 |
+
final_pos, _ = read_xyz("gen_final.xyz")
|
| 62 |
+
|
| 63 |
+
# 2. Setup Plot
|
| 64 |
+
fig = plt.figure(figsize=(12, 6))
|
| 65 |
+
|
| 66 |
+
# Plot 1: The Noise
|
| 67 |
+
ax1 = fig.add_subplot(121, projection='3d')
|
| 68 |
+
plot_crystal(ax1, noise_pos, atoms, "Step 0: Random Noise")
|
| 69 |
+
|
| 70 |
+
# Plot 2: The Generated Crystal
|
| 71 |
+
ax2 = fig.add_subplot(122, projection='3d')
|
| 72 |
+
plot_crystal(ax2, final_pos, atoms, "Step 50: Generated Crystal")
|
| 73 |
+
|
| 74 |
+
plt.tight_layout()
|
| 75 |
+
plt.savefig("result_plot.png", dpi=300)
|
| 76 |
+
print("Saved comparison figure to 'result_plot.png'")
|
| 77 |
+
plt.show()
|
| 78 |
+
|
| 79 |
+
if __name__ == "__main__":
|
| 80 |
+
create_comparison_figure()
|