File size: 5,096 Bytes
f71186a
 
 
 
007d7a5
 
f71186a
 
 
 
007d7a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f71186a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from typing import Optional, Tuple

import torch
from transformers import BertConfig, BertModel, BertPreTrainedModel, PreTrainedModel
from transformers.models.bert.modeling_bert import BertOnlyMLMHead


class BertEmbeddingConfig(BertConfig):
    n_output_dims: int
    distance_func: str = "euclidean"


class BiEncoderConfig(BertEmbeddingConfig):
    max_length1: int
    max_length2: int


class BiEncoder(PreTrainedModel):
    config_class = BiEncoderConfig

    def __init__(self, config: BiEncoderConfig):
        super().__init__(config)
        config1 = _replace_max_length(config, "max_length1")
        self.bert1 = BertForEmbedding(config1)
        config2 = _replace_max_length(config, "max_length2")
        self.bert2 = BertForEmbedding(config2)
        self.post_init()

    def forward(self, x1, x2):
        y1 = self.forward1(x1)
        y2 = self.forward2(x2)
        return {"y1": y1, "y2": y2}

    def forward2(self, x2):
        y2 = self.bert2(input_ids=x2["input_ids"])
        return y2

    def forward1(self, x1):
        y1 = self.bert1(input_ids=x1["input_ids"])
        return y1


class BiEncoderWithMaskedLM(PreTrainedModel):
    config_class = BiEncoderConfig

    def __init__(self, config: BiEncoderConfig):
        super().__init__(config=config)
        config1 = _replace_max_length(config, "max_length1")
        self.bert1 = BertForEmbedding(config1)
        self.lm_head1 = BertOnlyMLMHead(config=config1)

        config2 = _replace_max_length(config, "max_length2")
        self.bert2 = BertForEmbedding(config2)
        self.lm_head2 = BertOnlyMLMHead(config=config2)
        self.post_init()

    def forward(self, x1, x2):
        y1, state1 = self.bert1.forward_with_state(input_ids=x1["input_ids"])
        y2, state2 = self.bert2.forward_with_state(input_ids=x2["input_ids"])
        scores1 = self.lm_head1(state1)
        scores2 = self.lm_head2(state2)
        outputs = {"y1": y1, "y2": y2, "scores1": scores1, "scores2": scores2}
        return outputs


def _replace_max_length(config, length_key):
    c1 = config.__dict__.copy()
    c1["max_position_embeddings"] = c1.pop(length_key)
    config1 = BertEmbeddingConfig(**c1)
    return config1


class L2Norm:
    def __call__(self, x):
        return x / torch.norm(x, p=2, dim=-1, keepdim=True)


class BertForEmbedding(BertPreTrainedModel):
    config_class = BertEmbeddingConfig

    def __init__(self, config: BertEmbeddingConfig):
        super().__init__(config)
        n_output_dims = config.n_output_dims
        self.fc = torch.nn.Linear(config.hidden_size, n_output_dims)
        self.bert = BertModel(config)
        self.activation = _get_activation(config.distance_func)
        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> torch.Tensor:
        embedding, _ = self.forward_with_state(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        return embedding

    def forward_with_state(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        encoded = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        pooler_output = encoded.pooler_output
        logits = self.fc(pooler_output)
        embedding = self.activation(logits)
        return embedding, encoded.last_hidden_state


def _get_activation(distance_func: str):
    if distance_func == "euclidean":
        activation = torch.nn.Tanh()
    elif distance_func == "angular":
        activation = L2Norm()  # type: ignore
    else:
        raise NotImplementedError()
    return activation