affinose-interaction-model / src /bertose_model.py
supanthadey1's picture
Apply AFFINose display capitalization
7d06ac6 verified
"""
BERTose model
Core glycan representation model with three modalities:
- Sequence (WURCS atomic tokenization)
- MS (mass spectrometry peaks, RT, intensity)
- 3D structure (VQ-VAE discrete tokens, 4 per residue)
Each modality has its own encoder, with cross-attention for sequence-structure alignment.
"""
import torch
import torch.nn as nn
from typing import Dict, Optional, Tuple
import math
try:
from .bertose_layers import GlycanBERTConfig, GlycanBERTEmbeddings, GlycanBERTLayer
except ImportError:
from bertose_layers import GlycanBERTConfig, GlycanBERTEmbeddings, GlycanBERTLayer
class ConvGlycanBERTEmbeddings(nn.Module):
"""
Improved Convolutional front-end that mixes local WURCS context before the Transformer.
Key improvements over original:
1. Position embeddings added BEFORE convolution (provides spatial context to conv)
2. Residual connection (conv enriches embeddings rather than replacing them)
3. Multi-scale convolutions (kernel sizes 3, 5, 7) for better receptive field
4. Proper layer normalization on the residual path
"""
def __init__(self, config):
super().__init__()
self.token_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size
)
# Branch depth embeddings encode depth in the glycan tree.
max_branch_depth = getattr(config, "max_branch_depth", 8)
self.branch_embeddings = nn.Embedding(max_branch_depth, config.hidden_size)
# Linkage type embeddings encode glycosidic bond chemistry.
# 0=none, 1=1-3, 2=1-4, 3=1-6, etc.
num_linkage_types = getattr(config, "num_linkage_types", 9)
self.linkage_embeddings = nn.Embedding(num_linkage_types, config.hidden_size)
# Multi-scale convolutions for different receptive fields
kernel_size = getattr(config, "cnn_kernel_size", 3)
# Split channels evenly: 256 + 256 + 256 = 768 for hidden_size=768
channels_per_scale = config.hidden_size // 3
self.conv_layers = nn.ModuleList([
nn.Conv1d(
in_channels=config.hidden_size,
out_channels=channels_per_scale,
kernel_size=kernel_size + 2 * i, # Kernels: 3, 5, 7
padding=(kernel_size + 2 * i) // 2, # Same padding
)
for i in range(3)
])
self.conv_activation = nn.GELU()
self.conv_proj = nn.Linear(channels_per_scale * 3, config.hidden_size) # Project concatenated back
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.conv_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).expand((1, -1)),
)
self.hidden_size = config.hidden_size
def forward(self, input_ids, branch_depths=None, linkage_types=None):
seq_len = input_ids.shape[1]
# Step 1: Token + Position embeddings FIRST (provides spatial context to conv)
x = self.token_embeddings(input_ids) # (batch, seq, hidden)
position_ids = self.position_ids[:, :seq_len]
x = x + self.position_embeddings(position_ids)
# Add branch depth embeddings.
if branch_depths is not None:
# Clamp to valid range
branch_depths = branch_depths.clamp(0, self.branch_embeddings.num_embeddings - 1)
x = x + self.branch_embeddings(branch_depths)
# Add linkage type embeddings.
if linkage_types is not None:
linkage_types = linkage_types.clamp(0, self.linkage_embeddings.num_embeddings - 1)
x = x + self.linkage_embeddings(linkage_types)
x = self.LayerNorm(x)
# Step 2: Multi-scale convolution with RESIDUAL connection
# Convolution expects (batch, hidden, seq)
conv_in = x.permute(0, 2, 1)
# Apply multi-scale convolutions and concatenate
conv_outputs = []
for conv in self.conv_layers:
conv_out = self.conv_activation(conv(conv_in))
conv_outputs.append(conv_out)
# Concatenate multi-scale features and project back
conv_out = torch.cat(conv_outputs, dim=1) # (batch, hidden, seq)
conv_out = conv_out.permute(0, 2, 1) # (batch, seq, hidden)
conv_out = self.conv_proj(conv_out) # Project to correct size
# Step 3: Residual connection - conv ENRICHES rather than replaces
x = self.conv_norm(x + self.dropout(conv_out))
return x
def create_residue_level_mask(
seq_residue_ids: torch.Tensor, # (batch, N_seq)
struct_residue_ids: torch.Tensor # (batch, N_struct)
) -> torch.Tensor:
"""
Create residue-level attention mask for cross-attention.
Maps WURCS tokens to VQ-VAE structural tokens based on residue IDs.
A WURCS token with residue_id=0 can only attend to VQ-VAE tokens with residue_id=0.
Args:
seq_residue_ids: Residue IDs for sequence tokens (batch, N_seq)
struct_residue_ids: Residue IDs for structural tokens (batch, N_struct)
Returns:
Boolean mask (batch, N_seq, N_struct) where True = can attend
"""
# Expand dimensions for broadcasting
# seq: (batch, N_seq, 1)
# struct: (batch, 1, N_struct)
mask = seq_residue_ids.unsqueeze(2) == struct_residue_ids.unsqueeze(1)
# Shape: (batch, N_seq, N_struct)
# Mask out structural tokens (residue_id = -1) and MS tokens (residue_id = -2)
# Only tokens with residue_id >= 0 can attend
mask &= (seq_residue_ids.unsqueeze(2) >= 0)
return mask # True = can attend, False = cannot attend
class MultimodalGlycanBERTConfig:
"""Configuration for the BERTose model."""
def __init__(
self,
# Sequence modality
seq_vocab_size: int = 166,
seq_hidden_size: int = 768,
seq_num_layers: int = 12,
seq_num_heads: int = 12,
seq_max_length: int = 512,
# MS modality
ms_vocab_size: int = 242,
ms_hidden_size: int = 384,
ms_num_layers: int = 6,
ms_num_heads: int = 6,
ms_max_length: int = 150,
# 3D structure modality
struct_vocab_size: int = 1024, # VQ-VAE codebook size
struct_hidden_size: int = 512,
struct_num_layers: int = 8,
struct_num_heads: int = 8,
struct_max_length: int = 200,
use_3d: bool = True,
# Cross-attention
use_cross_attention: bool = True,
cross_attn_num_heads: int = 8,
# Fusion
fusion_hidden_size: int = 768,
fusion_num_layers: int = 2,
# Training
hidden_dropout_prob: float = 0.1,
attention_probs_dropout_prob: float = 0.1,
layer_norm_eps: float = 1e-12,
initializer_range: float = 0.02,
# Conv front-end
use_cnn_frontend: bool = True,
cnn_kernel_size: int = 3,
# Loss weights
seq_loss_weight: float = 0.60,
ms_loss_weight: float = 0.15,
struct_loss_weight: float = 0.25,
# Token IDs
pad_token_id: int = 0,
mask_token_id: int = 1,
):
# Sequence config
self.seq_vocab_size = seq_vocab_size
self.seq_hidden_size = seq_hidden_size
self.seq_num_layers = seq_num_layers
self.seq_num_heads = seq_num_heads
self.seq_max_length = seq_max_length
# MS config
self.ms_vocab_size = ms_vocab_size
self.ms_vocab_offset = seq_vocab_size # MS tokens start at 166
self.ms_total_vocab_size = seq_vocab_size + ms_vocab_size # 408 total
self.ms_hidden_size = ms_hidden_size
self.ms_num_layers = ms_num_layers
self.ms_num_heads = ms_num_heads
self.ms_max_length = ms_max_length
# Structure config
self.struct_vocab_size = struct_vocab_size
self.struct_hidden_size = struct_hidden_size
self.struct_num_layers = struct_num_layers
self.struct_num_heads = struct_num_heads
self.struct_max_length = struct_max_length
self.use_3d = use_3d
# Cross-attention config
self.use_cross_attention = use_cross_attention
self.cross_attn_num_heads = cross_attn_num_heads
# Fusion config
self.fusion_hidden_size = fusion_hidden_size
self.fusion_num_layers = fusion_num_layers
# Training config
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
# Conv front-end
self.use_cnn_frontend = use_cnn_frontend
self.cnn_kernel_size = cnn_kernel_size
# Loss weights
self.seq_loss_weight = seq_loss_weight
self.ms_loss_weight = ms_loss_weight
self.struct_loss_weight = struct_loss_weight
self.dist_loss_weight = 0.25
# Token IDs
self.pad_token_id = pad_token_id
self.mask_token_id = mask_token_id
def to_seq_config(self) -> GlycanBERTConfig:
"""Convert to sequence-only config."""
return GlycanBERTConfig(
vocab_size=self.seq_vocab_size,
hidden_size=self.seq_hidden_size,
num_hidden_layers=self.seq_num_layers,
num_attention_heads=self.seq_num_heads,
intermediate_size=self.seq_hidden_size * 4,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.seq_max_length,
layer_norm_eps=self.layer_norm_eps,
pad_token_id=self.pad_token_id,
mask_token_id=self.mask_token_id,
initializer_range=self.initializer_range,
)
def to_ms_config(self) -> GlycanBERTConfig:
"""Convert to MS-only config."""
return GlycanBERTConfig(
vocab_size=self.ms_total_vocab_size,
hidden_size=self.ms_hidden_size,
num_hidden_layers=self.ms_num_layers,
num_attention_heads=self.ms_num_heads,
intermediate_size=self.ms_hidden_size * 4,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.ms_max_length,
layer_norm_eps=self.layer_norm_eps,
pad_token_id=self.pad_token_id,
mask_token_id=self.mask_token_id,
initializer_range=self.initializer_range,
)
def to_struct_config(self) -> GlycanBERTConfig:
"""Convert to structure-only config."""
return GlycanBERTConfig(
vocab_size=self.struct_vocab_size,
hidden_size=self.struct_hidden_size,
num_hidden_layers=self.struct_num_layers,
num_attention_heads=self.struct_num_heads,
intermediate_size=self.struct_hidden_size * 4,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.struct_max_length,
layer_norm_eps=self.layer_norm_eps,
pad_token_id=self.pad_token_id,
mask_token_id=self.mask_token_id,
initializer_range=self.initializer_range,
)
# =============================================================================
# Improvement #1: Monosaccharide-Level Pooling
# =============================================================================
class MonosaccharidePooling(nn.Module):
"""
Pool token representations to monosaccharide level, then aggregate.
This bridges the gap between token-level BERT and monosaccharide-level CNNs/GNNs.
Uses monosaccharide_indices from the data to know where each residue starts.
"""
def __init__(self, hidden_size: int, num_attention_heads: int = 8, dropout: float = 0.1):
super().__init__()
self.hidden_size = hidden_size
# Attention pooling over monosaccharide representations
self.mono_attention = nn.MultiheadAttention(
embed_dim=hidden_size,
num_heads=num_attention_heads,
dropout=dropout,
batch_first=True
)
self.mono_norm = nn.LayerNorm(hidden_size)
# Final aggregation to single glycan representation
self.glycan_query = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
self.glycan_attention = nn.MultiheadAttention(
embed_dim=hidden_size,
num_heads=num_attention_heads,
dropout=dropout,
batch_first=True
)
self.glycan_norm = nn.LayerNorm(hidden_size)
def forward(
self,
hidden_states: torch.Tensor, # (batch, seq_len, hidden)
residue_ids: torch.Tensor, # (batch, seq_len) - which residue each token belongs to
attention_mask: torch.Tensor = None, # (batch, seq_len)
) -> torch.Tensor:
"""
Pool tokens to monosaccharide level, then to glycan level.
Returns:
Glycan representation: (batch, hidden_size)
"""
batch_size = hidden_states.size(0)
device = hidden_states.device
# Get unique residue IDs per sample (excluding -1 padding)
max_residues = 50 # Reasonable max for glycans
# Pool tokens within each residue using mean pooling
mono_reps = torch.zeros(batch_size, max_residues, self.hidden_size, device=device)
mono_mask = torch.zeros(batch_size, max_residues, dtype=torch.bool, device=device)
for b in range(batch_size):
unique_residues = torch.unique(residue_ids[b][residue_ids[b] >= 0])
for i, rid in enumerate(unique_residues):
if i >= max_residues:
break
token_mask = residue_ids[b] == rid
if attention_mask is not None:
token_mask = token_mask & (attention_mask[b] > 0)
if token_mask.sum() > 0:
mono_reps[b, i] = hidden_states[b][token_mask].mean(dim=0)
mono_mask[b, i] = True
# Apply attention over monosaccharide representations
# Convert mask for attention: True = valid, need to invert for PyTorch
key_padding_mask = ~mono_mask # True = ignore
mono_out, _ = self.mono_attention(
mono_reps, mono_reps, mono_reps,
key_padding_mask=key_padding_mask
)
mono_out = self.mono_norm(mono_reps + mono_out)
# Aggregate to single glycan representation using learned query
glycan_query = self.glycan_query.expand(batch_size, -1, -1)
glycan_out, _ = self.glycan_attention(
glycan_query, mono_out, mono_out,
key_padding_mask=key_padding_mask
)
glycan_out = self.glycan_norm(glycan_query + glycan_out)
return glycan_out.squeeze(1) # (batch, hidden)
# =============================================================================
# Improvement #2: Residue Type Embeddings
# =============================================================================
# Common monosaccharide types vocabulary
MONOSACCHARIDE_VOCAB = {
'[PAD_MONO]': 0, '[UNK_MONO]': 1,
'Glc': 2, 'GlcNAc': 3, 'GlcA': 4, 'GlcN': 5,
'Gal': 6, 'GalNAc': 7, 'GalA': 8, 'GalN': 9,
'Man': 10, 'ManNAc': 11, 'ManA': 12, 'ManN': 13,
'Fuc': 14, 'Rha': 15, 'Xyl': 16, 'Ara': 17,
'Neu5Ac': 18, 'Neu5Gc': 19, 'Kdn': 20, 'Sia': 21,
'GalNAcA': 22, 'GlcNAcA': 23, 'IdoA': 24, 'GulA': 25,
'Rib': 26, 'Lyx': 27, 'All': 28, 'Alt': 29,
'Tal': 30, 'Ido': 31, 'Qui': 32, 'Oli': 33,
'Tyv': 34, 'Abe': 35, 'Par': 36, 'Dig': 37,
'Col': 38, 'Dha': 39, 'Kdo': 40, 'Hep': 41,
'NeuroGc': 42, 'Muramic': 43, 'LDManHep': 44, 'DDManHep': 45,
'Bac': 46, 'Pse': 47, 'Leg': 48, 'Aci': 49,
'6dTal': 50, 'Fru': 51, 'Tag': 52, 'Sor': 53,
'Psi': 54, 'Sed': 55, 'MurNAc': 56, 'MurNGc': 57,
'Api': 58, 'Erwiniose': 59, 'Yer': 60, 'Thre': 61,
# Add more as needed, up to ~70
}
class ResidueTypeEmbeddings(nn.Module):
"""
Learnable embeddings for monosaccharide types.
Instead of the model having to learn that 'a1221m' = Fucose from character patterns,
we explicitly add a Fucose embedding to all tokens belonging to that residue.
"""
def __init__(self, hidden_size: int, num_mono_types: int = 70):
super().__init__()
self.mono_embeddings = nn.Embedding(num_mono_types, hidden_size)
self.mono_vocab = MONOSACCHARIDE_VOCAB
self.hidden_size = hidden_size
def forward(
self,
token_embeddings: torch.Tensor, # (batch, seq_len, hidden)
residue_ids: torch.Tensor, # (batch, seq_len)
mono_type_ids: torch.Tensor = None, # (batch, max_residues) - monosaccharide type per residue
) -> torch.Tensor:
"""
Add residue type embeddings to token embeddings.
Args:
token_embeddings: Base token embeddings
residue_ids: Which residue each token belongs to (-1 for special tokens)
mono_type_ids: Monosaccharide type ID for each residue position
Returns:
Enhanced embeddings with residue type information
"""
if mono_type_ids is None:
return token_embeddings
batch_size, seq_len, _ = token_embeddings.shape
enhanced = token_embeddings.clone()
# Add mono type embedding to each token based on its residue
for b in range(batch_size):
for pos in range(seq_len):
rid = residue_ids[b, pos].item()
if rid >= 0 and rid < mono_type_ids.size(1):
mono_id = mono_type_ids[b, rid]
enhanced[b, pos] = enhanced[b, pos] + self.mono_embeddings(mono_id)
return enhanced
@staticmethod
def get_mono_type_id(mono_name: str) -> int:
"""Convert monosaccharide name to type ID."""
return MONOSACCHARIDE_VOCAB.get(mono_name, MONOSACCHARIDE_VOCAB['[UNK_MONO]'])
# =============================================================================
# Improvement #4: Relative Position Encoding for Glycan Trees
# =============================================================================
class RelativePositionBias(nn.Module):
"""
Compute relative position bias for attention based on residue IDs.
Tokens in the same residue get distance 0.
Tokens in adjacent residues get distance ±1.
This helps the model understand glycan tree structure.
"""
def __init__(self, num_heads: int, max_distance: int = 10):
super().__init__()
self.num_heads = num_heads
self.max_distance = max_distance
# Learnable bias for each relative distance (-max to +max)
num_distances = 2 * max_distance + 1
self.relative_bias = nn.Embedding(num_distances, num_heads)
def forward(self, residue_ids: torch.Tensor) -> torch.Tensor:
"""
Compute relative position bias.
Args:
residue_ids: (batch, seq_len)
Returns:
Bias to add to attention scores: (batch, num_heads, seq_len, seq_len)
"""
# Compute pairwise residue distances
# (batch, seq_len, 1) - (batch, 1, seq_len) = (batch, seq_len, seq_len)
distance = residue_ids.unsqueeze(2) - residue_ids.unsqueeze(1)
# Clamp to max distance range and shift to 0-indexed
distance_clamped = distance.clamp(-self.max_distance, self.max_distance)
distance_idx = distance_clamped + self.max_distance # Now 0 to 2*max_distance
# Look up bias: (batch, seq_len, seq_len, num_heads)
bias = self.relative_bias(distance_idx)
# Transpose to (batch, num_heads, seq_len, seq_len)
bias = bias.permute(0, 3, 1, 2)
return bias
class CrossAttentionLayer(nn.Module):
"""
Cross-attention layer for sequence-structure alignment.
Allows sequence tokens to attend to structural atoms using attention masks.
"""
def __init__(self, config: MultimodalGlycanBERTConfig):
super().__init__()
self.num_heads = config.cross_attn_num_heads
self.hidden_size = config.seq_hidden_size
self.head_dim = self.hidden_size // self.num_heads
assert self.hidden_size % self.num_heads == 0, "hidden_size must be divisible by num_heads"
# Query from sequence, Key/Value from structure (VQ-VAE tokens)
self.query = nn.Linear(config.seq_hidden_size, self.hidden_size)
self.key = nn.Linear(config.struct_hidden_size, self.hidden_size)
self.value = nn.Linear(config.struct_hidden_size, self.hidden_size)
self.output = nn.Linear(self.hidden_size, config.seq_hidden_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.layer_norm = nn.LayerNorm(config.seq_hidden_size, eps=config.layer_norm_eps)
def forward(
self,
seq_hidden: torch.Tensor, # (batch, seq_len, seq_hidden)
struct_hidden: torch.Tensor, # (batch, struct_len, struct_hidden)
attention_mask: Optional[torch.Tensor] = None, # (batch, seq_len, struct_len)
) -> torch.Tensor:
"""
Apply cross-attention from sequence to structure.
Args:
seq_hidden: Sequence hidden states
struct_hidden: Structure hidden states
attention_mask: Boolean mask (True = can attend, False = cannot attend)
Returns:
Updated sequence hidden states
"""
batch_size, seq_len, _ = seq_hidden.shape
struct_len = struct_hidden.shape[1]
# Project to Q, K, V
Q = self.query(seq_hidden) # (batch, seq_len, hidden)
K = self.key(struct_hidden) # (batch, struct_len, hidden)
V = self.value(struct_hidden) # (batch, struct_len, hidden)
# Reshape for multi-head attention
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch, heads, seq_len, head_dim)
K = K.view(batch_size, struct_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch, heads, struct_len, head_dim)
V = V.view(batch_size, struct_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch, heads, struct_len, head_dim)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) # (batch, heads, seq_len, struct_len)
# Apply attention mask
if attention_mask is not None:
# attention_mask: (batch, seq_len, struct_len) -> (batch, 1, seq_len, struct_len)
attention_mask = attention_mask.unsqueeze(1)
# Convert boolean mask to float: True -> 0.0, False -> -10000.0
attention_mask = (~attention_mask).float() * -10000.0
scores = scores + attention_mask
# Softmax and dropout
attn_weights = torch.softmax(scores, dim=-1) # (batch, heads, seq_len, struct_len)
attn_weights = self.dropout(attn_weights)
# Apply attention to values
context = torch.matmul(attn_weights, V) # (batch, heads, seq_len, head_dim)
# Reshape back
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
# Output projection
output = self.output(context)
output = self.dropout(output)
# Residual connection + layer norm
output = self.layer_norm(seq_hidden + output)
return output
class MultimodalGlycanBERT(nn.Module):
"""
BERTose model for glycan representation learning.
Architecture:
1. Separate encoders for each modality (sequence, MS, 3D structure)
2. Cross-attention for sequence-structure alignment
3. Modality-specific MLM heads
4. Fusion layer for combined representation
"""
def __init__(self, config: MultimodalGlycanBERTConfig):
super().__init__()
self.config = config
# ===== Sequence Encoder =====
seq_config = config.to_seq_config()
seq_config.cnn_kernel_size = config.cnn_kernel_size
if config.use_cnn_frontend:
print(f"Enabled convolutional front-end (kernel={config.cnn_kernel_size})")
self.seq_embeddings = ConvGlycanBERTEmbeddings(seq_config)
else:
self.seq_embeddings = GlycanBERTEmbeddings(seq_config)
self.seq_layers = nn.ModuleList([GlycanBERTLayer(seq_config) for _ in range(seq_config.num_hidden_layers)])
self.seq_mlm_head = nn.Linear(seq_config.hidden_size, seq_config.vocab_size)
# ===== MS Encoder =====
ms_config = config.to_ms_config()
self.ms_embeddings = GlycanBERTEmbeddings(ms_config)
self.ms_layers = nn.ModuleList([GlycanBERTLayer(ms_config) for _ in range(ms_config.num_hidden_layers)])
self.ms_mlm_head = nn.Linear(ms_config.hidden_size, ms_config.vocab_size)
# ===== Structure Encoder (VQ-VAE tokens) =====
if config.use_3d:
struct_config = config.to_struct_config()
self.struct_embeddings = GlycanBERTEmbeddings(struct_config)
self.struct_layers = nn.ModuleList([GlycanBERTLayer(struct_config) for _ in range(struct_config.num_hidden_layers)])
self.struct_mlm_head = nn.Linear(struct_config.hidden_size, struct_config.vocab_size)
# Cross-attention layer (sequence → VQ-VAE structural tokens)
if config.use_cross_attention:
self.cross_attention = CrossAttentionLayer(config)
# ===== Projection layers (align hidden sizes) =====
if config.ms_hidden_size != config.seq_hidden_size:
self.ms_projection = nn.Linear(config.ms_hidden_size, config.seq_hidden_size)
else:
self.ms_projection = nn.Identity()
if config.use_3d and config.struct_hidden_size != config.seq_hidden_size:
self.struct_projection = nn.Linear(config.struct_hidden_size, config.seq_hidden_size)
else:
self.struct_projection = nn.Identity()
# ===== Fusion Layer =====
# Concatenate seq + ms + struct
fusion_input_size = config.seq_hidden_size * (3 if config.use_3d else 2)
self.fusion_layer = nn.Sequential(
nn.Linear(fusion_input_size, config.fusion_hidden_size),
nn.LayerNorm(config.fusion_hidden_size, eps=config.layer_norm_eps),
nn.GELU(),
nn.Dropout(config.hidden_dropout_prob),
nn.Linear(config.fusion_hidden_size, config.fusion_hidden_size),
)
# ===== Distance Prediction Head (Topology) =====
# Project down to 128 dimensions first to reduce memory use.
# (Batch, 256, 256, 768) -> (Batch, 256, 256, 128) reduces memory by 6x
self.dist_proj = nn.Linear(config.seq_hidden_size, 128)
self.distance_head = nn.Sequential(
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(
self,
seq_token_ids: torch.Tensor,
seq_attention_mask: torch.Tensor,
seq_residue_ids: torch.Tensor,
seq_branch_depths: Optional[torch.Tensor] = None,
seq_linkage_types: Optional[torch.Tensor] = None,
ms_token_ids: torch.Tensor = None,
ms_attention_mask: torch.Tensor = None,
has_ms: torch.Tensor = None,
struct_token_ids: Optional[torch.Tensor] = None,
struct_attention_mask: Optional[torch.Tensor] = None,
struct_residue_ids: Optional[torch.Tensor] = None,
has_3d: Optional[torch.Tensor] = None,
seq_labels: Optional[torch.Tensor] = None,
ms_labels: Optional[torch.Tensor] = None,
struct_labels: Optional[torch.Tensor] = None,
dist_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Dict[str, torch.Tensor]:
"""
Forward pass for BERTose.
Args:
seq_token_ids: (batch_size, seq_len) - Sequence token IDs
seq_attention_mask: (batch_size, seq_len) - Sequence attention mask
seq_residue_ids: (batch_size, seq_len) - Sequence token residue IDs
ms_token_ids: (batch_size, ms_len) - MS token IDs
ms_attention_mask: (batch_size, ms_len) - MS attention mask
has_ms: (batch_size,) - Boolean mask for samples with MS data
struct_token_ids: (batch_size, struct_len) - Structure VQ-VAE token IDs (optional)
struct_attention_mask: (batch_size, struct_len) - Structure attention mask (optional)
struct_residue_ids: (batch_size, struct_len) - Structure token residue IDs (optional)
has_3d: (batch_size,) - Boolean mask for samples with 3D data (optional)
seq_labels: (batch_size, seq_len) - Masked sequence labels (optional)
ms_labels: (batch_size, ms_len) - Masked MS labels (optional)
struct_labels: (batch_size, struct_len) - Masked structure labels (optional)
return_dict: Whether to return dict or tuple
Returns:
Dictionary containing logits, hidden states, losses, etc.
"""
batch_size = seq_token_ids.shape[0]
device = seq_token_ids.device
# ===== Sequence Encoder =====
# Pass branch_depths and linkage_types to embeddings for tree-aware encoding
seq_hidden = self.seq_embeddings(seq_token_ids, seq_branch_depths, seq_linkage_types)
for layer in self.seq_layers:
seq_hidden = layer(seq_hidden, seq_attention_mask)
seq_pooled = seq_hidden[:, 0, :] # [CLS] token
seq_logits = self.seq_mlm_head(seq_hidden)
# ===== Distance Predictions (Topology) =====
# Compute pairwise distance predictions
# MEMORY OPTIMIZATION: Project to 128-dim first
seq_hidden_small = self.dist_proj(seq_hidden) # (batch, seq_len, 128)
# Expand for pairwise: (batch, seq_len, 1, 128) - (batch, 1, seq_len, 128)
h_i = seq_hidden_small.unsqueeze(2)
h_j = seq_hidden_small.unsqueeze(1)
h_diff = torch.abs(h_i - h_j) # (batch, seq_len, seq_len, 128) - Much smaller!
dist_predictions = self.distance_head(h_diff) # (batch, seq_len, seq_len, 1)
# ===== MS Encoder =====
ms_hidden = None
ms_pooled = None
ms_logits = None
if ms_token_ids is not None:
ms_hidden = self.ms_embeddings(ms_token_ids)
for layer in self.ms_layers:
ms_hidden = layer(ms_hidden, ms_attention_mask)
ms_pooled = ms_hidden[:, 0, :] # [CLS] token
ms_logits = self.ms_mlm_head(ms_hidden)
# Zero out MS representations for samples without MS data
if has_ms is not None:
has_ms_expanded = has_ms.unsqueeze(1).float() # (batch, 1)
ms_pooled = ms_pooled * has_ms_expanded
# ===== Structure Encoder =====
struct_pooled = None
struct_logits = None
struct_hidden = None
if self.config.use_3d and struct_token_ids is not None:
struct_hidden = self.struct_embeddings(struct_token_ids)
for layer in self.struct_layers:
struct_hidden = layer(struct_hidden, struct_attention_mask)
struct_pooled = struct_hidden[:, 0, :] # [CLS] token
struct_logits = self.struct_mlm_head(struct_hidden)
# Zero out structure representations for samples without 3D data
if has_3d is not None:
has_3d_expanded = has_3d.unsqueeze(1).float() # (batch, 1)
struct_pooled = struct_pooled * has_3d_expanded
# ===== Cross-Attention (Sequence → VQ-VAE Structural Tokens) =====
# Use residue-level alignment between WURCS tokens and VQ-VAE tokens
if self.config.use_cross_attention and struct_residue_ids is not None:
# Create residue-level mask
# WURCS token with residue_id=0 → VQ-VAE tokens with residue_id=0
residue_mask = create_residue_level_mask(
seq_residue_ids=seq_residue_ids,
struct_residue_ids=struct_residue_ids,
) # (batch, N_seq, N_struct)
# Apply cross-attention: sequence tokens attend to VQ-VAE tokens
seq_hidden = self.cross_attention(
seq_hidden=seq_hidden,
struct_hidden=struct_hidden, # VQ-VAE token features
attention_mask=residue_mask, # Residue-based mask
)
# Update seq_pooled after cross-attention
seq_pooled = seq_hidden[:, 0, :]
# ===== Fusion =====
# Project to common hidden size
ms_pooled_projected = self.ms_projection(ms_pooled)
if self.config.use_3d and struct_pooled is not None:
struct_pooled_projected = self.struct_projection(struct_pooled)
combined = torch.cat([seq_pooled, ms_pooled_projected, struct_pooled_projected], dim=-1)
else:
combined = torch.cat([seq_pooled, ms_pooled_projected], dim=-1)
fused_repr = self.fusion_layer(combined)
# ===== Compute Losses =====
total_loss = None
seq_loss = None
ms_loss = None
struct_loss = None
dist_loss = None
if seq_labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
seq_loss = loss_fct(
seq_logits.view(-1, self.config.seq_vocab_size),
seq_labels.view(-1)
)
if ms_labels is not None:
ms_labels_masked = ms_labels.clone()
ms_labels_masked[~has_ms] = -100
# Only compute loss if there are valid labels (not all -100)
if (ms_labels_masked != -100).any():
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
ms_loss = loss_fct(
ms_logits.view(-1, self.config.ms_total_vocab_size),
ms_labels_masked.view(-1)
)
else:
ms_loss = torch.tensor(0.0, device=seq_token_ids.device)
if self.config.use_3d and struct_labels is not None and struct_logits is not None:
struct_labels_masked = struct_labels.clone()
if has_3d is not None:
struct_labels_masked[~has_3d] = -100
# Only compute loss if there are valid labels (not all -100)
if (struct_labels_masked != -100).any():
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
struct_loss = loss_fct(
struct_logits.view(-1, self.config.struct_vocab_size),
struct_labels_masked.view(-1)
)
else:
struct_loss = torch.tensor(0.0, device=seq_token_ids.device)
# ===== Distance Loss (Topology) =====
if dist_labels is not None:
# dist_predictions: (Batch, Seq, Seq, 1) -> (Batch, Seq, Seq)
preds = dist_predictions.squeeze(-1)
# Create mask for valid distance pairs (label != -1)
# Also respect attention mask to avoid padding
valid_mask = (dist_labels != -1) & (seq_attention_mask.unsqueeze(1) * seq_attention_mask.unsqueeze(2) == 1)
# DEBUG: Print once
if not hasattr(self, '_dist_debug_printed'):
print(f"[DIST DEBUG] dist_labels shape: {dist_labels.shape}, valid_mask.sum: {valid_mask.sum().item()}")
self._dist_debug_printed = True
if valid_mask.sum() > 0:
# MSE loss on valid positions only
loss_fct = nn.MSELoss()
dist_loss = loss_fct(preds[valid_mask], dist_labels[valid_mask].float())
else:
dist_loss = torch.tensor(0.0, device=seq_token_ids.device)
else:
# DEBUG: dist_labels is None
if not hasattr(self, '_dist_none_printed'):
print("[DIST DEBUG] dist_labels is None!")
self._dist_none_printed = True
# Weighted combination
losses = []
if seq_loss is not None:
losses.append(self.config.seq_loss_weight * seq_loss)
if ms_loss is not None:
losses.append(self.config.ms_loss_weight * ms_loss)
if struct_loss is not None:
losses.append(self.config.struct_loss_weight * struct_loss)
if dist_loss is not None:
losses.append(self.config.dist_loss_weight * dist_loss)
if losses:
total_loss = sum(losses)
if return_dict:
return {
'loss': total_loss,
'seq_loss': seq_loss,
'ms_loss': ms_loss,
'struct_loss': struct_loss,
'dist_loss': dist_loss,
'seq_logits': seq_logits,
'ms_logits': ms_logits,
'struct_logits': struct_logits,
'dist_predictions': dist_predictions,
'seq_hidden': seq_hidden,
'ms_hidden': ms_hidden,
'struct_hidden': struct_hidden,
'seq_pooled': seq_pooled,
'ms_pooled': ms_pooled,
'struct_pooled': struct_pooled,
'fused_repr': fused_repr,
}
else:
return (total_loss, seq_logits, ms_logits, struct_logits, fused_repr)
def get_multimodal_representation(
self,
seq_token_ids: torch.Tensor,
seq_attention_mask: torch.Tensor,
seq_residue_ids: torch.Tensor,
ms_token_ids: torch.Tensor,
ms_attention_mask: torch.Tensor,
has_ms: torch.Tensor,
struct_token_ids: Optional[torch.Tensor] = None,
struct_attention_mask: Optional[torch.Tensor] = None,
struct_residue_ids: Optional[torch.Tensor] = None,
has_3d: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Get fused multimodal representation (for inference)."""
outputs = self.forward(
seq_token_ids=seq_token_ids,
seq_attention_mask=seq_attention_mask,
seq_residue_ids=seq_residue_ids,
ms_token_ids=ms_token_ids,
ms_attention_mask=ms_attention_mask,
has_ms=has_ms,
struct_token_ids=struct_token_ids,
struct_attention_mask=struct_attention_mask,
struct_residue_ids=struct_residue_ids,
has_3d=has_3d,
return_dict=True,
)
return outputs['fused_repr']
if __name__ == "__main__":
# Test the model
print("="*80)
print("Testing BERTose model")
print("="*80)
# Create config
config = MultimodalGlycanBERTConfig(
seq_vocab_size=166,
seq_hidden_size=768,
seq_num_layers=12,
seq_num_heads=12,
ms_vocab_size=242,
ms_hidden_size=384,
ms_num_layers=6,
ms_num_heads=6,
struct_vocab_size=1024,
struct_hidden_size=512,
struct_num_layers=8,
struct_num_heads=8,
use_3d=True,
use_cross_attention=True,
seq_loss_weight=0.60,
ms_loss_weight=0.15,
struct_loss_weight=0.25,
)
print(f"\nConfig:")
print(f" Sequence vocab: {config.seq_vocab_size}")
print(f" MS vocab: {config.ms_vocab_size}")
print(f" Structure vocab: {config.struct_vocab_size}")
print(f" Loss weights: seq={config.seq_loss_weight}, ms={config.ms_loss_weight}, struct={config.struct_loss_weight}")
# Create model
model = MultimodalGlycanBERT(config)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel Parameters:")
print(f" Total: {total_params:,}")
print(f" Trainable: {trainable_params:,}")
# Test forward pass
print(f"\n{'='*80}")
print("Testing Forward Pass (with Conv front-end)")
print("="*80)
batch_size = 4
seq_len = 128
ms_len = 50
struct_len = 40
# Create dummy inputs
seq_token_ids = torch.randint(0, config.seq_vocab_size, (batch_size, seq_len))
seq_attention_mask = torch.ones(batch_size, seq_len)
# Approximate: ~5 tokens per residue
seq_residue_ids = torch.div(
torch.arange(seq_len), 5, rounding_mode="floor"
).unsqueeze(0).expand(batch_size, -1)
ms_token_ids = torch.randint(config.ms_vocab_offset, config.ms_total_vocab_size, (batch_size, ms_len))
ms_attention_mask = torch.ones(batch_size, ms_len)
struct_token_ids = torch.randint(0, config.struct_vocab_size, (batch_size, struct_len))
struct_attention_mask = torch.ones(batch_size, struct_len)
# Approximate: 4 tokens per residue for VQ-VAE tokens
struct_residue_ids = torch.div(
torch.arange(struct_len), 4, rounding_mode="floor"
).unsqueeze(0).expand(batch_size, -1)
has_ms = torch.tensor([True, True, False, True])
has_3d = torch.tensor([True, False, True, True])
# Create labels for MLM
seq_labels = seq_token_ids.clone()
seq_labels[seq_labels != config.mask_token_id] = -100
ms_labels = ms_token_ids.clone()
ms_labels[ms_labels != config.mask_token_id] = -100
struct_labels = struct_token_ids.clone()
struct_labels[struct_labels != config.mask_token_id] = -100
# Forward pass
outputs = model(
seq_token_ids=seq_token_ids,
seq_attention_mask=seq_attention_mask,
seq_residue_ids=seq_residue_ids,
ms_token_ids=ms_token_ids,
ms_attention_mask=ms_attention_mask,
has_ms=has_ms,
struct_token_ids=struct_token_ids,
struct_attention_mask=struct_attention_mask,
struct_residue_ids=struct_residue_ids,
has_3d=has_3d,
seq_labels=seq_labels,
ms_labels=ms_labels,
struct_labels=struct_labels,
)
print(f"\nOutput shapes:")
print(f" seq_logits: {outputs['seq_logits'].shape}")
print(f" ms_logits: {outputs['ms_logits'].shape}")
print(f" struct_logits: {outputs['struct_logits'].shape}")
print(f" fused_repr: {outputs['fused_repr'].shape}")
print(f"\nLosses:")
print(f" Total loss: {outputs['loss'].item():.4f}")
print(f" Sequence loss: {outputs['seq_loss'].item():.4f}")
print(f" MS loss: {outputs['ms_loss'].item():.4f}")
print(f" Structure loss: {outputs['struct_loss'].item():.4f}")
print(f"\n{'='*80}")
print("Model Test Complete!")
print("="*80)