|
|
""" |
|
|
Latent Attention Pooling implementation for LLM2Vec4CXR. |
|
|
Vendored to make the model self-contained (no external llm2vec dependency required). |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class LatentAttentionPooling(nn.Module): |
|
|
""" |
|
|
Latent attention pooling layer that uses a trainable latent dictionary |
|
|
to aggregate token embeddings into a fixed-size representation. |
|
|
""" |
|
|
|
|
|
def __init__(self, d_model, num_latents=512, num_heads=8): |
|
|
""" |
|
|
Args: |
|
|
d_model: Hidden size of the model (e.g., 2048 for Llama-7B) |
|
|
num_latents: Number of learnable latent vectors (default: 512) |
|
|
num_heads: Number of attention heads (default: 8) |
|
|
""" |
|
|
super().__init__() |
|
|
self.num_latents = num_latents |
|
|
self.d_model = d_model |
|
|
|
|
|
|
|
|
self.latents = nn.Parameter(torch.randn(num_latents, d_model)) |
|
|
|
|
|
|
|
|
|
|
|
self.multihead_attn = nn.MultiheadAttention( |
|
|
embed_dim=d_model, |
|
|
num_heads=num_heads, |
|
|
batch_first=True |
|
|
) |
|
|
|
|
|
|
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(d_model, d_model), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model, d_model) |
|
|
) |
|
|
|
|
|
def forward(self, hidden_states, attention_mask=None): |
|
|
""" |
|
|
Apply latent attention pooling to hidden states. |
|
|
|
|
|
Args: |
|
|
hidden_states: Token embeddings of shape (batch_size, seq_len, d_model) |
|
|
attention_mask: Optional mask of shape (batch_size, seq_len) |
|
|
|
|
|
Returns: |
|
|
Pooled embeddings of shape (batch_size, d_model) |
|
|
""" |
|
|
batch_size, seq_len, d_model = hidden_states.shape |
|
|
device = hidden_states.device |
|
|
|
|
|
|
|
|
if next(self.parameters()).device != device: |
|
|
self.to(device) |
|
|
|
|
|
|
|
|
latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_output, _ = self.multihead_attn( |
|
|
query=hidden_states, |
|
|
key=latents, |
|
|
value=latents |
|
|
) |
|
|
|
|
|
|
|
|
mlp_output = self.mlp(attn_output) |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
mask_expanded = attention_mask.unsqueeze(-1).expand(mlp_output.size()).float() |
|
|
sum_embeddings = torch.sum(mlp_output * mask_expanded, dim=1) |
|
|
sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9) |
|
|
pooled = sum_embeddings / sum_mask |
|
|
else: |
|
|
|
|
|
pooled = mlp_output.mean(dim=1) |
|
|
|
|
|
return pooled |
|
|
|
|
|
|