secilozksen commited on
Commit
a2357b3
1 Parent(s): fc50456

Upload model

Browse files
Files changed (4) hide show
  1. config.json +13 -0
  2. configuration_dpr.py +7 -0
  3. modeling_dpr.py +66 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/home/secilsen/Desktop/DPR-contrastive-finetuned/DPR-model-contrastive-finetuned/checkpoint-168",
3
+ "architectures": [
4
+ "OBSSDPRModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_dpr.CustomDPRConfig",
8
+ "AutoModel": "modeling_dpr.OBSSDPRModel"
9
+ },
10
+ "model_type": "dpr",
11
+ "torch_dtype": "float32",
12
+ "transformers_version": "4.24.0"
13
+ }
configuration_dpr.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class CustomDPRConfig(PretrainedConfig):
5
+ model_type = 'dpr'
6
+ def __init__(self, **kwargs):
7
+ super().__init__(**kwargs)
modeling_dpr.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, AutoModel, AutoTokenizer
2
+ import torch
3
+ import torch.nn as nn
4
+ from .configuration_dpr import CustomDPRConfig
5
+ from typing import Union, List, Dict
6
+
7
+
8
+ class OBSSDPRModel(PreTrainedModel):
9
+ config_class = CustomDPRConfig
10
+
11
+ def __init__(self, config):
12
+ super().__init__(config)
13
+ self.config = config
14
+ self.model = DPRModel()
15
+
16
+ def forward(self, input):
17
+ return self.model(input)
18
+
19
+
20
+ class DPRModel(nn.Module):
21
+ def __init__(self,
22
+ question_model_name='facebook/contriever-msmarco',
23
+ context_model_name='facebook/contriever-msmarco'):
24
+ super(DPRModel, self).__init__()
25
+ self.question_model = AutoModel.from_pretrained(question_model_name)
26
+ self.context_model = AutoModel.from_pretrained(context_model_name)
27
+
28
+ def freeze_layers(self, freeze_params):
29
+ num_layers_context = sum(1 for _ in self.context_model.parameters())
30
+ num_layers_question = sum(1 for _ in self.question_model.parameters())
31
+
32
+ for parameters in list(self.context_model.parameters())[:int(freeze_params * num_layers_context)]:
33
+ parameters.requires_grad = False
34
+
35
+ for parameters in list(self.context_model.parameters())[int(freeze_params * num_layers_context):]:
36
+ parameters.requires_grad = True
37
+
38
+ for parameters in list(self.question_model.parameters())[:int(freeze_params * num_layers_question)]:
39
+ parameters.requires_grad = False
40
+
41
+ for parameters in list(self.question_model.parameters())[int(freeze_params * num_layers_question):]:
42
+ parameters.requires_grad = True
43
+
44
+ def batch_dot_product(self, context_output, question_output):
45
+ mat1 = torch.unsqueeze(question_output, dim=1)
46
+ mat2 = torch.unsqueeze(context_output, dim=2)
47
+ result = torch.bmm(mat1, mat2)
48
+ result = torch.squeeze(result, dim=1)
49
+ result = torch.squeeze(result, dim=1)
50
+ return result
51
+
52
+ ##FOR CONTRIEVER
53
+ def mean_pooling(self, token_embeddings, mask):
54
+ token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
55
+ sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
56
+ return sentence_embeddings
57
+
58
+ def forward(self, batch: Union[List[Dict], Dict]):
59
+ context_tensor = batch['context_tensor']
60
+ question_tensor = batch['question_tensor']
61
+ context_model_output = self.context_model(**context_tensor)
62
+ question_model_output = self.question_model(**question_tensor)
63
+ embeddings_context = self.mean_pooling(context_model_output[0], context_tensor['attention_mask'])
64
+ embeddings_question = self.mean_pooling(question_model_output[0], question_tensor['attention_mask'])
65
+ scores = self.batch_dot_product(embeddings_context, embeddings_question) # self.scale
66
+ return scores
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6cdb51315917885d11d8224ca54cc176be86cdc1a62145c2452ec4d6a0feb3e
3
+ size 876003341