File size: 1,018 Bytes
015365f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch

from transformers import XLMRobertaModel as XLMRobertaModelBase


class XLMRobertaModel(XLMRobertaModelBase):
    def __init__(self, config):
        super().__init__(config)
        self.question_projection = torch.nn.Linear(768, 512)
        self.answer_projection = torch.nn.Linear(768, 512)

    def _embed(self, input_ids, attention_mask, projection):
        outputs = super().__call__(input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]

        input_mask_expanded = attention_mask.unsqueeze(-1).expand(sequence_output.size()).float()
        embeddings = torch.sum(sequence_output * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return torch.tanh(projection(embeddings))

    def question(self, input_ids, attention_mask):
        return self._embed(input_ids, attention_mask, self.question_projection)

    def answer(self, input_ids, attention_mask):
        return self._embed(input_ids, attention_mask, self.answer_projection)