updating model peptriever_2023-06-23T16:07:24.508460
Browse files- bi_encoder.py +64 -0
- config.json +4 -0
bi_encoder.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel
|
2 |
+
from transformers.models.bert.modeling_bert import BertOnlyMLMHead
|
3 |
+
|
4 |
+
from peptriever.model.bert_embedding import BertEmbeddingConfig, BertForEmbedding
|
5 |
+
|
6 |
+
|
7 |
+
class BiEncoderConfig(BertEmbeddingConfig):
|
8 |
+
max_length1: int
|
9 |
+
max_length2: int
|
10 |
+
|
11 |
+
|
12 |
+
class BiEncoder(PreTrainedModel):
|
13 |
+
config_class = BiEncoderConfig
|
14 |
+
|
15 |
+
def __init__(self, config: BiEncoderConfig):
|
16 |
+
super().__init__(config)
|
17 |
+
config1 = _replace_max_length(config, "max_length1")
|
18 |
+
self.bert1 = BertForEmbedding(config1)
|
19 |
+
config2 = _replace_max_length(config, "max_length2")
|
20 |
+
self.bert2 = BertForEmbedding(config2)
|
21 |
+
self.post_init()
|
22 |
+
|
23 |
+
def forward(self, x1, x2):
|
24 |
+
y1 = self.forward1(x1)
|
25 |
+
y2 = self.forward2(x2)
|
26 |
+
return {"y1": y1, "y2": y2}
|
27 |
+
|
28 |
+
def forward2(self, x2):
|
29 |
+
y2 = self.bert2(input_ids=x2["input_ids"])
|
30 |
+
return y2
|
31 |
+
|
32 |
+
def forward1(self, x1):
|
33 |
+
y1 = self.bert1(input_ids=x1["input_ids"])
|
34 |
+
return y1
|
35 |
+
|
36 |
+
|
37 |
+
class BiEncoderWithMaskedLM(PreTrainedModel):
|
38 |
+
config_class = BiEncoderConfig
|
39 |
+
|
40 |
+
def __init__(self, config: BiEncoderConfig):
|
41 |
+
super().__init__(config=config)
|
42 |
+
config1 = _replace_max_length(config, "max_length1")
|
43 |
+
self.bert1 = BertForEmbedding(config1)
|
44 |
+
self.lm_head1 = BertOnlyMLMHead(config=config1)
|
45 |
+
|
46 |
+
config2 = _replace_max_length(config, "max_length2")
|
47 |
+
self.bert2 = BertForEmbedding(config2)
|
48 |
+
self.lm_head2 = BertOnlyMLMHead(config=config2)
|
49 |
+
self.post_init()
|
50 |
+
|
51 |
+
def forward(self, x1, x2):
|
52 |
+
y1, state1 = self.bert1.forward_with_state(input_ids=x1["input_ids"])
|
53 |
+
y2, state2 = self.bert2.forward_with_state(input_ids=x2["input_ids"])
|
54 |
+
scores1 = self.lm_head1(state1)
|
55 |
+
scores2 = self.lm_head2(state2)
|
56 |
+
outputs = {"y1": y1, "y2": y2, "scores1": scores1, "scores2": scores2}
|
57 |
+
return outputs
|
58 |
+
|
59 |
+
|
60 |
+
def _replace_max_length(config, length_key):
|
61 |
+
c1 = config.__dict__.copy()
|
62 |
+
c1["max_position_embeddings"] = c1.pop(length_key)
|
63 |
+
config1 = BertEmbeddingConfig(**c1)
|
64 |
+
return config1
|
config.json
CHANGED
@@ -4,6 +4,10 @@
|
|
4 |
"BiEncoder"
|
5 |
],
|
6 |
"attention_probs_dropout_prob": 0.1,
|
|
|
|
|
|
|
|
|
7 |
"classifier_dropout": null,
|
8 |
"distance_func": "euclidean",
|
9 |
"hidden_act": "gelu",
|
|
|
4 |
"BiEncoder"
|
5 |
],
|
6 |
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"auto_map": {
|
8 |
+
"AutoConfig": "bi_encoder.BiEncoderConfig",
|
9 |
+
"AutoModel": "bi_encoder.BiEncoder"
|
10 |
+
},
|
11 |
"classifier_dropout": null,
|
12 |
"distance_func": "euclidean",
|
13 |
"hidden_act": "gelu",
|