voting-ensemble / models /parallel_bert_deberta.py
namanpenguin's picture
Upload 10 files
5e6ef00 verified
# models/parallel_bert_deberta.py
import torch
import torch.nn as nn
from transformers import BertModel, DebertaModel
from config import DROPOUT_RATE, BERT_MODEL_NAME, DEBERTA_MODEL_NAME # Import model names
class Attention(nn.Module):
"""
Simple Attention layer to compute a context vector from a sequence of hidden states.
It learns a single weight for each hidden state in the sequence, then uses softmax
to normalize these weights and compute a weighted sum of the hidden states.
"""
def __init__(self, hidden_size):
"""
Initializes the Attention layer.
Args:
hidden_size (int): The dimensionality of the input hidden states.
"""
super(Attention, self).__init__()
# A linear layer to project the hidden state to a single scalar (attention score)
self.attn = nn.Linear(hidden_size, 1)
def forward(self, encoder_output):
"""
Performs the forward pass of the attention mechanism.
Args:
encoder_output (torch.Tensor): Tensor of hidden states from an encoder.
Shape: (batch_size, sequence_length, hidden_size)
Returns:
torch.Tensor: The context vector, a weighted sum of the hidden states.
Shape: (batch_size, hidden_size)
"""
# Calculate raw attention scores
# self.attn(encoder_output) -> (batch_size, sequence_length, 1)
# .squeeze(-1) removes the last dimension, making it (batch_size, sequence_length)
attn_weights = torch.softmax(self.attn(encoder_output).squeeze(-1), dim=1)
# Compute the context vector as a weighted sum of encoder_output.
# attn_weights.unsqueeze(-1) adds a dimension for broadcasting: (batch_size, sequence_length, 1)
# This allows element-wise multiplication with encoder_output.
# torch.sum(..., dim=1) sums along the sequence_length dimension.
context_vector = torch.sum(attn_weights.unsqueeze(-1) * encoder_output, dim=1)
return context_vector
class ParallelMultiOutputModel(nn.Module):
"""
Hybrid model that leverages both BERT and DeBERTa in parallel.
It extracts features from both models, applies an attention mechanism to their outputs,
projects these attended features to a common dimension, concatenates them, and then
uses this combined representation for multi-output classification.
"""
# Statically set tokenizer name to BERT's for this combined model
# (assuming BERT's tokenizer is compatible or primary for combined input)
tokenizer_name = BERT_MODEL_NAME
def __init__(self, num_labels):
"""
Initializes the ParallelMultiOutputModel.
Args:
num_labels (list): A list where each element is the number of classes
for a corresponding label column.
"""
super(ParallelMultiOutputModel, self).__init__()
# Load pre-trained BERT and DeBERTa models
self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
self.deberta = DebertaModel.from_pretrained(DEBERTA_MODEL_NAME)
# Initialize attention layers for each backbone model
self.attn_bert = Attention(self.bert.config.hidden_size)
self.attn_deberta = Attention(self.deberta.config.hidden_size)
# Projection layers to reduce dimensionality of the context vectors
# before concatenation. This helps manage the combined feature size.
self.proj_bert = nn.Linear(self.bert.config.hidden_size, 256)
self.proj_deberta = nn.Linear(self.deberta.config.hidden_size, 256)
self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout layer for regularization
# Define classification heads. The input feature size is the sum of
# the projected sizes from BERT and DeBERTa (256 + 256 = 512).
self.classifiers = nn.ModuleList([
nn.Linear(512, n_classes) for n_classes in num_labels
])
def forward(self, input_ids, attention_mask):
"""
Performs the forward pass of the parallel model.
Args:
input_ids (torch.Tensor): Tensor of token IDs.
attention_mask (torch.Tensor): Tensor indicating attention.
Returns:
list: A list of logit tensors, one for each classification head.
"""
# Get the last hidden states (sequence of hidden states for all tokens)
# from both BERT and DeBERTa. These are typically used with attention.
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
deberta_output = self.deberta(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
# Apply attention to get a single context vector from each model's output
context_bert = self.attn_bert(bert_output)
context_deberta = self.attn_deberta(deberta_output)
# Project the context vectors to their reduced dimensions
reduced_bert = self.proj_bert(context_bert)
reduced_deberta = self.proj_deberta(context_deberta)
# Concatenate the reduced feature vectors from both models
combined = torch.cat((reduced_bert, reduced_deberta), dim=1)
combined = self.dropout(combined) # Apply dropout to the combined features
# Pass the combined features through each classification head
return [classifier(combined) for classifier in self.classifiers]