from transformers.modeling_outputs import TokenClassifierOutput import torch import torch.nn as nn from transformers import PreTrainedModel, AutoModel, AutoConfig from torch.nn import CrossEntropyLoss from typing import Optional, Tuple, Union import logging from .configuration_stacked import ImpressoConfig logger = logging.getLogger(__name__) class ExtendedMultitaskModelForTokenClassification(PreTrainedModel): config_class = ImpressoConfig _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config, num_token_labels_dict): super().__init__(config) self.num_token_labels_dict = num_token_labels_dict self.config = config # self.bert = AutoModel.from_config(config) self.bert = AutoModel.from_pretrained( config.name_or_path, config=config.pretrained_config ) if "classifier_dropout" not in config.__dict__: classifier_dropout = 0.1 else: classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) # Additional transformer layers self.transformer_encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=config.hidden_size, nhead=config.num_attention_heads ), num_layers=2, ) # For token classification, create a classifier for each task self.token_classifiers = nn.ModuleDict( { task: nn.Linear(config.hidden_size, num_labels) for task, num_labels in num_token_labels_dict.items() } ) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, token_labels: Optional[dict] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" token_labels (`dict` of `torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*): Labels for computing the token classification loss. Keys should match the tasks. """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) bert_kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "position_ids": position_ids, "head_mask": head_mask, "inputs_embeds": inputs_embeds, "output_attentions": output_attentions, "output_hidden_states": output_hidden_states, "return_dict": return_dict, } if any( keyword in self.config.name_or_path.lower() for keyword in ["llama", "deberta"] ): bert_kwargs.pop("token_type_ids") bert_kwargs.pop("head_mask") outputs = self.bert(**bert_kwargs) # For token classification token_output = outputs[0] token_output = self.dropout(token_output) # Pass through additional transformer layers token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose( 0, 1 ) # Collect the logits and compute the loss for each task task_logits = {} total_loss = 0 for task, classifier in self.token_classifiers.items(): logits = classifier(token_output) task_logits[task] = logits if token_labels and task in token_labels: loss_fct = CrossEntropyLoss() loss = loss_fct( logits.view(-1, self.num_token_labels_dict[task]), token_labels[task].view(-1), ) total_loss += loss if not return_dict: output = (task_logits,) + outputs[2:] return ((total_loss,) + output) if total_loss != 0 else output return TokenClassifierOutput( loss=total_loss, logits=task_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )