import torch.nn as nn from transformers import XLMRobertaModel from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaPreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput class Smish(nn.Module): def __init__(self): super().__init__() def forward(self, x): return x * (x.sigmoid() + 1).log().tanh() class NoRefER(XLMRobertaPreTrainedModel): def __init__(self, config): super().__init__(config) hidden_size = 32 self.config = config self.roberta = XLMRobertaModel(config) self.dense = nn.Sequential( nn.Dropout(config.hidden_dropout_prob), nn.Linear(config.hidden_size, hidden_size, bias = False), nn.Dropout(config.hidden_dropout_prob), Smish(), nn.Linear(hidden_size, 1, bias = False) ) self.post_init() def forward(self, positive_input_ids, positive_attention_mask, negative_input_ids, negative_attention_mask, labels, weight=None): # positive processing positive_inputs = { "input_ids": positive_input_ids #, "attention_mask": positive_attention_mask } positive = self.dense(self.roberta(**positive_inputs).pooler_output).squeeze(-1) # negative processing negative_inputs = { "input_ids": negative_input_ids #, "attention_mask": negative_attention_mask } negative = self.dense(self.roberta(**negative_inputs).pooler_output).squeeze(-1) if weight is None: bce = nn.BCEWithLogitsLoss() else: bs = len(positive) weights = (weight.float() * bs) / weight.sum() bce = nn.BCEWithLogitsLoss(weight = weights) loss = bce(positive - negative, labels.float()) return SequenceClassifierOutput( loss=loss, logits=positive.sigmoid()-negative.sigmoid(), ) def score( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, ): h = self.roberta(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states,).pooler_output return self.dense(h).sigmoid().squeeze(-1)