ronig commited on
Commit
007d7a5
1 Parent(s): 9cf2a8a

updating model peptriever_2023-06-23T16:07:24.508460

Browse files
Files changed (2) hide show
  1. bi_encoder.py +64 -0
  2. config.json +4 -0
bi_encoder.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from transformers.models.bert.modeling_bert import BertOnlyMLMHead
3
+
4
+ from peptriever.model.bert_embedding import BertEmbeddingConfig, BertForEmbedding
5
+
6
+
7
+ class BiEncoderConfig(BertEmbeddingConfig):
8
+ max_length1: int
9
+ max_length2: int
10
+
11
+
12
+ class BiEncoder(PreTrainedModel):
13
+ config_class = BiEncoderConfig
14
+
15
+ def __init__(self, config: BiEncoderConfig):
16
+ super().__init__(config)
17
+ config1 = _replace_max_length(config, "max_length1")
18
+ self.bert1 = BertForEmbedding(config1)
19
+ config2 = _replace_max_length(config, "max_length2")
20
+ self.bert2 = BertForEmbedding(config2)
21
+ self.post_init()
22
+
23
+ def forward(self, x1, x2):
24
+ y1 = self.forward1(x1)
25
+ y2 = self.forward2(x2)
26
+ return {"y1": y1, "y2": y2}
27
+
28
+ def forward2(self, x2):
29
+ y2 = self.bert2(input_ids=x2["input_ids"])
30
+ return y2
31
+
32
+ def forward1(self, x1):
33
+ y1 = self.bert1(input_ids=x1["input_ids"])
34
+ return y1
35
+
36
+
37
+ class BiEncoderWithMaskedLM(PreTrainedModel):
38
+ config_class = BiEncoderConfig
39
+
40
+ def __init__(self, config: BiEncoderConfig):
41
+ super().__init__(config=config)
42
+ config1 = _replace_max_length(config, "max_length1")
43
+ self.bert1 = BertForEmbedding(config1)
44
+ self.lm_head1 = BertOnlyMLMHead(config=config1)
45
+
46
+ config2 = _replace_max_length(config, "max_length2")
47
+ self.bert2 = BertForEmbedding(config2)
48
+ self.lm_head2 = BertOnlyMLMHead(config=config2)
49
+ self.post_init()
50
+
51
+ def forward(self, x1, x2):
52
+ y1, state1 = self.bert1.forward_with_state(input_ids=x1["input_ids"])
53
+ y2, state2 = self.bert2.forward_with_state(input_ids=x2["input_ids"])
54
+ scores1 = self.lm_head1(state1)
55
+ scores2 = self.lm_head2(state2)
56
+ outputs = {"y1": y1, "y2": y2, "scores1": scores1, "scores2": scores2}
57
+ return outputs
58
+
59
+
60
+ def _replace_max_length(config, length_key):
61
+ c1 = config.__dict__.copy()
62
+ c1["max_position_embeddings"] = c1.pop(length_key)
63
+ config1 = BertEmbeddingConfig(**c1)
64
+ return config1
config.json CHANGED
@@ -4,6 +4,10 @@
4
  "BiEncoder"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
 
 
 
 
7
  "classifier_dropout": null,
8
  "distance_func": "euclidean",
9
  "hidden_act": "gelu",
 
4
  "BiEncoder"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "bi_encoder.BiEncoderConfig",
9
+ "AutoModel": "bi_encoder.BiEncoder"
10
+ },
11
  "classifier_dropout": null,
12
  "distance_func": "euclidean",
13
  "hidden_act": "gelu",