Spaces:
Runtime error
Runtime error
| # models/parallel_bert_deberta.py | |
| import torch | |
| import torch.nn as nn | |
| from transformers import BertModel, DebertaModel | |
| from config import DROPOUT_RATE, BERT_MODEL_NAME, DEBERTA_MODEL_NAME # Import model names | |
| class Attention(nn.Module): | |
| """ | |
| Simple Attention layer to compute a context vector from a sequence of hidden states. | |
| It learns a single weight for each hidden state in the sequence, then uses softmax | |
| to normalize these weights and compute a weighted sum of the hidden states. | |
| """ | |
| def __init__(self, hidden_size): | |
| """ | |
| Initializes the Attention layer. | |
| Args: | |
| hidden_size (int): The dimensionality of the input hidden states. | |
| """ | |
| super(Attention, self).__init__() | |
| # A linear layer to project the hidden state to a single scalar (attention score) | |
| self.attn = nn.Linear(hidden_size, 1) | |
| def forward(self, encoder_output): | |
| """ | |
| Performs the forward pass of the attention mechanism. | |
| Args: | |
| encoder_output (torch.Tensor): Tensor of hidden states from an encoder. | |
| Shape: (batch_size, sequence_length, hidden_size) | |
| Returns: | |
| torch.Tensor: The context vector, a weighted sum of the hidden states. | |
| Shape: (batch_size, hidden_size) | |
| """ | |
| # Calculate raw attention scores | |
| # self.attn(encoder_output) -> (batch_size, sequence_length, 1) | |
| # .squeeze(-1) removes the last dimension, making it (batch_size, sequence_length) | |
| attn_weights = torch.softmax(self.attn(encoder_output).squeeze(-1), dim=1) | |
| # Compute the context vector as a weighted sum of encoder_output. | |
| # attn_weights.unsqueeze(-1) adds a dimension for broadcasting: (batch_size, sequence_length, 1) | |
| # This allows element-wise multiplication with encoder_output. | |
| # torch.sum(..., dim=1) sums along the sequence_length dimension. | |
| context_vector = torch.sum(attn_weights.unsqueeze(-1) * encoder_output, dim=1) | |
| return context_vector | |
| class ParallelMultiOutputModel(nn.Module): | |
| """ | |
| Hybrid model that leverages both BERT and DeBERTa in parallel. | |
| It extracts features from both models, applies an attention mechanism to their outputs, | |
| projects these attended features to a common dimension, concatenates them, and then | |
| uses this combined representation for multi-output classification. | |
| """ | |
| # Statically set tokenizer name to BERT's for this combined model | |
| # (assuming BERT's tokenizer is compatible or primary for combined input) | |
| tokenizer_name = BERT_MODEL_NAME | |
| def __init__(self, num_labels): | |
| """ | |
| Initializes the ParallelMultiOutputModel. | |
| Args: | |
| num_labels (list): A list where each element is the number of classes | |
| for a corresponding label column. | |
| """ | |
| super(ParallelMultiOutputModel, self).__init__() | |
| # Load pre-trained BERT and DeBERTa models | |
| self.bert = BertModel.from_pretrained(BERT_MODEL_NAME) | |
| self.deberta = DebertaModel.from_pretrained(DEBERTA_MODEL_NAME) | |
| # Initialize attention layers for each backbone model | |
| self.attn_bert = Attention(self.bert.config.hidden_size) | |
| self.attn_deberta = Attention(self.deberta.config.hidden_size) | |
| # Projection layers to reduce dimensionality of the context vectors | |
| # before concatenation. This helps manage the combined feature size. | |
| self.proj_bert = nn.Linear(self.bert.config.hidden_size, 256) | |
| self.proj_deberta = nn.Linear(self.deberta.config.hidden_size, 256) | |
| self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout layer for regularization | |
| # Define classification heads. The input feature size is the sum of | |
| # the projected sizes from BERT and DeBERTa (256 + 256 = 512). | |
| self.classifiers = nn.ModuleList([ | |
| nn.Linear(512, n_classes) for n_classes in num_labels | |
| ]) | |
| def forward(self, input_ids, attention_mask): | |
| """ | |
| Performs the forward pass of the parallel model. | |
| Args: | |
| input_ids (torch.Tensor): Tensor of token IDs. | |
| attention_mask (torch.Tensor): Tensor indicating attention. | |
| Returns: | |
| list: A list of logit tensors, one for each classification head. | |
| """ | |
| # Get the last hidden states (sequence of hidden states for all tokens) | |
| # from both BERT and DeBERTa. These are typically used with attention. | |
| bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state | |
| deberta_output = self.deberta(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state | |
| # Apply attention to get a single context vector from each model's output | |
| context_bert = self.attn_bert(bert_output) | |
| context_deberta = self.attn_deberta(deberta_output) | |
| # Project the context vectors to their reduced dimensions | |
| reduced_bert = self.proj_bert(context_bert) | |
| reduced_deberta = self.proj_deberta(context_deberta) | |
| # Concatenate the reduced feature vectors from both models | |
| combined = torch.cat((reduced_bert, reduced_deberta), dim=1) | |
| combined = self.dropout(combined) # Apply dropout to the combined features | |
| # Pass the combined features through each classification head | |
| return [classifier(combined) for classifier in self.classifiers] |