Add DPR Bert model
Browse files- dpr_bert/__init__.py +2 -0
- dpr_bert/config.py +29 -0
- dpr_bert/model.py +20 -0
dpr_bert/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .config import DprConfig
|
2 |
+
from .model import DprModel
|
dpr_bert/config.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
|
4 |
+
# From https://huggingface.co/klue/bert-base
|
5 |
+
_default_config = {
|
6 |
+
"architectures": ["BertForMaskedLM"],
|
7 |
+
"attention_probs_dropout_prob": 0.1,
|
8 |
+
"hidden_act": "gelu",
|
9 |
+
"hidden_dropout_prob": 0.1,
|
10 |
+
"hidden_size": 768,
|
11 |
+
"initializer_range": 0.02,
|
12 |
+
"intermediate_size": 3072,
|
13 |
+
"layer_norm_eps": 1e-12,
|
14 |
+
"max_position_embeddings": 512,
|
15 |
+
"model_type": "bert",
|
16 |
+
"num_attention_heads": 12,
|
17 |
+
"num_hidden_layers": 12,
|
18 |
+
"pad_token_id": 0,
|
19 |
+
"type_vocab_size": 2,
|
20 |
+
"vocab_size": 32000
|
21 |
+
}
|
22 |
+
|
23 |
+
class DprConfig(PretrainedConfig):
|
24 |
+
model_type = "dpr"
|
25 |
+
|
26 |
+
def __init__(self, qst_config=_default_config, ctx_config=_default_config, **kwargs):
|
27 |
+
self.qst_config = qst_config
|
28 |
+
self.ctx_config = ctx_config
|
29 |
+
super().__init__(**kwargs)
|
dpr_bert/model.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import PreTrainedModel, BertConfig, BertModel
|
3 |
+
|
4 |
+
from .config import DprConfig
|
5 |
+
|
6 |
+
|
7 |
+
class DprModel(PreTrainedModel):
|
8 |
+
config_class = DprConfig
|
9 |
+
|
10 |
+
def __init__(self, config):
|
11 |
+
super().__init__(config)
|
12 |
+
qst_config = BertConfig(**config.qst_config)
|
13 |
+
ctx_config = BertConfig(**config.ctx_config)
|
14 |
+
self.qst_encoder = BertModel(qst_config)
|
15 |
+
self.ctx_encoder = BertModel(ctx_config)
|
16 |
+
|
17 |
+
def forward(self, qst_input_ids, qst_attention_mask, ctx_input_ids, ctx_attention_mask):
|
18 |
+
qst_outputs = self.qst_encoder(input_ids=qst_input_ids, attention_mask=qst_attention_mask)
|
19 |
+
ctx_outputs = self.ctx_encoder(input_ids=ctx_input_ids, attention_mask=ctx_attention_mask)
|
20 |
+
return torch.einsum("ih,jh->ij", qst_outputs.pooler_output, ctx_outputs.pooler_output)
|