File size: 2,129 Bytes
2618b23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

from transformers import MistralModel, MistralConfig
from typing import Dict
from transformers.file_utils import ModelOutput
from typing import List, Optional, Tuple, Union
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch import nn, Tensor
from dataclasses import dataclass
from torch import nn
from typing import Dict
import torch
from transformers.file_utils import ModelOutput
import torch.nn.functional as F

COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y)

@dataclass
class EncoderOutput(ModelOutput):
    loss: Optional[Tensor] = None

class MistralModelEmbedding(MistralModel):
    def __init__(self, config: MistralConfig, **kwargs):
        super().__init__(config, **kwargs)

        self.dense_layer = nn.Linear(self.config.hidden_size,768)
        
    def sentence_embedding(self, hidden_state, mask):
        if self.config.sentence_pooling_method == 'mean':
            s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
            d = mask.sum(axis=1, keepdim=True).float()
            return s / d
        elif self.config.sentence_pooling_method == 'cls':
            return hidden_state[:,0]

    def encode(self, features):
        if features is None:
            return None
        psg_out = super().forward(**features,return_dict=True)
        output = self.dense_layer(psg_out.last_hidden_state)
        p_reps = self.sentence_embedding(output, features['attention_mask'])
        if self.config.normalized:
            p_reps = torch.nn.functional.normalize(p_reps, dim=-1)
        return p_reps.contiguous()


    def forward(self, query: Dict[str, Tensor] = None,
                passage: Dict[str, Tensor] = None, labels = None, margin = 0.5):
        q_reps = self.encode(query)
        p_reps = self.encode(passage)

        loss = None
        if labels is not None:
            distances = COSINE_DISTANCE(q_reps, p_reps)
            losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(margin - distances).pow(2))
            loss = losses.mean()

        return EncoderOutput(
            loss=loss,
        )