Crystal-Diffusion / src /data_loader.py
adi2075's picture
Upload 16 files
4e4ff14 verified
import torch
from mp_api.client import MPRester
import os
from pathlib import Path
from dotenv import load_dotenv
from tqdm import tqdm
# --- CONFIGURATION ---
load_dotenv() # Load .env file for API keys
API_KEY = os.getenv("MPI_API_KEY")
# Save path: project root `data/` folder
repo_root = Path(__file__).resolve().parents[1]
SAVE_PATH = repo_root / "data" / "perovskite_dataset.pt"
def fetch_data(limit=2000):
"""
Fetches a large dataset of ABO3 Perovskites (5 atoms) for the Foundation Model.
"""
print(f"Connecting to Materials Project...")
with MPRester(API_KEY) as mpr:
# 1. Broad Search: Get all stable materials with 5 atoms
# We search for materials with exactly 5 sites (atoms) in the unit cell.
# This implicitly targets ABO3 structures (1+1+3 = 5).
docs = mpr.materials.summary.search(
is_stable=True,
nsites=5,
fields=["structure", "material_id", "formula_pretty"]
)
print(f"Found {len(docs)} stable 5-atom crystals. Processing...")
dataset = []
# 2. Filter and Process
# We want oxygen-containing perovskites generally, but let's keep it broad for now.
# The 'nsites=5' filter does most of the heavy lifting.
count = 0
for doc in tqdm(docs):
if count >= limit:
break
structure = doc.structure
formula = doc.formula_pretty
# Heuristic check: Perovskites usually have 3 Oxygens.
# This filters out random 5-atom things that aren't Perovskites.
# (Optional but recommended for cleaner data)
if "O3" not in formula:
continue
# --- TENSOR CREATION ---
# A. Atomic Numbers (Integers) -> The "Identity"
atomic_numbers = [site.specie.number for site in structure]
z_tensor = torch.tensor(atomic_numbers, dtype=torch.long)
# B. Coordinates (Floats) -> The "Geometry"
coords = [site.coords for site in structure]
r_tensor = torch.tensor(coords, dtype=torch.float32)
# C. Center of Mass Correction (CRITICAL for Diffusion)
# We shift the crystal so its center is at (0,0,0).
# If we don't do this, the model wastes time learning absolute positions.
r_tensor = r_tensor - torch.mean(r_tensor, dim=0, keepdim=True)
# Create Data Object
data_point = {
"id": str(doc.material_id),
"formula": formula,
"z": z_tensor, # Features
"pos": r_tensor # Positions (Centered)
}
dataset.append(data_point)
count += 1
# 3. Save to Disk
# Ensure directory exists
SAVE_PATH.parent.mkdir(parents=True, exist_ok=True)
torch.save(dataset, SAVE_PATH)
print(f"✅ Successfully saved {len(dataset)} crystals to {SAVE_PATH}")
print(f" (Filtered for 5-atom unit cells containing 'O3')")
if __name__ == "__main__":
fetch_data(limit=2000)