SimpleAES / models /model.py
SFM2001's picture
upload files
4f591e5
import torch
import torch.nn as nn
from transformers import LongformerPreTrainedModel, LongformerModel
import torch.nn.functional as F
class AttentionPooling(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.hidden_dim = hidden_dim
self.query = nn.Linear(hidden_dim, hidden_dim)
self.energy = nn.Linear(hidden_dim, 1)
# Initialize weights
nn.init.xavier_uniform_(self.query.weight)
nn.init.xavier_uniform_(self.energy.weight)
self.query.bias.data.zero_()
self.energy.bias.data.zero_()
def forward(self, hidden_states, attention_mask=None):
# Compute attention scores
transformed = torch.tanh(self.query(hidden_states)) # (batch_size, seq_len, hidden_dim)
scores = self.energy(transformed).squeeze(-1) # (batch_size, seq_len)
# Apply attention mask if provided
if attention_mask is not None:
scores = scores.masked_fill(attention_mask == 0, float('-inf'))
# Compute attention weights
weights = F.softmax(scores, dim=-1) # (batch_size, seq_len)
# Apply attention pooling
pooled = torch.sum(hidden_states * weights.unsqueeze(-1), dim=1) # (batch_size, hidden_dim)
return pooled
class CustomLongformerForSequenceClassification(LongformerPreTrainedModel):
"""Longformer model with attention pooling for sequence classification.
Uses attention pooling over the last four hidden layers instead of CLS token pooling.
"""
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
# Longformer backbone
self.longformer = LongformerModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# Attention pooling for each layer
self.attention_poolers = nn.ModuleList([
AttentionPooling(config.hidden_size) for _ in range(4)
])
# Final classifier
self.classifier = nn.Linear(config.hidden_size * 4, config.num_labels)
# Initialize weights
self.post_init()
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
outputs = self.longformer(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs
)
# Get last four hidden layers
last_four_layers = outputs.hidden_states[-4:]
# Apply attention pooling to each layer
pooled = []
for layer, pooler in zip(last_four_layers, self.attention_poolers):
pooled.append(pooler(layer, attention_mask=attention_mask))
# Concatenate pooled representations
concatenated = torch.cat(pooled, dim=1)
concatenated = self.dropout(concatenated)
logits = self.classifier(concatenated)
# Compute loss if labels provided
loss = None
if labels is not None:
if hasattr(self, 'loss_fct'):
loss = self.loss_fct(logits, labels)
else:
loss = F.mse_loss(logits, labels.float())
return {'loss': loss, 'logits': logits}
class CustomLongformerForSequenceClassification(LongformerPreTrainedModel):
"""Longformer model with attention pooling for sequence classification."""
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
# Longformer backbone
self.longformer = LongformerModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# Attention pooling for each layer
self.attention_poolers = nn.ModuleList([
AttentionPooling(config.hidden_size) for _ in range(4)
])
# Final classifier
self.classifier = nn.Linear(config.hidden_size * 4, config.num_labels)
# Initialize weights
self.post_init()
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
outputs = self.longformer(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs
)
# Get last four hidden layers
last_four_layers = outputs.hidden_states[-4:]
# Apply attention pooling to each layer
pooled = []
for layer, pooler in zip(last_four_layers, self.attention_poolers):
pooled.append(pooler(layer, attention_mask=attention_mask))
# Concatenate pooled representations
concatenated = torch.cat(pooled, dim=1)
concatenated = self.dropout(concatenated)
logits = self.classifier(concatenated)
# Compute loss if labels provided
loss = None
if labels is not None:
if hasattr(self, 'loss_fct'):
loss = self.loss_fct(logits, labels)
else:
loss = F.mse_loss(logits.view(-1), labels.float().view(-1))
return {'loss': loss, 'logits': logits}