adi2075 commited on
Commit
4e4ff14
·
verified ·
1 Parent(s): 5906fbf

Upload 16 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ result_plot.png filter=lfs diff=lfs merge=lfs -text
analysis.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from src.model import CrystalDiffusionModel
5
+
6
+ # Load your model and generate a crystal
7
+ def generate_crystal():
8
+ model = CrystalDiffusionModel()
9
+ model.load_state_dict(torch.load("model_weights.pth", map_location='cpu'))
10
+ model.eval()
11
+
12
+ # Generate 5 atoms
13
+ num_atoms = 5
14
+ z = torch.tensor([56, 22, 8, 8, 8]) # BaTiO3
15
+
16
+ # Graph Setup
17
+ row = torch.repeat_interleave(torch.arange(num_atoms), num_atoms)
18
+ col = torch.arange(num_atoms).repeat(num_atoms)
19
+ mask = row != col
20
+ edge_index = torch.stack([row[mask], col[mask]], dim=0)
21
+
22
+ # Diffusion
23
+ x = torch.randn(num_atoms, 3) # Start with noise
24
+ steps = 50
25
+ dt = 1.0 / steps
26
+
27
+ for i in range(steps):
28
+ t = torch.tensor([[1.0 - i*dt]])
29
+ with torch.no_grad():
30
+ pred = model(x, z, t, edge_index)
31
+ x = x + (pred - x) * 0.1
32
+
33
+ return x.numpy()
34
+
35
+ def compute_rdf(coords, box_size=5.0, bins=50):
36
+ """
37
+ Calculates the Radial Distribution Function (RDF).
38
+ """
39
+ distances = []
40
+ num_atoms = len(coords)
41
+
42
+ for i in range(num_atoms):
43
+ for j in range(i + 1, num_atoms):
44
+ dist = np.linalg.norm(coords[i] - coords[j])
45
+ distances.append(dist)
46
+
47
+ # Histogram
48
+ hist, bin_edges = np.histogram(distances, bins=bins, range=(0, box_size))
49
+ r = (bin_edges[:-1] + bin_edges[1:]) / 2
50
+
51
+ # Normalize (Volume correction)
52
+ dr = bin_edges[1] - bin_edges[0]
53
+ volume = 4 * np.pi * r**2 * dr
54
+ rdf = hist / (volume * num_atoms) # Density normalization
55
+
56
+ return r, rdf
57
+
58
+ def plot_comparison():
59
+ print("Generating Analysis Plot...")
60
+
61
+ # 1. Get RDF for Random Noise
62
+ noise = np.random.randn(5, 3)
63
+ r_noise, rdf_noise = compute_rdf(noise)
64
+
65
+ # 2. Get RDF for Generated Crystal
66
+ crystal = generate_crystal()
67
+ r_crys, rdf_crys = compute_rdf(crystal)
68
+
69
+ # 3. Plot
70
+ plt.figure(figsize=(10, 6))
71
+ plt.plot(r_noise, rdf_noise, label='Random Noise', linestyle='--', color='gray')
72
+ plt.plot(r_crys, rdf_crys, label='Generated Crystal (AI)', linewidth=3, color='blue')
73
+
74
+ plt.title("Radial Distribution Function (RDF) Analysis")
75
+ plt.xlabel("Distance (Angstroms)")
76
+ plt.ylabel("Probability Density")
77
+ plt.legend()
78
+ plt.grid(True, alpha=0.3)
79
+
80
+ plt.savefig("rdf_analysis.png")
81
+ print("✅ Saved 'rdf_analysis.png'. Put this in your README!")
82
+
83
+ if __name__ == "__main__":
84
+ plot_comparison()
app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ import py3Dmol
5
+ from stmol import showmol
6
+ from src.model import CrystalDiffusionModel
7
+
8
+ # --- PAGE CONFIGURATION ---
9
+ st.set_page_config(
10
+ page_title="CrystalDiff: AI Material Designer",
11
+ layout="wide",
12
+ page_icon="💎",
13
+ initial_sidebar_state="expanded"
14
+ )
15
+
16
+ # --- SIDEBAR: CONTROLS & INFO ---
17
+ with st.sidebar:
18
+ st.title("💎 CrystalDiff Controls")
19
+
20
+ st.markdown("### 1. Select Chemistry")
21
+ target_atom = st.selectbox(
22
+ "Choose A-Site Cation",
23
+ ["Ca (Calcium)", "Sr (Strontium)", "Ba (Barium)", "Pb (Lead)"],
24
+ index=1,
25
+ help="The large atom in the center of the cage."
26
+ )
27
+
28
+ st.markdown("### 2. Diffusion Settings")
29
+ steps = st.slider("Denoising Steps", 10, 100, 50, help="More steps = higher quality, but slower.")
30
+ noise_scale = st.slider("Initial Chaos (Noise)", 0.5, 2.0, 1.0, help="Higher noise means the AI has to be more creative.")
31
+
32
+ st.divider()
33
+
34
+ st.markdown("### 🧠 How it Works")
35
+ st.info("""
36
+ **Generative Diffusion:**
37
+ The model starts with random noise (chaos) and iteratively subtracts noise to find a stable crystal structure.
38
+
39
+ **E(n)-Equivariance:**
40
+ The AI uses a custom Graph Neural Network that respects the laws of physics (rotational symmetry).
41
+ """)
42
+
43
+ st.markdown("---")
44
+ st.caption("Built with PyTorch & Streamlit by Aditya Mangal. Inspired by DeepMind's work on generative models for materials science.")
45
+
46
+ # --- MAIN PAGE ---
47
+ st.title("💎 CrystalDiff: Generative Material Design")
48
+ st.markdown("""
49
+ This application uses **Geometric Deep Learning** to hallucinate new stable crystals.
50
+ It was trained on the **Materials Project** database to understand the chemical rules of **Perovskite Oxides ($ABO_3$)**.
51
+ """)
52
+
53
+ # Map selection to Atomic Number
54
+ atom_map = {
55
+ "Ca (Calcium)": 20, "Sr (Strontium)": 38,
56
+ "Ba (Barium)": 56, "Pb (Lead)": 82
57
+ }
58
+ selected_z = atom_map[target_atom]
59
+ formula_display = f"{target_atom.split()[0]}TiO₃"
60
+
61
+ # --- HELPER FUNCTIONS ---
62
+ @st.cache_resource
63
+ def load_model():
64
+ device = torch.device("cpu")
65
+ model = CrystalDiffusionModel()
66
+ try:
67
+ model.load_state_dict(torch.load("model_weights.pth", map_location=device))
68
+ model.eval()
69
+ return model, device
70
+ except FileNotFoundError:
71
+ return None, None
72
+
73
+ def calculate_metrics(pos, z):
74
+ """Calculates bond lengths to validate physics."""
75
+ # Find Ti (22) and O (8)
76
+ ti_idx = [i for i, atom in enumerate(z) if atom == 22]
77
+ o_idx = [i for i, atom in enumerate(z) if atom == 8]
78
+
79
+ if not ti_idx or not o_idx: return 0.0
80
+
81
+ ti_pos = pos[ti_idx[0]]
82
+ dists = []
83
+ for o in o_idx:
84
+ d = np.linalg.norm(ti_pos - pos[o])
85
+ dists.append(d)
86
+
87
+ return np.mean(dists)
88
+
89
+ def make_view(pos, z):
90
+ """Creates a 3D molecule view"""
91
+ view = py3Dmol.view(width=800, height=500)
92
+ xyz_str = f"{len(pos)}\nGenerated\n"
93
+ for i in range(len(pos)):
94
+ elem = "O" if z[i] == 8 else "Ti" if z[i] == 22 else target_atom.split()[0]
95
+ xyz_str += f"{elem} {pos[i,0]:.4f} {pos[i,1]:.4f} {pos[i,2]:.4f}\n"
96
+ view.addModel(xyz_str, "xyz")
97
+ # Style: spheres for atoms, sticks for bonds
98
+ view.setStyle({'sphere': {'scale': 0.25}, 'stick': {'radius': 0.1}})
99
+ view.zoomTo()
100
+ return view
101
+
102
+ # --- APP LOGIC ---
103
+ model, device = load_model()
104
+
105
+ if model is None:
106
+ st.error(" Model weights not found! Please run 'train.py' first.")
107
+ st.stop()
108
+
109
+ # Layout: Two columns
110
+ col1, col2 = st.columns([1, 2])
111
+
112
+ with col1:
113
+ st.subheader("🧪 Experiment Setup")
114
+ st.write(f"**Target Material:** {formula_display}")
115
+ st.write(f"**Structure Family:** Cubic Perovskite")
116
+
117
+ if st.button("✨ Generate Crystal", type="primary", use_container_width=True):
118
+
119
+ # 1. Setup Data
120
+ z = torch.tensor([selected_z, 22, 8, 8, 8], device=device) # A-Site, Ti, O, O, O
121
+ num_atoms = 5
122
+
123
+ # Graph connections
124
+ row = torch.repeat_interleave(torch.arange(num_atoms), num_atoms)
125
+ col = torch.arange(num_atoms).repeat(num_atoms)
126
+ mask = row != col
127
+ edge_index = torch.stack([row[mask], col[mask]], dim=0).to(device)
128
+
129
+ # 2. Diffusion Loop
130
+ x = torch.randn(num_atoms, 3, device=device) * noise_scale
131
+
132
+ progress_bar = st.progress(0)
133
+ status = st.empty()
134
+
135
+ dt = 1.0 / steps
136
+ for i in range(steps):
137
+ t_val = 1.0 - (i * dt)
138
+ t_tensor = torch.tensor([[t_val]], device=device)
139
+
140
+ with torch.no_grad():
141
+ x_pred = model(x, z, t_tensor, edge_index)
142
+
143
+ # Euler update
144
+ x = x + (x_pred - x) * 0.1
145
+
146
+ if i % 5 == 0:
147
+ progress_bar.progress(i / steps)
148
+ status.text(f"Denoising... Step {i}/{steps}")
149
+
150
+ progress_bar.progress(1.0)
151
+ status.success("Done!")
152
+
153
+ # 3. Store result in session state to keep it on screen
154
+ st.session_state['generated_pos'] = x.numpy()
155
+ st.session_state['generated_z'] = z.numpy()
156
+
157
+ with col2:
158
+ st.subheader("⚛️ 3D Visualization")
159
+
160
+ if 'generated_pos' in st.session_state:
161
+ pos = st.session_state['generated_pos']
162
+ z = st.session_state['generated_z']
163
+
164
+ # Calculate Physics
165
+ avg_bond = calculate_metrics(pos, z)
166
+
167
+ # Display Metrics
168
+ m1, m2 = st.columns(2)
169
+ m1.metric("Avg Ti-O Bond Length", f"{avg_bond:.3f} Å")
170
+
171
+ # Validation Logic
172
+ if 1.8 < avg_bond < 2.2:
173
+ m2.success("✅ Physically Valid")
174
+ else:
175
+ m2.warning("⚠️ Unstable Structure")
176
+
177
+ # Render 3D
178
+ view = make_view(pos, z)
179
+ showmol(view, height=500, width=800)
180
+
181
+ else:
182
+ st.info("👈 Select your chemistry on the left and click 'Generate Crystal' to start the AI.")
183
+ st.markdown("""
184
+ <div style="text-align: center; padding: 50px; border: 2px dashed #444; border-radius: 10px; margin-top: 20px;">
185
+ <h1 style="color: #666;">🧊</h1>
186
+ <p style="color: #888;">Waiting for generation...</p>
187
+ </div>
188
+ """, unsafe_allow_html=True)
data/perovskite_dataset.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a29272d987115656b20c250170a62efa7d76142ad58a8d85da6ceccedd797db
3
+ size 31765
generate.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from src.model import CrystalDiffusionModel
4
+
5
+ # --- CONFIGURATION ---
6
+ # Select the atoms you want to generate!
7
+ # Example: BaTiO3 (Barium Titanate) -> Ba=56, Ti=22, O=8
8
+ # Example: SrTiO3 (Strontium Titanate) -> Sr=38, Ti=22, O=8
9
+ # Example: CaTiO3 (Calcium Titanate) -> Ca=20, Ti=22, O=8
10
+
11
+ # Let's try generating Strontium Titanate (SrTiO3) this time
12
+ TARGET_ATOMS = [38, 22, 8, 8, 8] # Sr, Ti, O, O, O
13
+ MODEL_PATH = "model_weights.pth"
14
+ STEPS = 50 # Number of diffusion steps
15
+
16
+ def save_xyz(pos, z, filename):
17
+ """
18
+ Saves the crystal in XYZ format for visualization.
19
+ """
20
+ with open(filename, "w") as f:
21
+ f.write(f"{len(pos)}\n")
22
+ f.write("Generated by CrystalDiff\n")
23
+ for i in range(len(pos)):
24
+ # Simple periodic table lookup for common perovskite elements
25
+ # You can add more if you generate other materials
26
+ elem_map = {
27
+ 8: "O", 22: "Ti", 20: "Ca",
28
+ 56: "Ba", 38: "Sr", 82: "Pb",
29
+ 26: "Fe", 40: "Zr"
30
+ }
31
+ atom_symbol = elem_map.get(int(z[i]), "X") # Default to X if unknown
32
+ f.write(f"{atom_symbol} {pos[i,0]:.4f} {pos[i,1]:.4f} {pos[i,2]:.4f}\n")
33
+
34
+ def generate():
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ print(f"--- 💎 Generating Crystal on {device} ---")
37
+
38
+ # 1. Load Model
39
+ model = CrystalDiffusionModel().to(device)
40
+ try:
41
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
42
+ except FileNotFoundError:
43
+ print(f"❌ Error: Could not find '{MODEL_PATH}'. Did you run train.py?")
44
+ return
45
+
46
+ model.eval()
47
+
48
+ # 2. Setup Target Chemistry
49
+ z = torch.tensor(TARGET_ATOMS, device=device)
50
+ num_atoms = len(z)
51
+
52
+ print(f"Target Atoms: {z.tolist()}")
53
+
54
+ # Create fully connected graph
55
+ row = torch.repeat_interleave(torch.arange(num_atoms), num_atoms)
56
+ col = torch.arange(num_atoms).repeat(num_atoms)
57
+ mask = row != col
58
+ edge_index = torch.stack([row[mask], col[mask]], dim=0).to(device)
59
+
60
+ # 3. Start with Pure Noise (The "Chaos")
61
+ # We use a noise scale of 1.0 to match training
62
+ x = torch.randn(num_atoms, 3, device=device)
63
+
64
+ print(f"Initial State: Random Gas Cloud")
65
+ save_xyz(x, z, "gen_step_00.xyz")
66
+
67
+ # 4. The Reverse Diffusion Loop
68
+ dt = 1.0 / STEPS
69
+
70
+ for i in range(STEPS):
71
+ # Time goes from 1.0 -> 0.0
72
+ t_val = 1.0 - (i * dt)
73
+ t_tensor = torch.tensor([[t_val]], device=device)
74
+
75
+ with torch.no_grad():
76
+ # Predict where the atoms SHOULD be
77
+ x_pred = model(x, z, t_tensor, edge_index)
78
+
79
+ # Update Position (Euler Integration)
80
+ # We move 10% towards the prediction at each step for stability
81
+ x = x + (x_pred - x) * 0.1
82
+
83
+ if i % 10 == 0:
84
+ print(f"Step {i}/{STEPS}: Denoising...")
85
+ save_xyz(x, z, f"gen_step_{i:02d}.xyz")
86
+
87
+ # 5. Final Save
88
+ print(f"✅ Final Structure Generated!")
89
+ save_xyz(x, z, "gen_final.xyz")
90
+ print("Check 'gen_final.xyz' to see your crystal.")
91
+
92
+ if __name__ == "__main__":
93
+ generate()
model_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1507e932ff678dd90a144ff2b5e5f280870f528ee9769e8a9303d8a56f226099
3
+ size 355888
rdf_analysis.png ADDED
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ mp-api
2
+ python-dotenv
3
+ streamlit
4
+ py3Dmol
result_plot.png ADDED

Git LFS Details

  • SHA256: 0ab1f0d27d26aebee590e099bf021c08bd86887432cd450d2ff8bf9de3f64463
  • Pointer size: 131 Bytes
  • Size of remote file: 774 kB
src/__pycache__/layers.cpython-312.pyc ADDED
Binary file (3.25 kB). View file
 
src/__pycache__/model.cpython-312.pyc ADDED
Binary file (3.81 kB). View file
 
src/data_loader.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from mp_api.client import MPRester
3
+ import os
4
+ from pathlib import Path
5
+ from dotenv import load_dotenv
6
+ from tqdm import tqdm
7
+
8
+ # --- CONFIGURATION ---
9
+ load_dotenv() # Load .env file for API keys
10
+ API_KEY = os.getenv("MPI_API_KEY")
11
+
12
+ # Save path: project root `data/` folder
13
+ repo_root = Path(__file__).resolve().parents[1]
14
+ SAVE_PATH = repo_root / "data" / "perovskite_dataset.pt"
15
+
16
+ def fetch_data(limit=2000):
17
+ """
18
+ Fetches a large dataset of ABO3 Perovskites (5 atoms) for the Foundation Model.
19
+ """
20
+ print(f"Connecting to Materials Project...")
21
+
22
+ with MPRester(API_KEY) as mpr:
23
+ # 1. Broad Search: Get all stable materials with 5 atoms
24
+ # We search for materials with exactly 5 sites (atoms) in the unit cell.
25
+ # This implicitly targets ABO3 structures (1+1+3 = 5).
26
+ docs = mpr.materials.summary.search(
27
+ is_stable=True,
28
+ nsites=5,
29
+ fields=["structure", "material_id", "formula_pretty"]
30
+ )
31
+
32
+ print(f"Found {len(docs)} stable 5-atom crystals. Processing...")
33
+
34
+ dataset = []
35
+
36
+ # 2. Filter and Process
37
+ # We want oxygen-containing perovskites generally, but let's keep it broad for now.
38
+ # The 'nsites=5' filter does most of the heavy lifting.
39
+
40
+ count = 0
41
+ for doc in tqdm(docs):
42
+ if count >= limit:
43
+ break
44
+
45
+ structure = doc.structure
46
+ formula = doc.formula_pretty
47
+
48
+ # Heuristic check: Perovskites usually have 3 Oxygens.
49
+ # This filters out random 5-atom things that aren't Perovskites.
50
+ # (Optional but recommended for cleaner data)
51
+ if "O3" not in formula:
52
+ continue
53
+
54
+ # --- TENSOR CREATION ---
55
+
56
+ # A. Atomic Numbers (Integers) -> The "Identity"
57
+ atomic_numbers = [site.specie.number for site in structure]
58
+ z_tensor = torch.tensor(atomic_numbers, dtype=torch.long)
59
+
60
+ # B. Coordinates (Floats) -> The "Geometry"
61
+ coords = [site.coords for site in structure]
62
+ r_tensor = torch.tensor(coords, dtype=torch.float32)
63
+
64
+ # C. Center of Mass Correction (CRITICAL for Diffusion)
65
+ # We shift the crystal so its center is at (0,0,0).
66
+ # If we don't do this, the model wastes time learning absolute positions.
67
+ r_tensor = r_tensor - torch.mean(r_tensor, dim=0, keepdim=True)
68
+
69
+ # Create Data Object
70
+ data_point = {
71
+ "id": str(doc.material_id),
72
+ "formula": formula,
73
+ "z": z_tensor, # Features
74
+ "pos": r_tensor # Positions (Centered)
75
+ }
76
+
77
+ dataset.append(data_point)
78
+ count += 1
79
+
80
+ # 3. Save to Disk
81
+ # Ensure directory exists
82
+ SAVE_PATH.parent.mkdir(parents=True, exist_ok=True)
83
+
84
+ torch.save(dataset, SAVE_PATH)
85
+ print(f"✅ Successfully saved {len(dataset)} crystals to {SAVE_PATH}")
86
+ print(f" (Filtered for 5-atom unit cells containing 'O3')")
87
+
88
+ if __name__ == "__main__":
89
+ fetch_data(limit=2000)
src/layers.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class EGNNLayer(nn.Module):
5
+ """
6
+ Equivariant GNN.
7
+ Update node features (h) and coordinates (x) while respecting rotation.
8
+ """
9
+
10
+ def __init__(self, c_in, c_out):
11
+ super().__init__()
12
+
13
+ # Edge MLP: Compute message based on features and distance
14
+ # input: h_i + h_j + distance(1)
15
+ self.edge_mlp = nn.Sequential(
16
+ nn.Linear(c_in * 2 + 1, c_out), # c_in * 2 + 1 means we concatenate h_i, h_j and the distance scalar
17
+ nn.SiLU(),
18
+ nn.Linear(c_out, c_out),
19
+ nn.SiLU()
20
+ )
21
+ # Node MLP: Update atom features
22
+ # Input: h_i + aggregated_message
23
+ self.node_mlp = nn.Sequential(
24
+ nn.Linear(c_in + c_out, c_out),
25
+ nn.SiLU(),
26
+ nn.Linear(c_out, c_out)
27
+ )
28
+
29
+ # Coord MLP: Update position (x)
30
+ # Input : message (c_out)
31
+ self.coord_mlp = nn.Sequential(
32
+ nn.Linear(c_out, 1), # output a single scalar 'weight' for the coordinate update
33
+ nn.Tanh() # keeps updates stable (-1 to 1)
34
+ )
35
+
36
+ def forward(self, h, x, edge_index):
37
+ """
38
+ h: Node features (N, c_in)
39
+ x: Coordinates (N, 3)
40
+ edge_index: Adjacency list (2, E) where E is number of edges-> who connects to whom
41
+ """
42
+
43
+ row, col = edge_index # row = source, col = target
44
+
45
+ # setp 1 : calculate distance
46
+ # get coordinates of source and target nodes
47
+ x_i = x[row] # (E, 3)
48
+ x_j = x[col] # (E, 3),
49
+ # example: if edge_index has [0, 1] in row and [2, 3] in col,
50
+ # then x_i will have coordinates of nodes 0 and 1, while x_j will have coordinates of nodes 2 and 3
51
+
52
+ # calculate squared distance (rotation invariant)
53
+ dist_sq = torch.sum((x_i - x_j)**2, dim=-1, keepdim=True)
54
+ # sum(-1) means we sum over the coordinate dimension, resulting in a scalar distance for each edge. keepdim=True keeps the output shape as (E, 1)
55
+
56
+ # step 2 : calculate edge messages
57
+ # Concatenate: Feature_i, Feature_j, Distance
58
+ # h[row] for source node features, h[col] for target node features
59
+ edge_input = torch.cat([h[row], h[col], dist_sq], dim=-1) # (E, c_in*2 + 1)
60
+
61
+ # pass through edge MLP to get messages
62
+ m_ij = self.edge_mlp(edge_input)
63
+
64
+ # step 3 : Update coordinates ( Equivariant part)
65
+ # predict a weight for vector (x_i - x_j) based on the message
66
+ coord_weight = self.coord_mlp(m_ij)
67
+
68
+ # Update x_new = x + sum((x_i - x_j) * weight), transform
69
+ trans = (x_i - x_j) * coord_weight
70
+
71
+ # Aggregate coordinate updates using scatter_add_ (preserves autograd)
72
+ idx_exp = row.unsqueeze(-1).expand(-1, x.size(-1)) # (E, 3)
73
+ x_agg = torch.zeros_like(x)
74
+ x_agg = x_agg.scatter_add_(0, idx_exp, trans)
75
+
76
+ x_new = x + x_agg
77
+
78
+ # step 4 : Update node features
79
+ m_idx_exp = row.unsqueeze(-1).expand(-1, m_ij.size(-1))
80
+ m_agg = torch.zeros(h.shape[0], m_ij.shape[1], device=h.device)
81
+ m_agg = m_agg.scatter_add_(0, m_idx_exp, m_ij)
82
+
83
+ # Combine old features with new message
84
+
85
+ h_input = torch.cat([h, m_agg], dim=-1)
86
+ h_new = self.node_mlp(h_input)
87
+
88
+ return h_new, x_new
src/model.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from src.layers import EGNNLayer
4
+
5
+ class TimeEmbedding(nn.Module):
6
+ """
7
+ Converts a time scalar 't' into a vector embedding.
8
+ This allows the neural network to understand the noise level (time step).
9
+ """
10
+
11
+ def __init__(self, dim: int):
12
+ super().__init__()
13
+ self.dim = dim
14
+ self.linear_1 = nn.Linear(1, dim)
15
+ self.linear_2 = nn.Linear(dim, dim)
16
+ self.act = nn.SiLU() # SiLU is standard for diffusion models
17
+
18
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
19
+ """
20
+ Args:
21
+ t (Tensor): Time scalars of shape (Batch_Size, 1).
22
+
23
+ Returns:
24
+ Tensor: Time embeddings of shape (Batch_Size, dim).
25
+ """
26
+ # t is shape (Batch_Size, 1) -> we want to output (Batch_Size, dim)
27
+ x = self.act(self.linear_1(t))
28
+ x = self.linear_2(x)
29
+ return x
30
+
31
+ class CrystalDiffusionModel(nn.Module):
32
+ """
33
+ E(n)-Equivariant Diffusion Model for Crystal Generation.
34
+ Predicts the denoised coordinates given a noisy input.
35
+ """
36
+
37
+ def __init__(self, hidden_dim: int = 64, num_layers: int = 3, max_atom_type: int = 100):
38
+ super().__init__()
39
+
40
+ # 1. Atom Embedding: Integer -> Vector
41
+ # Maps atomic numbers (e.g., 8 for Oxygen) to a dense vector
42
+ self.atom_embed = nn.Embedding(max_atom_type, hidden_dim)
43
+
44
+ # 2. Time Embedding: Scalar -> Vector
45
+ # Helps the model know if it's looking at pure noise (t=1) or a crystal (t=0)
46
+ self.time_embed = TimeEmbedding(hidden_dim)
47
+
48
+ # 3. Backbone: Stack of Equivariant GNN layers
49
+ # These update both features (h) and positions (x)
50
+ self.layers = nn.ModuleList([
51
+ EGNNLayer(c_in=hidden_dim, c_out=hidden_dim)
52
+ for _ in range(num_layers)
53
+ ])
54
+
55
+ # Note: We don't need a final linear layer for positions because
56
+ # the EGNN layers update the coordinates 'x' directly at every step.
57
+
58
+ def forward(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
59
+ """
60
+ Forward pass of the diffusion model.
61
+
62
+ Args:
63
+ x (Tensor): Noisy atom positions. Shape (N, 3).
64
+ z (Tensor): Atomic numbers. Shape (N,).
65
+ t (Tensor): Time step/Noise level. Shape (Batch_Size, 1).
66
+ edge_index (Tensor): Graph connectivity (Adjacency list). Shape (2, E).
67
+
68
+ Returns:
69
+ Tensor: Denoised atom positions. Shape (N, 3).
70
+ """
71
+
72
+ # 1. Embed Inputs
73
+ h = self.atom_embed(z) # (N, hidden_dim)
74
+ t_emb = self.time_embed(t) # (Batch, hidden_dim)
75
+
76
+ # 2. Condition on Time
77
+ # Broadcast time embedding to all atoms in the batch
78
+ # (Assuming single batch or handled externally for simplicity)
79
+ h = h + t_emb.mean(dim=0, keepdim=True)
80
+
81
+ # 3. Message Passing (The "Brain")
82
+ for layer in self.layers:
83
+ # Update features (h) and positions (x) respecting symmetry
84
+ h, x = layer(h, x, edge_index)
85
+
86
+ # Return the updated (denoised) positions
87
+ return x
88
+
89
+
train.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import random
4
+ import os
5
+ from src.model import CrystalDiffusionModel
6
+
7
+ # --- CONFIGURATION ---
8
+ # Check your data folder to ensure filename matches exactly
9
+ DATA_PATH = "data/perovskite_dataset.pt"
10
+ EPOCHS = 3000
11
+ LEARNING_RATE = 1e-3
12
+
13
+ def load_dataset():
14
+ if not os.path.exists(DATA_PATH):
15
+ raise FileNotFoundError(f"❌ Could not find dataset at {DATA_PATH}. Check spelling!")
16
+
17
+ data = torch.load(DATA_PATH)
18
+ print(f"✅ Loaded {len(data)} crystals for training.")
19
+ return data
20
+
21
+ def get_random_batch(dataset, device):
22
+ """
23
+ Picks a RANDOM crystal from the dataset.
24
+ This is crucial for generalization (learning rules vs memorizing one shape).
25
+ """
26
+ # 1. Pick random sample
27
+ sample = random.choice(dataset)
28
+
29
+ # 2. Extract Data
30
+ z = sample["z"].to(device).long()
31
+ x_real = sample["pos"].to(device).float()
32
+
33
+ # 3. Build Graph (Fully Connected)
34
+ # We build this dynamically in case crystals have different sizes
35
+ num_atoms = z.shape[0]
36
+
37
+ # Create all pairs (0,0), (0,1)... (N,N)
38
+ row = torch.repeat_interleave(torch.arange(num_atoms), num_atoms)
39
+ col = torch.arange(num_atoms).repeat(num_atoms)
40
+
41
+ # Remove self-loops (atoms don't connect to themselves)
42
+ mask = row != col
43
+ edge_index = torch.stack([row[mask], col[mask]], dim=0).to(device)
44
+
45
+ return x_real, z, edge_index
46
+
47
+ def train():
48
+ # 1. Setup
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ print(f"--- 🚀 Training on {device} ---")
51
+
52
+ # Load Data
53
+ dataset = load_dataset()
54
+
55
+ # Initialize Model
56
+ model = CrystalDiffusionModel().to(device)
57
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
58
+ model.train()
59
+
60
+ print(f"--- Starting Training Loop ({EPOCHS} Epochs) ---")
61
+
62
+ for epoch in range(1, EPOCHS + 1):
63
+ optimizer.zero_grad()
64
+
65
+ # 2. Get Random Batch
66
+ x_real, z, edge_index = get_random_batch(dataset, device)
67
+
68
+ # 3. Diffusion Step (Forward)
69
+ # Sample random time 't' (how much noise to add)
70
+ t = torch.rand(1, 1, device=device)
71
+
72
+ # Create Noise
73
+ noise = torch.randn_like(x_real)
74
+
75
+ # Add noise: x_noisy = Real + (Noise * t)
76
+ x_noisy = x_real + (noise * t)
77
+
78
+ # 4. Model Prediction (Reverse)
79
+ # Predict the denoised structure
80
+ x_pred = model(x_noisy, z, t, edge_index)
81
+
82
+ # 5. Calculate Loss
83
+ # We want the predicted position to match the real position
84
+ loss = torch.mean((x_pred - x_real)**2)
85
+
86
+ loss.backward()
87
+ optimizer.step()
88
+
89
+ # Log progress
90
+ if epoch % 200 == 0:
91
+ print(f"Epoch {epoch} | Loss: {loss.item():.6f}")
92
+
93
+ # Save the smarter model
94
+ torch.save(model.state_dict(), "model_weights.pth")
95
+ print("✅ Training Complete. Model saved to model_weights.pth!")
96
+
97
+ if __name__ == "__main__":
98
+ train()
validate.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def read_xyz(filename):
4
+ coords = []
5
+ atoms = []
6
+ with open(filename, 'r') as f:
7
+ lines = f.readlines()
8
+ for line in lines[2:]:
9
+ parts = line.split()
10
+ atoms.append(parts[0])
11
+ coords.append(np.array([float(parts[1]), float(parts[2]), float(parts[3])]))
12
+ return atoms, coords
13
+
14
+ def check_physics():
15
+ print("--- 🧪 Scientific Validation ---")
16
+
17
+ # 1. Load the Generated Crystal
18
+ atoms, coords = read_xyz("gen_final.xyz")
19
+
20
+ # 2. Find the Titanium (Ti) and Oxygens (O)
21
+ # Note: In our code, we mapped:
22
+ # 22 -> Ti (Titanium)
23
+ # 8 -> O (Oxygen)
24
+ # 20 -> Ca (Calcium)
25
+
26
+ ti_indices = [i for i, atom in enumerate(atoms) if atom == "Ti"]
27
+ o_indices = [i for i, atom in enumerate(atoms) if atom == "O"]
28
+
29
+ if not ti_indices or not o_indices:
30
+ print("❌ Could not find Ti or O atoms to measure bonds.")
31
+ return
32
+
33
+ print(f"Found {len(ti_indices)} Titanium and {len(o_indices)} Oxygen atoms.")
34
+
35
+ # 3. Measure Distances
36
+ bond_lengths = []
37
+
38
+ for ti_idx in ti_indices:
39
+ ti_pos = coords[ti_idx]
40
+ for o_idx in o_indices:
41
+ o_pos = coords[o_idx]
42
+
43
+ # Calculate Euclidean Distance
44
+ dist = np.linalg.norm(ti_pos - o_pos)
45
+ bond_lengths.append(dist)
46
+
47
+ # 4. Analyze Results
48
+ min_bond = min(bond_lengths)
49
+ avg_bond = sum(bond_lengths) / len(bond_lengths)
50
+
51
+ print(f"\nMeasured Bond Lengths (Ti - O):")
52
+ print(f" Minimum: {min_bond:.4f} Å")
53
+ print(f" Average: {avg_bond:.4f} Å")
54
+
55
+ # 5. The "DeepMind" Pass/Fail
56
+ # Real Physics: Ti-O bond is typically 1.90 - 2.05 Å
57
+ # We allow some error since this is a tiny model trained for 5 minutes
58
+ if 1.5 < min_bond < 2.5:
59
+ print("\n✅ SUCCESS: The model learned valid chemical bonds!")
60
+ print(" (Target range: ~1.9 Å)")
61
+ else:
62
+ print("\n⚠️ WARNING: Bonds are physically unrealistic.")
63
+ print(" (Try training for more epochs or checking the dataset)")
64
+
65
+ if __name__ == "__main__":
66
+ check_physics()
visualize.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from mpl_toolkits.mplot3d import Axes3D
3
+ import numpy as np
4
+
5
+ def read_xyz(filename):
6
+ """
7
+ Reads an XYZ file and returns coordinates and atom types.
8
+ """
9
+ coords = []
10
+ atoms = []
11
+ with open(filename, 'r') as f:
12
+ lines = f.readlines()
13
+ # Skip header lines (first 2)
14
+ for line in lines[2:]:
15
+ parts = line.split()
16
+ atoms.append(parts[0])
17
+ coords.append([float(parts[1]), float(parts[2]), float(parts[3])])
18
+ return np.array(coords), atoms
19
+
20
+ def plot_crystal(ax, coords, atoms, title):
21
+ """
22
+ Plots a single crystal in a 3D subplot.
23
+ """
24
+ # Define colors for atoms (Titanium=Silver, Oxygen=Red, Ca=Green)
25
+ colors = {'Ti': 'gray', 'O': 'red', 'Ca': 'green', 'Pb': 'black', 'I': 'purple'}
26
+
27
+ # Scatter plot
28
+ # s=size of atom, alpha=transparency
29
+ for i, atom in enumerate(atoms):
30
+ color = colors.get(atom, 'blue') # Default to blue if unknown
31
+ ax.scatter(coords[i,0], coords[i,1], coords[i,2],
32
+ c=color, s=200, edgecolors='k', alpha=0.8)
33
+
34
+ # Draw "bonds" (lines between atoms close to each other)
35
+ # This helps visualize the structure
36
+ num_atoms = len(coords)
37
+ for i in range(num_atoms):
38
+ for j in range(i + 1, num_atoms):
39
+ dist = np.linalg.norm(coords[i] - coords[j])
40
+ # If atoms are closer than 2.8 Angstroms, draw a line
41
+ if dist < 2.8:
42
+ ax.plot([coords[i,0], coords[j,0]],
43
+ [coords[i,1], coords[j,1]],
44
+ [coords[i,2], coords[j,2]],
45
+ c='black', linewidth=1, alpha=0.5)
46
+
47
+ ax.set_title(title)
48
+ ax.set_xlabel('X')
49
+ ax.set_ylabel('Y')
50
+ ax.set_zlabel('Z')
51
+
52
+ # Set consistent limits so we can compare
53
+ ax.set_xlim(-2, 5)
54
+ ax.set_ylim(-2, 5)
55
+ ax.set_zlim(-2, 5)
56
+
57
+ def create_comparison_figure():
58
+ # 1. Read Data
59
+ # Make sure you ran generate.py first to get these files!
60
+ noise_pos, atoms = read_xyz("gen_step_00.xyz")
61
+ final_pos, _ = read_xyz("gen_final.xyz")
62
+
63
+ # 2. Setup Plot
64
+ fig = plt.figure(figsize=(12, 6))
65
+
66
+ # Plot 1: The Noise
67
+ ax1 = fig.add_subplot(121, projection='3d')
68
+ plot_crystal(ax1, noise_pos, atoms, "Step 0: Random Noise")
69
+
70
+ # Plot 2: The Generated Crystal
71
+ ax2 = fig.add_subplot(122, projection='3d')
72
+ plot_crystal(ax2, final_pos, atoms, "Step 50: Generated Crystal")
73
+
74
+ plt.tight_layout()
75
+ plt.savefig("result_plot.png", dpi=300)
76
+ print("Saved comparison figure to 'result_plot.png'")
77
+ plt.show()
78
+
79
+ if __name__ == "__main__":
80
+ create_comparison_figure()