import torch from transformers import XLMRobertaModel as XLMRobertaModelBase class XLMRobertaModel(XLMRobertaModelBase): def __init__(self, config): super().__init__(config) self.question_projection = torch.nn.Linear(768, 512) self.answer_projection = torch.nn.Linear(768, 512) def _embed(self, input_ids, attention_mask, projection): outputs = super().__call__(input_ids, attention_mask=attention_mask) sequence_output = outputs[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(sequence_output.size()).float() embeddings = torch.sum(sequence_output * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) return torch.tanh(projection(embeddings)) def question(self, input_ids, attention_mask): return self._embed(input_ids, attention_mask, self.question_projection) def answer(self, input_ids, attention_mask): return self._embed(input_ids, attention_mask, self.answer_projection)