|
import torch |
|
import torch.nn as nn |
|
from transformers import BertConfig, BertModel |
|
|
|
|
|
class AttentionPool(nn.Module): |
|
"""Attention-based pooling layer.""" |
|
|
|
def __init__(self, hidden_size): |
|
super(AttentionPool, self).__init__() |
|
self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1)) |
|
nn.init.xavier_uniform_( |
|
self.attention_weights |
|
) |
|
|
|
def forward(self, hidden_states): |
|
attention_scores = torch.matmul(hidden_states, self.attention_weights) |
|
attention_scores = torch.softmax(attention_scores, dim=1) |
|
pooled_output = torch.sum(hidden_states * attention_scores, dim=1) |
|
return pooled_output |
|
|
|
|
|
class GeneformerMultiTask(nn.Module): |
|
def __init__( |
|
self, |
|
pretrained_path, |
|
num_labels_list, |
|
dropout_rate=0.1, |
|
use_task_weights=False, |
|
task_weights=None, |
|
max_layers_to_freeze=0, |
|
use_attention_pooling=False, |
|
): |
|
super(GeneformerMultiTask, self).__init__() |
|
self.config = BertConfig.from_pretrained(pretrained_path) |
|
self.bert = BertModel(self.config) |
|
self.num_labels_list = num_labels_list |
|
self.use_task_weights = use_task_weights |
|
self.dropout = nn.Dropout(dropout_rate) |
|
self.use_attention_pooling = use_attention_pooling |
|
|
|
if use_task_weights and ( |
|
task_weights is None or len(task_weights) != len(num_labels_list) |
|
): |
|
raise ValueError( |
|
"Task weights must be defined and match the number of tasks when 'use_task_weights' is True." |
|
) |
|
self.task_weights = ( |
|
task_weights if use_task_weights else [1.0] * len(num_labels_list) |
|
) |
|
|
|
|
|
for layer in self.bert.encoder.layer[:max_layers_to_freeze]: |
|
for param in layer.parameters(): |
|
param.requires_grad = False |
|
|
|
self.attention_pool = ( |
|
AttentionPool(self.config.hidden_size) if use_attention_pooling else None |
|
) |
|
|
|
self.classification_heads = nn.ModuleList( |
|
[ |
|
nn.Linear(self.config.hidden_size, num_labels) |
|
for num_labels in num_labels_list |
|
] |
|
) |
|
|
|
for head in self.classification_heads: |
|
nn.init.xavier_uniform_(head.weight) |
|
nn.init.zeros_(head.bias) |
|
|
|
def forward(self, input_ids, attention_mask, labels=None): |
|
try: |
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
|
except Exception as e: |
|
raise RuntimeError(f"Error during BERT forward pass: {e}") |
|
|
|
sequence_output = outputs.last_hidden_state |
|
|
|
try: |
|
pooled_output = ( |
|
self.attention_pool(sequence_output) |
|
if self.use_attention_pooling |
|
else sequence_output[:, 0, :] |
|
) |
|
pooled_output = self.dropout(pooled_output) |
|
except Exception as e: |
|
raise RuntimeError(f"Error during pooling and dropout: {e}") |
|
|
|
total_loss = 0 |
|
logits = [] |
|
losses = [] |
|
|
|
for task_id, (head, num_labels) in enumerate( |
|
zip(self.classification_heads, self.num_labels_list) |
|
): |
|
try: |
|
task_logits = head(pooled_output) |
|
except Exception as e: |
|
raise RuntimeError( |
|
f"Error during forward pass of classification head {task_id}: {e}" |
|
) |
|
|
|
logits.append(task_logits) |
|
|
|
if labels is not None: |
|
try: |
|
loss_fct = nn.CrossEntropyLoss() |
|
task_loss = loss_fct( |
|
task_logits.view(-1, num_labels), labels[task_id].view(-1) |
|
) |
|
if self.use_task_weights: |
|
task_loss *= self.task_weights[task_id] |
|
total_loss += task_loss |
|
losses.append(task_loss.item()) |
|
except Exception as e: |
|
raise RuntimeError( |
|
f"Error during loss computation for task {task_id}: {e}" |
|
) |
|
|
|
return total_loss, logits, losses if labels is not None else logits |
|
|