Crystal-Diffusion / src /model.py
adi2075's picture
Upload 16 files
4e4ff14 verified
import torch
import torch.nn as nn
from src.layers import EGNNLayer
class TimeEmbedding(nn.Module):
"""
Converts a time scalar 't' into a vector embedding.
This allows the neural network to understand the noise level (time step).
"""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
self.linear_1 = nn.Linear(1, dim)
self.linear_2 = nn.Linear(dim, dim)
self.act = nn.SiLU() # SiLU is standard for diffusion models
def forward(self, t: torch.Tensor) -> torch.Tensor:
"""
Args:
t (Tensor): Time scalars of shape (Batch_Size, 1).
Returns:
Tensor: Time embeddings of shape (Batch_Size, dim).
"""
# t is shape (Batch_Size, 1) -> we want to output (Batch_Size, dim)
x = self.act(self.linear_1(t))
x = self.linear_2(x)
return x
class CrystalDiffusionModel(nn.Module):
"""
E(n)-Equivariant Diffusion Model for Crystal Generation.
Predicts the denoised coordinates given a noisy input.
"""
def __init__(self, hidden_dim: int = 64, num_layers: int = 3, max_atom_type: int = 100):
super().__init__()
# 1. Atom Embedding: Integer -> Vector
# Maps atomic numbers (e.g., 8 for Oxygen) to a dense vector
self.atom_embed = nn.Embedding(max_atom_type, hidden_dim)
# 2. Time Embedding: Scalar -> Vector
# Helps the model know if it's looking at pure noise (t=1) or a crystal (t=0)
self.time_embed = TimeEmbedding(hidden_dim)
# 3. Backbone: Stack of Equivariant GNN layers
# These update both features (h) and positions (x)
self.layers = nn.ModuleList([
EGNNLayer(c_in=hidden_dim, c_out=hidden_dim)
for _ in range(num_layers)
])
# Note: We don't need a final linear layer for positions because
# the EGNN layers update the coordinates 'x' directly at every step.
def forward(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the diffusion model.
Args:
x (Tensor): Noisy atom positions. Shape (N, 3).
z (Tensor): Atomic numbers. Shape (N,).
t (Tensor): Time step/Noise level. Shape (Batch_Size, 1).
edge_index (Tensor): Graph connectivity (Adjacency list). Shape (2, E).
Returns:
Tensor: Denoised atom positions. Shape (N, 3).
"""
# 1. Embed Inputs
h = self.atom_embed(z) # (N, hidden_dim)
t_emb = self.time_embed(t) # (Batch, hidden_dim)
# 2. Condition on Time
# Broadcast time embedding to all atoms in the batch
# (Assuming single batch or handled externally for simplicity)
h = h + t_emb.mean(dim=0, keepdim=True)
# 3. Message Passing (The "Brain")
for layer in self.layers:
# Update features (h) and positions (x) respecting symmetry
h, x = layer(h, x, edge_index)
# Return the updated (denoised) positions
return x