import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from transformers import PreTrainedModel, AutoModel from transformers.modeling_outputs import ModelOutput from dataclasses import dataclass from typing import Optional from .configuration import MultiHeadConfig @dataclass class MultiHeadOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None doc_logits: torch.FloatTensor = None sent_logits: torch.FloatTensor = None hidden_states: Optional[torch.FloatTensor] = None attentions: Optional[torch.FloatTensor] = None class MultiHeadPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = MultiHeadConfig base_model_prefix = "multihead" supports_gradient_checkpointing = True class MultiHeadModel(MultiHeadPreTrainedModel): def __init__(self, config: MultiHeadConfig): super().__init__(config) self.encoder = AutoModel.from_pretrained(config.encoder_name) self.classifier_dropout = nn.Dropout(config.classifier_dropout) self.doc_classifier = nn.Linear(self.encoder.config.hidden_size, config.num_labels) self.sent_classifier = nn.Linear(self.encoder.config.hidden_size, config.num_labels) self.doc_attention = nn.Linear(self.encoder.config.hidden_size, 1) self.sent_attention = nn.Linear(self.encoder.config.hidden_size, 1) self.post_init() def attentive_pooling(self, hidden_states, mask, attention_layer, sentence_mode=False): if not sentence_mode: attention_scores = attention_layer(hidden_states).squeeze(-1) attention_scores = attention_scores.masked_fill(~mask, float("-inf")) attention_weights = torch.softmax(attention_scores, dim=1) pooled_output = torch.bmm(attention_weights.unsqueeze(1), hidden_states) return pooled_output.squeeze(1) else: batch_size, num_sentences, seq_len = mask.size() attention_scores = attention_layer(hidden_states).squeeze(-1).unsqueeze(1) attention_scores = attention_scores.expand(batch_size, num_sentences, seq_len) attention_scores = attention_scores.masked_fill(~mask, float("-inf")) attention_weights = torch.softmax(attention_scores, dim=2) pooled_output = torch.bmm(attention_weights, hidden_states) return pooled_output def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, document_labels=None, sentence_positions=None, sentence_labels=None, return_dict=True, **kwargs ): outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True, ) last_hidden_state = outputs.last_hidden_state doc_repr = self.attentive_pooling( hidden_states=last_hidden_state, mask=attention_mask.bool(), attention_layer=self.doc_attention, sentence_mode=False ) doc_repr = self.classifier_dropout(doc_repr) doc_logits = self.doc_classifier(doc_repr) batch_size, max_sents = sentence_positions.size() seq_len = attention_mask.size(1) valid_mask = (sentence_positions != -1) safe_positions = sentence_positions.masked_fill(~valid_mask, 0) sentence_tokens_mask = torch.zeros(batch_size, max_sents, seq_len, dtype=torch.bool, device=attention_mask.device) batch_idx = torch.arange(batch_size, device=input_ids.device).unsqueeze(1).unsqueeze(2) sentence_tokens_mask[batch_idx, torch.arange(max_sents).unsqueeze(0), safe_positions] = valid_mask sent_reprs = self.attentive_pooling( hidden_states=last_hidden_state, mask=sentence_tokens_mask, attention_layer=self.sent_attention, sentence_mode=True ) sent_reprs = self.classifier_dropout(sent_reprs) sent_logits = self.sent_classifier(sent_reprs) loss = None if document_labels is not None: doc_loss_fct = CrossEntropyLoss() doc_loss = doc_loss_fct(doc_logits, document_labels) if sentence_labels is not None: sent_loss_fct = CrossEntropyLoss(ignore_index=-100) sent_logits_flat = sent_logits.view(-1, sent_logits.size(-1)) sentence_labels_flat = sentence_labels.view(-1) sent_loss = sent_loss_fct(sent_logits_flat, sentence_labels_flat) loss = doc_loss + (2 * sent_loss) else: loss = doc_loss if not return_dict: return (loss, doc_logits, sent_logits) return MultiHeadOutput( loss=loss, doc_logits=doc_logits, sent_logits=sent_logits, hidden_states=outputs.hidden_states if hasattr(outputs, "hidden_states") else None, attentions=outputs.attentions if hasattr(outputs, "attentions") else None, )