|
import torch.nn as nn |
|
from transformers import XLMRobertaModel |
|
from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaPreTrainedModel |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
class Smish(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
def forward(self, x): |
|
return x * (x.sigmoid() + 1).log().tanh() |
|
|
|
class NoRefER(XLMRobertaPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
hidden_size = 32 |
|
self.config = config |
|
self.roberta = XLMRobertaModel(config) |
|
self.dense = nn.Sequential( |
|
nn.Dropout(config.hidden_dropout_prob), |
|
nn.Linear(config.hidden_size, hidden_size, bias = False), |
|
nn.Dropout(config.hidden_dropout_prob), Smish(), |
|
nn.Linear(hidden_size, 1, bias = False) |
|
) |
|
|
|
self.post_init() |
|
|
|
def forward(self, positive_input_ids, positive_attention_mask, negative_input_ids, negative_attention_mask, labels, weight=None): |
|
|
|
positive_inputs = { |
|
"input_ids": positive_input_ids |
|
} |
|
positive = self.dense(self.roberta(**positive_inputs).pooler_output).squeeze(-1) |
|
|
|
|
|
negative_inputs = { |
|
"input_ids": negative_input_ids |
|
} |
|
negative = self.dense(self.roberta(**negative_inputs).pooler_output).squeeze(-1) |
|
|
|
if weight is None: |
|
bce = nn.BCEWithLogitsLoss() |
|
else: |
|
bs = len(positive) |
|
weights = (weight.float() * bs) / weight.sum() |
|
bce = nn.BCEWithLogitsLoss(weight = weights) |
|
loss = bce(positive - negative, labels.float()) |
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=positive.sigmoid()-negative.sigmoid(), |
|
) |
|
|
|
def score( |
|
self, |
|
input_ids, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
labels=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
): |
|
h = self.roberta(input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states,).pooler_output |
|
|
|
return self.dense(h).sigmoid().squeeze(-1) |
|
|