__author__ = "Yifan Zhang (yzhang@hbku.edu.qa)" __copyright__ = "Copyright (C) 2021, Qatar Computing Research Institute, HBKU, Doha" from dataclasses import dataclass from typing import Optional, Tuple import torch from torch import nn from torch.nn.functional import sigmoid from transformers import BertPreTrainedModel, BertModel from transformers.file_utils import ModelOutput TOKEN_TAGS = ( "", "O", "Name_Calling,Labeling", "Repetition", "Slogans", "Appeal_to_fear-prejudice", "Doubt", "Exaggeration,Minimisation", "Flag-Waving", "Loaded_Language", "Reductio_ad_hitlerum", "Bandwagon", "Causal_Oversimplification", "Obfuscation,Intentional_Vagueness,Confusion", "Appeal_to_Authority", "Black-and-White_Fallacy", "Thought-terminating_Cliches", "Red_Herring", "Straw_Men", "Whataboutism" ) SEQUENCE_TAGS = ("Non-prop", "Prop") @dataclass class TokenAndSequenceJointClassifierOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None token_logits: torch.FloatTensor = None sequence_logits: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None class BertForTokenAndSequenceJointClassification(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_token_labels = 20 self.num_sequence_labels = 2 self.token_tags = TOKEN_TAGS self.sequence_tags = SEQUENCE_TAGS self.alpha = 0.9 self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.ModuleList([ nn.Linear(config.hidden_size, self.num_token_labels), nn.Linear(config.hidden_size, self.num_sequence_labels), ]) self.masking_gate = nn.Linear(2, 1) self.init_weights() self.merge_classifier_1 = nn.Linear(self.num_token_labels + self.num_sequence_labels, self.num_token_labels) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=True, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) sequence_output = outputs[0] pooler_output = outputs[1] sequence_output = self.dropout(sequence_output) token_logits = self.classifier[0](sequence_output) pooler_output = self.dropout(pooler_output) sequence_logits = self.classifier[1](pooler_output) gate = torch.sigmoid(self.masking_gate(sequence_logits)) gates = gate.unsqueeze(1).repeat(1, token_logits.size()[1], token_logits.size()[2]) weighted_token_logits = torch.mul(gates, token_logits) logits = [weighted_token_logits, sequence_logits] loss = None if labels is not None: criterion = nn.CrossEntropyLoss(ignore_index=0) binary_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([3932/14263]).cuda()) loss_fct = CrossEntropyLoss() weighted_token_logits = weighted_token_logits.view(-1, weighted_token_logits.shape[-1]) sequence_logits = sequence_logits.view(-1, sequence_logits.shape[-1]) token_loss = criterion(weighted_token_logits, labels) sequence_label = torch.LongTensor([1] if any([label > 0 for label in labels]) else [0]) sequence_loss = binary_criterion(sequence_logits, sequence_label) loss = self.alpha*loss[0] + (1-self.alpha)*loss[1] if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenAndSequenceJointClassifierOutput( loss=loss, token_logits=weighted_token_logits, sequence_logits=sequence_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )