TRIADS / model_code /phonons_dataset_builder.py
Rtx09's picture
TRIADS — 6-benchmark weights + model code + Gradio app
8a82d34
"""
+=============================================================+
| V6 Physics-Featurized Phonon Dataset Builder |
| Architecture-Agnostic | Rich Physics | 3-Order Graphs |
| |
| Features per atom: 18d (element physics + coords + local) |
| Features per bond: 8d physics + 40d RBF + 3d direction |
| Order 2 (angles): 8d angle RBF |
| Order 3 (dihedrals): 8d dihedral RBF |
| Composition: MAGPIE + mat2vec + matminer extras |
| Global physics: Debye temp, force constants, etc. |
| |
| ⚠ NO SCALING — raw features. Scale at training time only. |
+=============================================================+
DEPENDENCIES:
pip install matminer pymatgen gensim tqdm scikit-learn torch numpy
USAGE:
python build_phonons_v6_dataset.py
-> Outputs: phonons_v6_dataset.pt
"""
import os, time, math, warnings, urllib.request, logging
from collections import defaultdict
warnings.filterwarnings('ignore')
import numpy as np
import torch
from tqdm import tqdm
from sklearn.model_selection import KFold
logging.basicConfig(level=logging.INFO, format='%(name)s | %(message)s')
log = logging.getLogger("V6-BUILD")
# ═══════════════════════════════════════════════════════════════
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════
CUTOFF = 8.0
MAX_NEIGHBORS = 12
N_RBF_DIST = 40
N_RBF_ANGLE = 8
N_RBF_DIHEDRAL = 8
MAX_QUADS = 50000 # cap dihedrals per crystal for memory
FOLD_SEED = 18012019 # matbench v0.1 protocol
N_FOLDS = 5
N_ELEM_FEAT = 12 # from lookup table
N_ATOM_COMPUTED = 6 # frac_coords(3) + coord_num(1) + avg_nn(1) + std_nn(1)
N_ATOM_FEAT = N_ELEM_FEAT + N_ATOM_COMPUTED # 18
N_BOND_PHYSICS = 8
N_GLOBAL_PHYS = 15
# ═══════════════════════════════════════════════════════════════
# GAUSSIAN RADIAL BASIS FUNCTIONS
# ═══════════════════════════════════════════════════════════════
def gaussian_rbf(values, n_bins, vmin, vmax):
"""Fixed Gaussian expansion. No learnable parameters."""
centers = torch.linspace(vmin, vmax, n_bins)
gamma = 1.0 / ((vmax - vmin) / n_bins) ** 2
return torch.exp(-gamma * (values.unsqueeze(-1) - centers.unsqueeze(0)) ** 2)
# ═══════════════════════════════════════════════════════════════
# ELEMENT PHYSICS LOOKUP TABLE
# ═══════════════════════════════════════════════════════════════
def build_element_table():
"""
Build [103, 12] lookup table of per-element physical properties.
Z=0 is padding. Uses pymatgen Element data.
Columns: mass, 1/sqrt(mass), electronegativity, atomic_radius,
covalent_radius, ionization_energy, electron_affinity,
valence_electrons, group, period, block, is_metal
"""
from pymatgen.core.periodic_table import Element
block_map = {'s': 0., 'p': 1., 'd': 2., 'f': 3.}
table = torch.zeros(103, N_ELEM_FEAT)
for z in range(1, 103):
try:
el = Element.from_Z(z)
mass = float(el.atomic_mass) if el.atomic_mass else 1.0
chi = float(el.X) if el.X is not None else 0.0
ar = float(el.atomic_radius) if el.atomic_radius is not None else 1.5
# Covalent radius proxy
try:
cr = float(el.average_ionic_radius) if el.average_ionic_radius and float(el.average_ionic_radius) > 0 else ar
except:
cr = ar
# First ionization energy
ie = 0.0
try:
ies = el.ionization_energies
if isinstance(ies, dict) and 1 in ies and ies[1] is not None:
ie = float(ies[1])
elif isinstance(ies, (list, tuple)) and len(ies) > 1 and ies[1] is not None:
ie = float(ies[1])
except:
pass
# Electron affinity
ea = 0.0
try:
if el.electron_affinity is not None:
ea = float(el.electron_affinity)
except:
pass
# Group, period, valence electrons
g = int(el.group) if el.group is not None else 0
p = int(el.row) if el.row is not None else 0
ve = g if g <= 2 else (g - 10 if g >= 13 else 2)
bl = block_map.get(el.block, 0.) if hasattr(el, 'block') and el.block else 0.
im = 1.0 if el.is_metal else 0.0
table[z] = torch.tensor([
mass, 1.0 / math.sqrt(max(mass, 0.01)), chi, ar, cr,
ie, ea, float(ve), float(g), float(p), bl, im
])
except:
table[z] = torch.tensor([1., 1., 0., 1.5, 1.5, 0., 0., 0., 0., 0., 0., 0.])
return table
# ═══════════════════════════════════════════════════════════════
# CRYSTAL GRAPH BUILDER (Orders 1, 2, 3)
# ═══════════════════════════════════════════════════════════════
def _empty_graph(atom_z, atom_features, n_atoms):
"""Fallback for crystals with no neighbors found."""
return {
'atom_z': atom_z,
'atom_features': atom_features,
'n_atoms': n_atoms,
'edge_index': torch.zeros(2, 1, dtype=torch.long),
'edge_dist': torch.zeros(1),
'edge_rbf': torch.zeros(1, N_RBF_DIST),
'edge_vec': torch.zeros(1, 3),
'edge_physics': torch.zeros(1, N_BOND_PHYSICS),
'n_edges': 1,
'triplet_index': torch.zeros(2, 0, dtype=torch.long),
'angle_rbf': torch.zeros(0, N_RBF_ANGLE),
'n_triplets': 0,
'quad_index': torch.zeros(2, 0, dtype=torch.long),
'dihedral_rbf': torch.zeros(0, N_RBF_DIHEDRAL),
'n_quads': 0,
}
def build_crystal_graph(structure, elem_table):
"""
Build a complete 3-order crystal graph for a single structure.
Returns dict with atom features, edge features + physics,
triplets (angles), and quads (dihedrals).
✅ ZERO DATA LEAKAGE: uses ONLY this structure's geometry.
"""
n_atoms = len(structure)
atom_z = torch.tensor([site.specie.Z for site in structure], dtype=torch.long)
# Element lookup features [N, 12]
atom_elem_feat = elem_table[atom_z.clamp(0, 102)]
# Fractional coordinates [N, 3]
frac_coords = torch.tensor(
[site.frac_coords for site in structure], dtype=torch.float32
)
# ── NEIGHBOR FINDING ──────────────────────────────────────
src_list, dst_list, dist_list, vec_list = [], [], [], []
nn_dists_per_atom = defaultdict(list)
try:
all_nbrs = structure.get_all_neighbors(CUTOFF)
for i, nbrs in enumerate(all_nbrs):
nbrs_sorted = sorted(nbrs, key=lambda x: x.nn_distance)[:MAX_NEIGHBORS]
for nbr in nbrs_sorted:
src_list.append(i)
dst_list.append(nbr.index)
dist_list.append(nbr.nn_distance)
vec_list.append(nbr.coords - structure[i].coords)
nn_dists_per_atom[i].append(nbr.nn_distance)
except Exception as e:
log.warning(f" Neighbor finding failed: {e}")
# Per-atom coordination stats
coord_nums = torch.zeros(n_atoms)
avg_nn_dists = torch.zeros(n_atoms)
std_nn_dists = torch.zeros(n_atoms)
for i in range(n_atoms):
ds = nn_dists_per_atom.get(i, [])
coord_nums[i] = len(ds)
if ds:
avg_nn_dists[i] = np.mean(ds)
std_nn_dists[i] = np.std(ds) if len(ds) > 1 else 0.0
# Combined atom features [N, 18]
atom_features = torch.cat([
atom_elem_feat, # [N, 12]
frac_coords, # [N, 3]
coord_nums.unsqueeze(-1), # [N, 1]
avg_nn_dists.unsqueeze(-1), # [N, 1]
std_nn_dists.unsqueeze(-1), # [N, 1]
], dim=-1) # [N, 18]
if len(src_list) == 0:
return _empty_graph(atom_z, atom_features, n_atoms)
# ── EDGE FEATURES (Order 1) ───────────────────────────────
edge_index = torch.tensor([src_list, dst_list], dtype=torch.long)
edge_dist = torch.tensor(dist_list, dtype=torch.float32)
raw_vecs = torch.tensor(np.array(vec_list), dtype=torch.float32)
n_edges = edge_index.shape[1]
edge_rbf = gaussian_rbf(edge_dist, N_RBF_DIST, 0.0, CUTOFF)
norms = raw_vecs.norm(dim=-1, keepdim=True).clamp(min=1e-8)
edge_vec = raw_vecs / norms
# ── BOND PHYSICS FEATURES [E, 8] ─────────────────────────
z_src = atom_z[edge_index[0]] # [E]
z_dst = atom_z[edge_index[1]] # [E]
m_src = elem_table[z_src.clamp(0, 102), 0] # mass
m_dst = elem_table[z_dst.clamp(0, 102), 0]
chi_src = elem_table[z_src.clamp(0, 102), 2] # electronegativity
chi_dst = elem_table[z_dst.clamp(0, 102), 2]
r_src = elem_table[z_src.clamp(0, 102), 3] # atomic radius
r_dst = elem_table[z_dst.clamp(0, 102), 3]
d = edge_dist.clamp(min=0.01)
# Vectorized bond physics computation
chi_prod = (chi_src * chi_dst).clamp(min=0.01)
k_est = torch.sqrt(chi_prod) / (d * d) # force constant
mu = (m_src * m_dst) / (m_src + m_dst).clamp(min=0.01) # reduced mass
omega = torch.sqrt(k_est / mu.clamp(min=0.01)) # Einstein freq
delta_chi = (chi_src - chi_dst).abs() # EN difference
ionicity = delta_chi * delta_chi # bond ionicity
r_ratio = (r_src + r_dst) / d # radius sum ratio
m_ratio = torch.min(m_src, m_dst) / torch.max(m_src, m_dst).clamp(min=0.01)
inv_d = 1.0 / d # inverse distance
edge_physics = torch.stack([
k_est, mu, omega, delta_chi, ionicity, r_ratio, m_ratio, inv_d
], dim=-1) # [E, 8]
# ── TRIPLETS / ANGLES (Order 2) ───────────────────────────
dst_np = edge_index[1].numpy()
dest_to_edges = defaultdict(list)
for e_idx in range(n_edges):
dest_to_edges[int(dst_np[e_idx])].append(e_idx)
trip_ij, trip_kj = [], []
for j, edge_list in dest_to_edges.items():
for idx_ij in edge_list:
for idx_kj in edge_list:
if idx_ij != idx_kj:
trip_ij.append(idx_ij)
trip_kj.append(idx_kj)
if trip_ij:
triplet_index = torch.tensor([trip_ij, trip_kj], dtype=torch.long)
v_ij = edge_vec[triplet_index[0]]
v_kj = edge_vec[triplet_index[1]]
cos_theta = (v_ij * v_kj).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7)
angles = torch.acos(cos_theta)
angle_rbf_t = gaussian_rbf(angles, N_RBF_ANGLE, 0.0, math.pi)
n_triplets = triplet_index.shape[1]
else:
triplet_index = torch.zeros(2, 0, dtype=torch.long)
angle_rbf_t = torch.zeros(0, N_RBF_ANGLE)
n_triplets = 0
# ── QUADS / DIHEDRALS (Order 3) ───────────────────────────
quad_index, dihedral_rbf_t, n_quads = _compute_quads(
triplet_index, n_triplets, edge_vec, trip_ij, trip_kj
)
return {
'atom_z': atom_z,
'atom_features': atom_features,
'n_atoms': n_atoms,
'edge_index': edge_index,
'edge_dist': edge_dist,
'edge_rbf': edge_rbf,
'edge_vec': edge_vec,
'edge_physics': edge_physics,
'n_edges': n_edges,
'triplet_index': triplet_index,
'angle_rbf': angle_rbf_t,
'n_triplets': n_triplets,
'quad_index': quad_index,
'dihedral_rbf': dihedral_rbf_t,
'n_quads': n_quads,
}
def _compute_quads(triplet_index, n_triplets, edge_vec, trip_ij, trip_kj):
"""Compute Order 3: pairs of triplets sharing a bond (dihedrals)."""
if n_triplets == 0:
return (torch.zeros(2, 0, dtype=torch.long),
torch.zeros(0, N_RBF_DIHEDRAL), 0)
# For each edge, which triplets reference it?
edge_to_trips = defaultdict(list)
for t_idx in range(n_triplets):
edge_to_trips[trip_ij[t_idx]].append(t_idx)
edge_to_trips[trip_kj[t_idx]].append(t_idx)
quad_src, quad_dst = [], []
for edge_idx, tlist in edge_to_trips.items():
for i in range(len(tlist)):
for j in range(len(tlist)):
if tlist[i] != tlist[j]:
quad_src.append(tlist[i])
quad_dst.append(tlist[j])
if len(quad_src) >= MAX_QUADS:
break
if len(quad_src) >= MAX_QUADS:
break
if len(quad_src) >= MAX_QUADS:
break
if not quad_src:
return (torch.zeros(2, 0, dtype=torch.long),
torch.zeros(0, N_RBF_DIHEDRAL), 0)
quad_index = torch.tensor([quad_src, quad_dst], dtype=torch.long)
# Dihedral angle = angle between planes of the two triplets
v_a1 = edge_vec[triplet_index[0, quad_index[0]]]
v_a2 = edge_vec[triplet_index[1, quad_index[0]]]
v_b1 = edge_vec[triplet_index[0, quad_index[1]]]
v_b2 = edge_vec[triplet_index[1, quad_index[1]]]
n_a = torch.cross(v_a1, v_a2, dim=-1)
n_b = torch.cross(v_b1, v_b2, dim=-1)
n_a = n_a / n_a.norm(dim=-1, keepdim=True).clamp(min=1e-8)
n_b = n_b / n_b.norm(dim=-1, keepdim=True).clamp(min=1e-8)
cos_dih = (n_a * n_b).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7)
dihedrals = torch.acos(cos_dih)
dihedral_rbf_t = gaussian_rbf(dihedrals, N_RBF_DIHEDRAL, 0.0, math.pi)
return quad_index, dihedral_rbf_t, quad_index.shape[1]
# ═══════════════════════════════════════════════════════════════
# GLOBAL PHYSICS FEATURES (per crystal)
# ═══════════════════════════════════════════════════════════════
def compute_global_physics(graph, structure, elem_table):
"""
Compute 15 global physics features from a crystal graph.
Features:
0: avg_force_constant 7: avg_coordination
1: std_force_constant 8: density
2: avg_reduced_mass 9: volume_per_atom
3: mass_variance 10: packing_fraction
4: avg_einstein_freq 11: avg_bond_length
5: electronegativity_var 12: std_bond_length
6: debye_temp_estimate 13: max_atomic_mass
14: min_atomic_mass
"""
ep = graph['edge_physics'] # [E, 8]
n_atoms = graph['n_atoms']
atom_z = graph['atom_z']
# From bond physics
k_vals = ep[:, 0] # force constants
mu_vals = ep[:, 1] # reduced masses
omega_vals = ep[:, 2] # Einstein frequencies
dists = graph['edge_dist']
feats = torch.zeros(N_GLOBAL_PHYS)
if graph['n_edges'] > 0 and dists.shape[0] > 0:
feats[0] = k_vals.mean()
feats[1] = k_vals.std() if k_vals.shape[0] > 1 else 0.0
feats[2] = mu_vals.mean()
feats[4] = omega_vals.mean()
feats[11] = dists.mean()
feats[12] = dists.std() if dists.shape[0] > 1 else 0.0
# Mass statistics
masses = elem_table[atom_z.clamp(0, 102), 0]
feats[3] = masses.var() if n_atoms > 1 else 0.0
feats[13] = masses.max()
feats[14] = masses.min()
# Electronegativity variance
chis = elem_table[atom_z.clamp(0, 102), 2]
feats[5] = chis.var() if n_atoms > 1 else 0.0
# Debye temperature estimate: Θ_D ∝ sqrt(k_avg / m_avg)
m_avg = masses.mean()
k_avg = feats[0]
feats[6] = math.sqrt(float(k_avg / max(m_avg, 0.01)))
# Coordination
feats[7] = graph['atom_features'][:, N_ELEM_FEAT + 3].mean() # coord_num column
# Structural
try:
feats[8] = structure.density
feats[9] = structure.volume / max(n_atoms, 1)
# Packing fraction
total_vol = sum(
(4 / 3) * math.pi * (float(site.specie.atomic_radius) ** 3)
for site in structure
if hasattr(site.specie, 'atomic_radius') and site.specie.atomic_radius is not None
)
feats[10] = total_vol / structure.volume if structure.volume > 0 else 0.0
except:
pass
return feats
# ═══════════════════════════════════════════════════════════════
# STRUCTURAL FEATURES (per crystal)
# ═══════════════════════════════════════════════════════════════
def compute_structural_features(structure):
"""
Compute 11 structural features: lattice params + symmetry.
Same as previous versions for backward compatibility.
"""
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
feats = np.zeros(11, dtype=np.float32)
try:
lat = structure.lattice
feats[0:6] = [lat.a, lat.b, lat.c, lat.alpha, lat.beta, lat.gamma]
feats[6] = structure.volume / max(len(structure), 1)
feats[7] = structure.density
feats[8] = float(len(structure))
try:
sga = SpacegroupAnalyzer(structure, symprec=0.1)
feats[9] = float(sga.get_space_group_number())
except:
feats[9] = 0.0
try:
total_vol = sum(
(4 / 3) * np.pi * site.specie.atomic_radius ** 3
for site in structure
if hasattr(site.specie, 'atomic_radius') and site.specie.atomic_radius is not None
)
feats[10] = total_vol / structure.volume if structure.volume > 0 else 0.0
except:
feats[10] = 0.0
except:
pass
return feats
# ═══════════════════════════════════════════════════════════════
# COMPOSITION FEATURIZER (MAGPIE + mat2vec + matminer extras)
# ═══════════════════════════════════════════════════════════════
class CompositionFeaturizer:
"""
Builds rich composition features per crystal:
- MAGPIE elemental properties (132d: 22 props × 6 stats)
- Extra matminer (Stoichiometry, ValenceOrbital, IonProperty, TMetalFraction)
- Structural features (11d)
- mat2vec embeddings (200d)
✅ ALL features are deterministic per-sample. No cross-sample info.
"""
M2V_URL = "https://storage.googleapis.com/mat2vec/"
M2V_FILES = [
"pretrained_embeddings",
"pretrained_embeddings.wv.vectors.npy",
"pretrained_embeddings.trainables.syn1neg.npy",
]
def __init__(self, cache="mat2vec_cache"):
from matminer.featurizers.composition import (
ElementProperty, Stoichiometry, ValenceOrbital, IonProperty
)
from matminer.featurizers.composition.element import TMetalFraction
from gensim.models import Word2Vec
self.ep_magpie = ElementProperty.from_preset("magpie")
self.n_magpie = len(self.ep_magpie.feature_labels())
self.extra_ftzrs = [
("Stoichiometry", Stoichiometry()),
("ValenceOrbital", ValenceOrbital()),
("IonProperty", IonProperty()),
("TMetalFraction", TMetalFraction()),
]
self._extra_sizes = {}
for name, ft in self.extra_ftzrs:
try:
self._extra_sizes[name] = len(ft.feature_labels())
except:
self._extra_sizes[name] = None
# Download mat2vec
os.makedirs(cache, exist_ok=True)
for f in self.M2V_FILES:
p = os.path.join(cache, f)
if not os.path.exists(p):
log.info(f" Downloading mat2vec: {f}...")
urllib.request.urlretrieve(self.M2V_URL + f, p)
m2v = Word2Vec.load(os.path.join(cache, "pretrained_embeddings"))
self.emb = {w: m2v.wv[w] for w in m2v.wv.index_to_key}
self.n_extra = None # determined on first call
def _pool_m2v(self, comp):
v, t = np.zeros(200, np.float32), 0.0
for s, f in comp.get_el_amt_dict().items():
if s in self.emb:
v += f * self.emb[s]
t += f
return v / max(t, 1e-8)
def _featurize_extras(self, comp):
parts = []
for name, ft in self.extra_ftzrs:
try:
vals = np.array(ft.featurize(comp), np.float32)
parts.append(np.nan_to_num(vals, nan=0.0))
if self._extra_sizes.get(name) is None:
self._extra_sizes[name] = len(vals)
except:
sz = self._extra_sizes.get(name, 0) or 1
parts.append(np.zeros(sz, np.float32))
return np.concatenate(parts)
def featurize_all(self, compositions, structures):
"""Return [N, D_comp] array of all composition features."""
# Determine dimensions from first sample
test_extras = self._featurize_extras(compositions[0])
self.n_extra = len(test_extras)
struct_feats_dim = 11
total_dim = self.n_magpie + self.n_extra + struct_feats_dim + 200
log.info(f" Composition features: {self.n_magpie} MAGPIE + "
f"{self.n_extra} Extras + 11 Structural + 200 mat2vec = {total_dim}d")
out = []
for i, comp in enumerate(tqdm(compositions, desc=" Featurizing compositions", leave=False)):
# MAGPIE
try:
mg = np.array(self.ep_magpie.featurize(comp), np.float32)
except:
mg = np.zeros(self.n_magpie, np.float32)
mg = np.nan_to_num(mg, nan=0.0)
# Extra matminer
ex = self._featurize_extras(comp)
# Structural
sf = compute_structural_features(structures[i])
# mat2vec
m2v = self._pool_m2v(comp)
out.append(np.concatenate([mg, ex, sf, m2v]))
return np.array(out, dtype=np.float32)
# ═══════════════════════════════════════════════════════════════
# MAIN — BUILD AND SAVE
# ═══════════════════════════════════════════════════════════════
def main():
t0 = time.time()
print("""
+==========================================================+
| V6 Physics-Featurized Phonon Dataset Builder |
| 3-Order Graphs | Bond Physics | Architecture-Agnostic |
| ⚠ NO SCALING — raw features only |
+==========================================================+
""")
# ── LOAD MATBENCH DATA ────────────────────────────────────
print(" Loading matbench_phonons...")
from matminer.datasets import load_dataset
df = load_dataset("matbench_phonons")
targets = np.array(df['last phdos peak'].tolist(), np.float32)
structures = df['structure'].tolist()
compositions = [s.composition for s in structures]
N = len(structures)
print(f" Loaded: {N} samples")
print(f" Target range: {targets.min():.1f}{targets.max():.1f} cm⁻¹")
# ── BUILD ELEMENT TABLE ───────────────────────────────────
print("\n Building element physics table...")
elem_table = build_element_table()
print(f" Element table: {elem_table.shape} (Z=0..102, {N_ELEM_FEAT} features)")
# ── BUILD CRYSTAL GRAPHS ─────────────────────────────────
print(f"\n Building 3-order crystal graphs ({MAX_NEIGHBORS}-NN, cutoff={CUTOFF}Å)...")
graphs = []
global_physics_list = []
for i, struct in enumerate(tqdm(structures, desc=" Building graphs")):
g = build_crystal_graph(struct, elem_table)
gp = compute_global_physics(g, struct, elem_table)
graphs.append(g)
global_physics_list.append(gp)
# Stats
n_atoms_list = [g['n_atoms'] for g in graphs]
n_edges_list = [g['n_edges'] for g in graphs]
n_trips_list = [g['n_triplets'] for g in graphs]
n_quads_list = [g['n_quads'] for g in graphs]
print(f" Graphs built:")
print(f" Atoms/crystal: min={min(n_atoms_list)}, max={max(n_atoms_list)}, "
f"mean={np.mean(n_atoms_list):.1f}")
print(f" Edges/crystal: min={min(n_edges_list)}, max={max(n_edges_list)}, "
f"mean={np.mean(n_edges_list):.1f}")
print(f" Triplets/crystal: min={min(n_trips_list)}, max={max(n_trips_list)}, "
f"mean={np.mean(n_trips_list):.1f}")
print(f" Quads/crystal: min={min(n_quads_list)}, max={max(n_quads_list)}, "
f"mean={np.mean(n_quads_list):.1f}")
global_physics = torch.stack(global_physics_list)
print(f" Global physics: {global_physics.shape}")
# ── COMPOSITION FEATURES ─────────────────────────────────
print("\n Computing composition features...")
feat = CompositionFeaturizer()
comp_features = feat.featurize_all(compositions, structures)
print(f" Composition features shape: {comp_features.shape}")
# ── FOLD INDICES (strict matbench protocol) ──────────────
print(f"\n Computing 5-fold split indices (seed={FOLD_SEED})...")
kf = KFold(N_FOLDS, shuffle=True, random_state=FOLD_SEED)
fold_indices = [(train_idx.tolist(), test_idx.tolist())
for train_idx, test_idx in kf.split(range(N))]
# Verify zero leakage
for fi, (tr, te) in enumerate(fold_indices):
overlap = set(tr) & set(te)
assert len(overlap) == 0, f"DATA LEAK in fold {fi}: {len(overlap)} shared indices!"
assert len(tr) + len(te) == N, f"Fold {fi}: missing samples!"
print(" ✅ All folds verified: ZERO data leakage")
# ── FEATURE DIMENSION INFO ───────────────────────────────
n_magpie = feat.n_magpie
n_extra = feat.n_extra
feature_info = {
'atom_features_dim': N_ATOM_FEAT,
'atom_features_layout': [
'mass', '1/sqrt_mass', 'electronegativity', 'atomic_radius',
'covalent_radius', 'ionization_energy', 'electron_affinity',
'valence_electrons', 'group', 'period', 'block', 'is_metal',
'frac_x', 'frac_y', 'frac_z',
'coordination_num', 'avg_nn_dist', 'std_nn_dist',
],
'edge_physics_dim': N_BOND_PHYSICS,
'edge_physics_layout': [
'force_constant', 'reduced_mass', 'einstein_freq',
'en_difference', 'ionicity', 'radius_sum_ratio',
'mass_ratio', 'inverse_distance',
],
'edge_rbf_dim': N_RBF_DIST,
'angle_rbf_dim': N_RBF_ANGLE,
'dihedral_rbf_dim': N_RBF_DIHEDRAL,
'global_physics_dim': N_GLOBAL_PHYS,
'global_physics_layout': [
'avg_force_constant', 'std_force_constant', 'avg_reduced_mass',
'mass_variance', 'avg_einstein_freq', 'en_variance',
'debye_temp_estimate', 'avg_coordination', 'density',
'volume_per_atom', 'packing_fraction', 'avg_bond_length',
'std_bond_length', 'max_atomic_mass', 'min_atomic_mass',
],
'comp_magpie_range': (0, n_magpie),
'comp_extras_range': (n_magpie, n_magpie + n_extra),
'comp_structural_range': (n_magpie + n_extra, n_magpie + n_extra + 11),
'comp_mat2vec_range': (n_magpie + n_extra + 11, n_magpie + n_extra + 11 + 200),
'comp_total_dim': comp_features.shape[1],
}
# ── SAVE ─────────────────────────────────────────────────
save_path = "phonons_v6_dataset.pt"
save_data = {
# Per-crystal data
'graphs': graphs,
'comp_features': torch.tensor(comp_features, dtype=torch.float32),
'global_physics': global_physics,
'targets': torch.tensor(targets, dtype=torch.float32),
# Fold indices
'fold_indices': fold_indices,
'fold_seed': FOLD_SEED,
# Metadata
'n_samples': N,
'feature_info': feature_info,
'element_table': elem_table,
'config': {
'cutoff': CUTOFF,
'max_neighbors': MAX_NEIGHBORS,
'n_rbf_dist': N_RBF_DIST,
'n_rbf_angle': N_RBF_ANGLE,
'n_rbf_dihedral': N_RBF_DIHEDRAL,
'max_quads': MAX_QUADS,
'fold_seed': FOLD_SEED,
'n_folds': N_FOLDS,
},
}
torch.save(save_data, save_path)
size_mb = os.path.getsize(save_path) / 1e6
dt = time.time() - t0
print(f"\n ✅ Saved: {save_path} ({size_mb:.1f} MB)")
print(f" Total time: {dt:.1f}s")
# ── SUMMARY ──────────────────────────────────────────────
print(f"""
╔══════════════════════════════════════════════════════════╗
║ Dataset Summary ║
╠══════════════════════════════════════════════════════════╣
║ Samples: {N:>6}
║ Atom features: {N_ATOM_FEAT:>6}d (12 elem + 3 coord + 3 local) ║
║ Bond RBF: {N_RBF_DIST:>6}d ║
║ Bond physics: {N_BOND_PHYSICS:>6}d (k, μ, ω, Δχ, ...) ║
║ Angle RBF: {N_RBF_ANGLE:>6}d ║
║ Dihedral RBF: {N_RBF_DIHEDRAL:>6}d ║
║ Composition: {comp_features.shape[1]:>6}d (MAGPIE+extras+struct+m2v)║
║ Global physics: {N_GLOBAL_PHYS:>6}d ║
║ Folds: {N_FOLDS:>6} (seed={FOLD_SEED}) ║
║ File size: {size_mb:>5.1f} MB ║
╚══════════════════════════════════════════════════════════╝
⚠ Remember: NO scaling applied. Apply StandardScaler at
training time using ONLY train-fold indices!
Architecture-agnostic: plug ANY model on top of this dataset.
""")
if __name__ == '__main__':
main()