|
import torch |
|
import torch.nn as nn |
|
from transformers import LongformerPreTrainedModel, LongformerModel |
|
import torch.nn.functional as F |
|
|
|
class AttentionPooling(nn.Module): |
|
def __init__(self, hidden_dim): |
|
super().__init__() |
|
self.hidden_dim = hidden_dim |
|
self.query = nn.Linear(hidden_dim, hidden_dim) |
|
self.energy = nn.Linear(hidden_dim, 1) |
|
|
|
|
|
nn.init.xavier_uniform_(self.query.weight) |
|
nn.init.xavier_uniform_(self.energy.weight) |
|
self.query.bias.data.zero_() |
|
self.energy.bias.data.zero_() |
|
|
|
def forward(self, hidden_states, attention_mask=None): |
|
|
|
transformed = torch.tanh(self.query(hidden_states)) |
|
scores = self.energy(transformed).squeeze(-1) |
|
|
|
|
|
if attention_mask is not None: |
|
scores = scores.masked_fill(attention_mask == 0, float('-inf')) |
|
|
|
|
|
weights = F.softmax(scores, dim=-1) |
|
|
|
|
|
pooled = torch.sum(hidden_states * weights.unsqueeze(-1), dim=1) |
|
return pooled |
|
|
|
|
|
|
|
class CustomLongformerForSequenceClassification(LongformerPreTrainedModel): |
|
"""Longformer model with attention pooling for sequence classification. |
|
|
|
Uses attention pooling over the last four hidden layers instead of CLS token pooling. |
|
""" |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.config = config |
|
|
|
|
|
self.longformer = LongformerModel(config) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
self.attention_poolers = nn.ModuleList([ |
|
AttentionPooling(config.hidden_size) for _ in range(4) |
|
]) |
|
|
|
|
|
self.classifier = nn.Linear(config.hidden_size * 4, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
|
outputs = self.longformer( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
output_hidden_states=True, |
|
**kwargs |
|
) |
|
|
|
|
|
last_four_layers = outputs.hidden_states[-4:] |
|
|
|
|
|
pooled = [] |
|
for layer, pooler in zip(last_four_layers, self.attention_poolers): |
|
pooled.append(pooler(layer, attention_mask=attention_mask)) |
|
|
|
|
|
concatenated = torch.cat(pooled, dim=1) |
|
concatenated = self.dropout(concatenated) |
|
logits = self.classifier(concatenated) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
if hasattr(self, 'loss_fct'): |
|
loss = self.loss_fct(logits, labels) |
|
else: |
|
loss = F.mse_loss(logits, labels.float()) |
|
|
|
return {'loss': loss, 'logits': logits} |
|
|
|
class CustomLongformerForSequenceClassification(LongformerPreTrainedModel): |
|
"""Longformer model with attention pooling for sequence classification.""" |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.config = config |
|
|
|
|
|
self.longformer = LongformerModel(config) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
self.attention_poolers = nn.ModuleList([ |
|
AttentionPooling(config.hidden_size) for _ in range(4) |
|
]) |
|
|
|
|
|
self.classifier = nn.Linear(config.hidden_size * 4, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
|
outputs = self.longformer( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
output_hidden_states=True, |
|
**kwargs |
|
) |
|
|
|
|
|
last_four_layers = outputs.hidden_states[-4:] |
|
|
|
|
|
pooled = [] |
|
for layer, pooler in zip(last_four_layers, self.attention_poolers): |
|
pooled.append(pooler(layer, attention_mask=attention_mask)) |
|
|
|
|
|
concatenated = torch.cat(pooled, dim=1) |
|
concatenated = self.dropout(concatenated) |
|
logits = self.classifier(concatenated) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
if hasattr(self, 'loss_fct'): |
|
loss = self.loss_fct(logits, labels) |
|
else: |
|
loss = F.mse_loss(logits.view(-1), labels.float().view(-1)) |
|
|
|
return {'loss': loss, 'logits': logits} |