riccorl commited on
Commit
91e262c
1 Parent(s): 28d9162

Upload models

Browse files
Files changed (24) hide show
  1. .gitattributes +1 -0
  2. app.py +3 -3
  3. models/relik-reader-aida-deberta-small/.gitattributes +35 -0
  4. models/relik-reader-aida-deberta-small/added_tokens.json +108 -0
  5. models/relik-reader-aida-deberta-small/config.json +18 -0
  6. models/relik-reader-aida-deberta-small/configuration_relik.py +33 -0
  7. models/relik-reader-aida-deberta-small/modeling_relik.py +983 -0
  8. models/relik-reader-aida-deberta-small/pytorch_model.bin +3 -0
  9. models/relik-reader-aida-deberta-small/special_tokens_map.json +112 -0
  10. models/relik-reader-aida-deberta-small/spm.model +3 -0
  11. models/relik-reader-aida-deberta-small/tokenizer.json +0 -0
  12. models/relik-reader-aida-deberta-small/tokenizer_config.json +970 -0
  13. models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index/config.yaml +8 -0
  14. models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index/documents.json +3 -0
  15. models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index/embeddings.pt +3 -0
  16. models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder/added_tokens.json +7 -0
  17. models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder/config.json +28 -0
  18. models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder/hf.py +88 -0
  19. models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder/pytorch_model.bin +3 -0
  20. models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder/special_tokens_map.json +7 -0
  21. models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder/tokenizer.json +0 -0
  22. models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder/tokenizer_config.json +56 -0
  23. models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder/vocab.txt +0 -0
  24. scripts/setup.sh +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index/documents.json filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -180,9 +180,9 @@ def run_client():
180
  # submit = st.button("Run")
181
 
182
  relik = Relik(
183
- question_encoder="riccorl/relik-retriever-small-aida-blink-pretrain-omniencoder",
184
- document_index="riccorl/index-relik-retriever-small-aida-blink-pretrain-omniencoder",
185
- reader="riccorl/relik-reader-aida-deberta-small",
186
  top_k=100,
187
  window_size=32,
188
  window_stride=16,
 
180
  # submit = st.button("Run")
181
 
182
  relik = Relik(
183
+ question_encoder=Path(__file__).parent / "models" / "relik-retriever-small-aida-blink-pretrain-omniencoder" / "question_encoder",
184
+ document_index=Path(__file__).parent / "models" / "relik-retriever-small-aida-blink-pretrain-omniencoder" / "document_index",
185
+ reader=Path(__file__).parent / "models" /"relik-reader-aida-deberta-small",
186
  top_k=100,
187
  window_size=32,
188
  window_stride=16,
models/relik-reader-aida-deberta-small/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
models/relik-reader-aida-deberta-small/added_tokens.json ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "--NME--": 128001,
3
+ "[CLS]": 1,
4
+ "[E-0]": 128002,
5
+ "[E-10]": 128012,
6
+ "[E-11]": 128013,
7
+ "[E-12]": 128014,
8
+ "[E-13]": 128015,
9
+ "[E-14]": 128016,
10
+ "[E-15]": 128017,
11
+ "[E-16]": 128018,
12
+ "[E-17]": 128019,
13
+ "[E-18]": 128020,
14
+ "[E-19]": 128021,
15
+ "[E-1]": 128003,
16
+ "[E-20]": 128022,
17
+ "[E-21]": 128023,
18
+ "[E-22]": 128024,
19
+ "[E-23]": 128025,
20
+ "[E-24]": 128026,
21
+ "[E-25]": 128027,
22
+ "[E-26]": 128028,
23
+ "[E-27]": 128029,
24
+ "[E-28]": 128030,
25
+ "[E-29]": 128031,
26
+ "[E-2]": 128004,
27
+ "[E-30]": 128032,
28
+ "[E-31]": 128033,
29
+ "[E-32]": 128034,
30
+ "[E-33]": 128035,
31
+ "[E-34]": 128036,
32
+ "[E-35]": 128037,
33
+ "[E-36]": 128038,
34
+ "[E-37]": 128039,
35
+ "[E-38]": 128040,
36
+ "[E-39]": 128041,
37
+ "[E-3]": 128005,
38
+ "[E-40]": 128042,
39
+ "[E-41]": 128043,
40
+ "[E-42]": 128044,
41
+ "[E-43]": 128045,
42
+ "[E-44]": 128046,
43
+ "[E-45]": 128047,
44
+ "[E-46]": 128048,
45
+ "[E-47]": 128049,
46
+ "[E-48]": 128050,
47
+ "[E-49]": 128051,
48
+ "[E-4]": 128006,
49
+ "[E-50]": 128052,
50
+ "[E-51]": 128053,
51
+ "[E-52]": 128054,
52
+ "[E-53]": 128055,
53
+ "[E-54]": 128056,
54
+ "[E-55]": 128057,
55
+ "[E-56]": 128058,
56
+ "[E-57]": 128059,
57
+ "[E-58]": 128060,
58
+ "[E-59]": 128061,
59
+ "[E-5]": 128007,
60
+ "[E-60]": 128062,
61
+ "[E-61]": 128063,
62
+ "[E-62]": 128064,
63
+ "[E-63]": 128065,
64
+ "[E-64]": 128066,
65
+ "[E-65]": 128067,
66
+ "[E-66]": 128068,
67
+ "[E-67]": 128069,
68
+ "[E-68]": 128070,
69
+ "[E-69]": 128071,
70
+ "[E-6]": 128008,
71
+ "[E-70]": 128072,
72
+ "[E-71]": 128073,
73
+ "[E-72]": 128074,
74
+ "[E-73]": 128075,
75
+ "[E-74]": 128076,
76
+ "[E-75]": 128077,
77
+ "[E-76]": 128078,
78
+ "[E-77]": 128079,
79
+ "[E-78]": 128080,
80
+ "[E-79]": 128081,
81
+ "[E-7]": 128009,
82
+ "[E-80]": 128082,
83
+ "[E-81]": 128083,
84
+ "[E-82]": 128084,
85
+ "[E-83]": 128085,
86
+ "[E-84]": 128086,
87
+ "[E-85]": 128087,
88
+ "[E-86]": 128088,
89
+ "[E-87]": 128089,
90
+ "[E-88]": 128090,
91
+ "[E-89]": 128091,
92
+ "[E-8]": 128010,
93
+ "[E-90]": 128092,
94
+ "[E-91]": 128093,
95
+ "[E-92]": 128094,
96
+ "[E-93]": 128095,
97
+ "[E-94]": 128096,
98
+ "[E-95]": 128097,
99
+ "[E-96]": 128098,
100
+ "[E-97]": 128099,
101
+ "[E-98]": 128100,
102
+ "[E-99]": 128101,
103
+ "[E-9]": 128011,
104
+ "[MASK]": 128000,
105
+ "[PAD]": 0,
106
+ "[SEP]": 2,
107
+ "[UNK]": 3
108
+ }
models/relik-reader-aida-deberta-small/config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "additional_special_symbols": 101,
4
+ "architectures": [
5
+ "RelikReaderELModel"
6
+ ],
7
+ "auto_map": {
8
+ "AutoModel": "modeling_relik.RelikReaderELModel"
9
+ },
10
+ "linears_hidden_size": 512,
11
+ "model_type": "relik-reader",
12
+ "num_layers": null,
13
+ "torch_dtype": "float32",
14
+ "training": false,
15
+ "transformer_model": "microsoft/deberta-v3-small",
16
+ "transformers_version": "4.34.0",
17
+ "use_last_k_layers": 1
18
+ }
models/relik-reader-aida-deberta-small/configuration_relik.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from transformers import AutoConfig
4
+ from transformers.configuration_utils import PretrainedConfig
5
+
6
+
7
+ class RelikReaderConfig(PretrainedConfig):
8
+ model_type = "relik-reader"
9
+
10
+ def __init__(
11
+ self,
12
+ transformer_model: str = "microsoft/deberta-v3-base",
13
+ additional_special_symbols: int = 101,
14
+ num_layers: Optional[int] = None,
15
+ activation: str = "gelu",
16
+ linears_hidden_size: Optional[int] = 512,
17
+ use_last_k_layers: int = 1,
18
+ training: bool = False,
19
+ default_reader_class: Optional[str] = None,
20
+ **kwargs
21
+ ) -> None:
22
+ self.transformer_model = transformer_model
23
+ self.additional_special_symbols = additional_special_symbols
24
+ self.num_layers = num_layers
25
+ self.activation = activation
26
+ self.linears_hidden_size = linears_hidden_size
27
+ self.use_last_k_layers = use_last_k_layers
28
+ self.training = training
29
+ self.default_reader_class = default_reader_class
30
+ super().__init__(**kwargs)
31
+
32
+
33
+ AutoConfig.register("relik-reader", RelikReaderConfig)
models/relik-reader-aida-deberta-small/modeling_relik.py ADDED
@@ -0,0 +1,983 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, Any
2
+
3
+ import torch
4
+ from transformers import AutoModel, PreTrainedModel
5
+ from transformers.activations import GELUActivation, ClippedGELUActivation
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.modeling_utils import PoolerEndLogits
8
+
9
+ from .configuration_relik import RelikReaderConfig
10
+
11
+
12
+ class RelikReaderSample:
13
+ def __init__(self, **kwargs):
14
+ super().__setattr__("_d", {})
15
+ self._d = kwargs
16
+
17
+ def __getattribute__(self, item):
18
+ return super(RelikReaderSample, self).__getattribute__(item)
19
+
20
+ def __getattr__(self, item):
21
+ if item.startswith("__") and item.endswith("__"):
22
+ # this is likely some python library-specific variable (such as __deepcopy__ for copy)
23
+ # better follow standard behavior here
24
+ raise AttributeError(item)
25
+ elif item in self._d:
26
+ return self._d[item]
27
+ else:
28
+ return None
29
+
30
+ def __setattr__(self, key, value):
31
+ if key in self._d:
32
+ self._d[key] = value
33
+ else:
34
+ super().__setattr__(key, value)
35
+
36
+
37
+ activation2functions = {
38
+ "relu": torch.nn.ReLU(),
39
+ "gelu": GELUActivation(),
40
+ "gelu_10": ClippedGELUActivation(-10, 10),
41
+ }
42
+
43
+
44
+ class PoolerEndLogitsBi(PoolerEndLogits):
45
+ def __init__(self, config: PretrainedConfig):
46
+ super().__init__(config)
47
+ self.dense_1 = torch.nn.Linear(config.hidden_size, 2)
48
+
49
+ def forward(
50
+ self,
51
+ hidden_states: torch.FloatTensor,
52
+ start_states: Optional[torch.FloatTensor] = None,
53
+ start_positions: Optional[torch.LongTensor] = None,
54
+ p_mask: Optional[torch.FloatTensor] = None,
55
+ ) -> torch.FloatTensor:
56
+ if p_mask is not None:
57
+ p_mask = p_mask.unsqueeze(-1)
58
+ logits = super().forward(
59
+ hidden_states,
60
+ start_states,
61
+ start_positions,
62
+ p_mask,
63
+ )
64
+ return logits
65
+
66
+
67
+ class RelikReaderSpanModel(PreTrainedModel):
68
+ config_class = RelikReaderConfig
69
+
70
+ def __init__(self, config: RelikReaderConfig, *args, **kwargs):
71
+ super().__init__(config)
72
+ # Transformer model declaration
73
+ self.config = config
74
+ self.transformer_model = (
75
+ AutoModel.from_pretrained(self.config.transformer_model)
76
+ if self.config.num_layers is None
77
+ else AutoModel.from_pretrained(
78
+ self.config.transformer_model, num_hidden_layers=self.config.num_layers
79
+ )
80
+ )
81
+ self.transformer_model.resize_token_embeddings(
82
+ self.transformer_model.config.vocab_size
83
+ + self.config.additional_special_symbols
84
+ )
85
+
86
+ self.activation = self.config.activation
87
+ self.linears_hidden_size = self.config.linears_hidden_size
88
+ self.use_last_k_layers = self.config.use_last_k_layers
89
+
90
+ # named entity detection layers
91
+ self.ned_start_classifier = self._get_projection_layer(
92
+ self.activation, last_hidden=2, layer_norm=False
93
+ )
94
+ self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)
95
+
96
+ # END entity disambiguation layer
97
+ self.ed_start_projector = self._get_projection_layer(self.activation)
98
+ self.ed_end_projector = self._get_projection_layer(self.activation)
99
+
100
+ self.training = self.config.training
101
+
102
+ # criterion
103
+ self.criterion = torch.nn.CrossEntropyLoss()
104
+
105
+ def _get_projection_layer(
106
+ self,
107
+ activation: str,
108
+ last_hidden: Optional[int] = None,
109
+ input_hidden=None,
110
+ layer_norm: bool = True,
111
+ ) -> torch.nn.Sequential:
112
+ head_components = [
113
+ torch.nn.Dropout(0.1),
114
+ torch.nn.Linear(
115
+ self.transformer_model.config.hidden_size * self.use_last_k_layers
116
+ if input_hidden is None
117
+ else input_hidden,
118
+ self.linears_hidden_size,
119
+ ),
120
+ activation2functions[activation],
121
+ torch.nn.Dropout(0.1),
122
+ torch.nn.Linear(
123
+ self.linears_hidden_size,
124
+ self.linears_hidden_size if last_hidden is None else last_hidden,
125
+ ),
126
+ ]
127
+
128
+ if layer_norm:
129
+ head_components.append(
130
+ torch.nn.LayerNorm(
131
+ self.linears_hidden_size if last_hidden is None else last_hidden,
132
+ self.transformer_model.config.layer_norm_eps,
133
+ )
134
+ )
135
+
136
+ return torch.nn.Sequential(*head_components)
137
+
138
+ def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
139
+ mask = mask.unsqueeze(-1)
140
+ if next(self.parameters()).dtype == torch.float16:
141
+ logits = logits * (1 - mask) - 65500 * mask
142
+ else:
143
+ logits = logits * (1 - mask) - 1e30 * mask
144
+ return logits
145
+
146
+ def _get_model_features(
147
+ self,
148
+ input_ids: torch.Tensor,
149
+ attention_mask: torch.Tensor,
150
+ token_type_ids: Optional[torch.Tensor],
151
+ ):
152
+ model_input = {
153
+ "input_ids": input_ids,
154
+ "attention_mask": attention_mask,
155
+ "output_hidden_states": self.use_last_k_layers > 1,
156
+ }
157
+
158
+ if token_type_ids is not None:
159
+ model_input["token_type_ids"] = token_type_ids
160
+
161
+ model_output = self.transformer_model(**model_input)
162
+
163
+ if self.use_last_k_layers > 1:
164
+ model_features = torch.cat(
165
+ model_output[1][-self.use_last_k_layers :], dim=-1
166
+ )
167
+ else:
168
+ model_features = model_output[0]
169
+
170
+ return model_features
171
+
172
+ def compute_ned_end_logits(
173
+ self,
174
+ start_predictions,
175
+ start_labels,
176
+ model_features,
177
+ prediction_mask,
178
+ batch_size,
179
+ ) -> Optional[torch.Tensor]:
180
+ # todo: maybe when constraining on the spans,
181
+ # we should not use a prediction_mask for the end tokens.
182
+ # at least we should not during training imo
183
+ start_positions = start_labels if self.training else start_predictions
184
+ start_positions_indices = (
185
+ torch.arange(start_positions.size(1), device=start_positions.device)
186
+ .unsqueeze(0)
187
+ .expand(batch_size, -1)[start_positions > 0]
188
+ ).to(start_positions.device)
189
+
190
+ if len(start_positions_indices) > 0:
191
+ expanded_features = torch.cat(
192
+ [
193
+ model_features[i].unsqueeze(0).expand(x, -1, -1)
194
+ for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
195
+ if x > 0
196
+ ],
197
+ dim=0,
198
+ ).to(start_positions_indices.device)
199
+
200
+ expanded_prediction_mask = torch.cat(
201
+ [
202
+ prediction_mask[i].unsqueeze(0).expand(x, -1)
203
+ for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
204
+ if x > 0
205
+ ],
206
+ dim=0,
207
+ ).to(expanded_features.device)
208
+
209
+ end_logits = self.ned_end_classifier(
210
+ hidden_states=expanded_features,
211
+ start_positions=start_positions_indices,
212
+ p_mask=expanded_prediction_mask,
213
+ )
214
+
215
+ return end_logits
216
+
217
+ return None
218
+
219
+ def compute_classification_logits(
220
+ self,
221
+ model_features,
222
+ special_symbols_mask,
223
+ prediction_mask,
224
+ batch_size,
225
+ start_positions=None,
226
+ end_positions=None,
227
+ ) -> torch.Tensor:
228
+ if start_positions is None or end_positions is None:
229
+ start_positions = torch.zeros_like(prediction_mask)
230
+ end_positions = torch.zeros_like(prediction_mask)
231
+
232
+ model_start_features = self.ed_start_projector(model_features)
233
+ model_end_features = self.ed_end_projector(model_features)
234
+ model_end_features[start_positions > 0] = model_end_features[end_positions > 0]
235
+
236
+ model_ed_features = torch.cat(
237
+ [model_start_features, model_end_features], dim=-1
238
+ )
239
+
240
+ # computing ed features
241
+ classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item()
242
+ special_symbols_representation = model_ed_features[special_symbols_mask].view(
243
+ batch_size, classes_representations, -1
244
+ )
245
+
246
+ logits = torch.bmm(
247
+ model_ed_features,
248
+ torch.permute(special_symbols_representation, (0, 2, 1)),
249
+ )
250
+
251
+ logits = self._mask_logits(logits, prediction_mask)
252
+
253
+ return logits
254
+
255
+ def forward(
256
+ self,
257
+ input_ids: torch.Tensor,
258
+ attention_mask: torch.Tensor,
259
+ token_type_ids: Optional[torch.Tensor] = None,
260
+ prediction_mask: Optional[torch.Tensor] = None,
261
+ special_symbols_mask: Optional[torch.Tensor] = None,
262
+ start_labels: Optional[torch.Tensor] = None,
263
+ end_labels: Optional[torch.Tensor] = None,
264
+ use_predefined_spans: bool = False,
265
+ *args,
266
+ **kwargs,
267
+ ) -> Dict[str, Any]:
268
+
269
+ batch_size, seq_len = input_ids.shape
270
+
271
+ model_features = self._get_model_features(
272
+ input_ids, attention_mask, token_type_ids
273
+ )
274
+
275
+ ned_start_labels = None
276
+
277
+ # named entity detection if required
278
+ if use_predefined_spans: # no need to compute spans
279
+ ned_start_logits, ned_start_probabilities, ned_start_predictions = (
280
+ None,
281
+ None,
282
+ torch.clone(start_labels)
283
+ if start_labels is not None
284
+ else torch.zeros_like(input_ids),
285
+ )
286
+ ned_end_logits, ned_end_probabilities, ned_end_predictions = (
287
+ None,
288
+ None,
289
+ torch.clone(end_labels)
290
+ if end_labels is not None
291
+ else torch.zeros_like(input_ids),
292
+ )
293
+
294
+ ned_start_predictions[ned_start_predictions > 0] = 1
295
+ ned_end_predictions[ned_end_predictions > 0] = 1
296
+
297
+ else: # compute spans
298
+ # start boundary prediction
299
+ ned_start_logits = self.ned_start_classifier(model_features)
300
+ ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask)
301
+ ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
302
+ ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
303
+
304
+ # end boundary prediction
305
+ ned_start_labels = (
306
+ torch.zeros_like(start_labels) if start_labels is not None else None
307
+ )
308
+
309
+ if ned_start_labels is not None:
310
+ ned_start_labels[start_labels == -100] = -100
311
+ ned_start_labels[start_labels > 0] = 1
312
+
313
+ ned_end_logits = self.compute_ned_end_logits(
314
+ ned_start_predictions,
315
+ ned_start_labels,
316
+ model_features,
317
+ prediction_mask,
318
+ batch_size,
319
+ )
320
+
321
+ if ned_end_logits is not None:
322
+ ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
323
+ ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
324
+ else:
325
+ ned_end_logits, ned_end_probabilities = None, None
326
+ ned_end_predictions = ned_start_predictions.new_zeros(batch_size)
327
+
328
+ # flattening end predictions
329
+ # (flattening can happen only if the
330
+ # end boundaries were not predicted using the gold labels)
331
+ if not self.training:
332
+ flattened_end_predictions = torch.clone(ned_start_predictions)
333
+ flattened_end_predictions[flattened_end_predictions > 0] = 0
334
+
335
+ batch_start_predictions = list()
336
+ for elem_idx in range(batch_size):
337
+ batch_start_predictions.append(
338
+ torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist()
339
+ )
340
+
341
+ # check that the total number of start predictions
342
+ # is equal to the end predictions
343
+ total_start_predictions = sum(map(len, batch_start_predictions))
344
+ total_end_predictions = len(ned_end_predictions)
345
+ assert (
346
+ total_start_predictions == 0
347
+ or total_start_predictions == total_end_predictions
348
+ ), (
349
+ f"Total number of start predictions = {total_start_predictions}. "
350
+ f"Total number of end predictions = {total_end_predictions}"
351
+ )
352
+
353
+ curr_end_pred_num = 0
354
+ for elem_idx, bsp in enumerate(batch_start_predictions):
355
+ for sp in bsp:
356
+ ep = ned_end_predictions[curr_end_pred_num].item()
357
+ if ep < sp:
358
+ ep = sp
359
+
360
+ # if we already set this span throw it (no overlap)
361
+ if flattened_end_predictions[elem_idx, ep] == 1:
362
+ ned_start_predictions[elem_idx, sp] = 0
363
+ else:
364
+ flattened_end_predictions[elem_idx, ep] = 1
365
+
366
+ curr_end_pred_num += 1
367
+
368
+ ned_end_predictions = flattened_end_predictions
369
+
370
+ start_position, end_position = (
371
+ (start_labels, end_labels)
372
+ if self.training
373
+ else (ned_start_predictions, ned_end_predictions)
374
+ )
375
+
376
+ # Entity disambiguation
377
+ ed_logits = self.compute_classification_logits(
378
+ model_features,
379
+ special_symbols_mask,
380
+ prediction_mask,
381
+ batch_size,
382
+ start_position,
383
+ end_position,
384
+ )
385
+ ed_probabilities = torch.softmax(ed_logits, dim=-1)
386
+ ed_predictions = torch.argmax(ed_probabilities, dim=-1)
387
+
388
+ # output build
389
+ output_dict = dict(
390
+ batch_size=batch_size,
391
+ ned_start_logits=ned_start_logits,
392
+ ned_start_probabilities=ned_start_probabilities,
393
+ ned_start_predictions=ned_start_predictions,
394
+ ned_end_logits=ned_end_logits,
395
+ ned_end_probabilities=ned_end_probabilities,
396
+ ned_end_predictions=ned_end_predictions,
397
+ ed_logits=ed_logits,
398
+ ed_probabilities=ed_probabilities,
399
+ ed_predictions=ed_predictions,
400
+ )
401
+
402
+ # compute loss if labels
403
+ if start_labels is not None and end_labels is not None and self.training:
404
+ # named entity detection loss
405
+
406
+ # start
407
+ if ned_start_logits is not None:
408
+ ned_start_loss = self.criterion(
409
+ ned_start_logits.view(-1, ned_start_logits.shape[-1]),
410
+ ned_start_labels.view(-1),
411
+ )
412
+ else:
413
+ ned_start_loss = 0
414
+
415
+ # end
416
+ if ned_end_logits is not None:
417
+ ned_end_labels = torch.zeros_like(end_labels)
418
+ ned_end_labels[end_labels == -100] = -100
419
+ ned_end_labels[end_labels > 0] = 1
420
+
421
+ ned_end_loss = self.criterion(
422
+ ned_end_logits,
423
+ (
424
+ torch.arange(
425
+ ned_end_labels.size(1), device=ned_end_labels.device
426
+ )
427
+ .unsqueeze(0)
428
+ .expand(batch_size, -1)[ned_end_labels > 0]
429
+ ).to(ned_end_labels.device),
430
+ )
431
+
432
+ else:
433
+ ned_end_loss = 0
434
+
435
+ # entity disambiguation loss
436
+ start_labels[ned_start_labels != 1] = -100
437
+ ed_labels = torch.clone(start_labels)
438
+ ed_labels[end_labels > 0] = end_labels[end_labels > 0]
439
+ ed_loss = self.criterion(
440
+ ed_logits.view(-1, ed_logits.shape[-1]),
441
+ ed_labels.view(-1),
442
+ )
443
+
444
+ output_dict["ned_start_loss"] = ned_start_loss
445
+ output_dict["ned_end_loss"] = ned_end_loss
446
+ output_dict["ed_loss"] = ed_loss
447
+
448
+ output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss
449
+
450
+ return output_dict
451
+
452
+
453
+ class RelikReaderREModel(PreTrainedModel):
454
+ config_class = RelikReaderConfig
455
+
456
+ def __init__(self, config, *args, **kwargs):
457
+ super().__init__(config)
458
+ # Transformer model declaration
459
+ # self.transformer_model_name = transformer_model
460
+ self.config = config
461
+ self.transformer_model = (
462
+ AutoModel.from_pretrained(config.transformer_model)
463
+ if config.num_layers is None
464
+ else AutoModel.from_pretrained(
465
+ config.transformer_model, num_hidden_layers=config.num_layers
466
+ )
467
+ )
468
+ self.transformer_model.resize_token_embeddings(
469
+ self.transformer_model.config.vocab_size + config.additional_special_symbols
470
+ )
471
+
472
+ # named entity detection layers
473
+ self.ned_start_classifier = self._get_projection_layer(
474
+ config.activation, last_hidden=2, layer_norm=False
475
+ )
476
+
477
+ self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config)
478
+
479
+ self.entity_type_loss = (
480
+ config.entity_type_loss if hasattr(config, "entity_type_loss") else False
481
+ )
482
+ self.relation_disambiguation_loss = (
483
+ config.relation_disambiguation_loss
484
+ if hasattr(config, "relation_disambiguation_loss")
485
+ else False
486
+ )
487
+
488
+ input_hidden_ents = 2 * self.transformer_model.config.hidden_size
489
+
490
+ self.re_subject_projector = self._get_projection_layer(
491
+ config.activation, input_hidden=input_hidden_ents
492
+ )
493
+ self.re_object_projector = self._get_projection_layer(
494
+ config.activation, input_hidden=input_hidden_ents
495
+ )
496
+ self.re_relation_projector = self._get_projection_layer(config.activation)
497
+
498
+ if self.entity_type_loss or self.relation_disambiguation_loss:
499
+ self.re_entities_projector = self._get_projection_layer(
500
+ config.activation,
501
+ input_hidden=2 * self.transformer_model.config.hidden_size,
502
+ )
503
+ self.re_definition_projector = self._get_projection_layer(
504
+ config.activation,
505
+ )
506
+
507
+ self.re_classifier = self._get_projection_layer(
508
+ config.activation,
509
+ input_hidden=config.linears_hidden_size,
510
+ last_hidden=2,
511
+ layer_norm=False,
512
+ )
513
+
514
+ if self.entity_type_loss or self.relation_disambiguation_loss:
515
+ self.re_ed_classifier = self._get_projection_layer(
516
+ config.activation,
517
+ input_hidden=config.linears_hidden_size,
518
+ last_hidden=2,
519
+ layer_norm=False,
520
+ )
521
+
522
+ self.training = config.training
523
+
524
+ # criterion
525
+ self.criterion = torch.nn.CrossEntropyLoss()
526
+
527
+ def _get_projection_layer(
528
+ self,
529
+ activation: str,
530
+ last_hidden: Optional[int] = None,
531
+ input_hidden=None,
532
+ layer_norm: bool = True,
533
+ ) -> torch.nn.Sequential:
534
+ head_components = [
535
+ torch.nn.Dropout(0.1),
536
+ torch.nn.Linear(
537
+ self.transformer_model.config.hidden_size
538
+ * self.config.use_last_k_layers
539
+ if input_hidden is None
540
+ else input_hidden,
541
+ self.config.linears_hidden_size,
542
+ ),
543
+ activation2functions[activation],
544
+ torch.nn.Dropout(0.1),
545
+ torch.nn.Linear(
546
+ self.config.linears_hidden_size,
547
+ self.config.linears_hidden_size if last_hidden is None else last_hidden,
548
+ ),
549
+ ]
550
+
551
+ if layer_norm:
552
+ head_components.append(
553
+ torch.nn.LayerNorm(
554
+ self.config.linears_hidden_size
555
+ if last_hidden is None
556
+ else last_hidden,
557
+ self.transformer_model.config.layer_norm_eps,
558
+ )
559
+ )
560
+
561
+ return torch.nn.Sequential(*head_components)
562
+
563
+ def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
564
+ mask = mask.unsqueeze(-1)
565
+ if next(self.parameters()).dtype == torch.float16:
566
+ logits = logits * (1 - mask) - 65500 * mask
567
+ else:
568
+ logits = logits * (1 - mask) - 1e30 * mask
569
+ return logits
570
+
571
+ def _get_model_features(
572
+ self,
573
+ input_ids: torch.Tensor,
574
+ attention_mask: torch.Tensor,
575
+ token_type_ids: Optional[torch.Tensor],
576
+ ):
577
+ model_input = {
578
+ "input_ids": input_ids,
579
+ "attention_mask": attention_mask,
580
+ "output_hidden_states": self.config.use_last_k_layers > 1,
581
+ }
582
+
583
+ if token_type_ids is not None:
584
+ model_input["token_type_ids"] = token_type_ids
585
+
586
+ model_output = self.transformer_model(**model_input)
587
+
588
+ if self.config.use_last_k_layers > 1:
589
+ model_features = torch.cat(
590
+ model_output[1][-self.config.use_last_k_layers :], dim=-1
591
+ )
592
+ else:
593
+ model_features = model_output[0]
594
+
595
+ return model_features
596
+
597
+ def compute_ned_end_logits(
598
+ self,
599
+ start_predictions,
600
+ start_labels,
601
+ model_features,
602
+ prediction_mask,
603
+ batch_size,
604
+ ) -> Optional[torch.Tensor]:
605
+ # todo: maybe when constraining on the spans,
606
+ # we should not use a prediction_mask for the end tokens.
607
+ # at least we should not during training imo
608
+ start_positions = start_labels if self.training else start_predictions
609
+ start_positions_indices = (
610
+ torch.arange(start_positions.size(1), device=start_positions.device)
611
+ .unsqueeze(0)
612
+ .expand(batch_size, -1)[start_positions > 0]
613
+ ).to(start_positions.device)
614
+
615
+ if len(start_positions_indices) > 0:
616
+ expanded_features = torch.cat(
617
+ [
618
+ model_features[i].unsqueeze(0).expand(x, -1, -1)
619
+ for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
620
+ if x > 0
621
+ ],
622
+ dim=0,
623
+ ).to(start_positions_indices.device)
624
+
625
+ expanded_prediction_mask = torch.cat(
626
+ [
627
+ prediction_mask[i].unsqueeze(0).expand(x, -1)
628
+ for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
629
+ if x > 0
630
+ ],
631
+ dim=0,
632
+ ).to(expanded_features.device)
633
+
634
+ # mask all tokens before start_positions_indices ie, mask all tokens with
635
+ # indices < start_positions_indices with 1, ie. [range(x) for x in start_positions_indices]
636
+ expanded_prediction_mask = torch.stack(
637
+ [
638
+ torch.cat(
639
+ [
640
+ torch.ones(x, device=expanded_features.device),
641
+ expanded_prediction_mask[i, x:],
642
+ ]
643
+ )
644
+ for i, x in enumerate(start_positions_indices)
645
+ if x > 0
646
+ ],
647
+ dim=0,
648
+ ).to(expanded_features.device)
649
+
650
+ end_logits = self.ned_end_classifier(
651
+ hidden_states=expanded_features,
652
+ start_positions=start_positions_indices,
653
+ p_mask=expanded_prediction_mask,
654
+ )
655
+
656
+ return end_logits
657
+
658
+ return None
659
+
660
+ def compute_relation_logits(
661
+ self,
662
+ model_entity_features,
663
+ special_symbols_features,
664
+ ) -> torch.Tensor:
665
+ model_subject_features = self.re_subject_projector(model_entity_features)
666
+ model_object_features = self.re_object_projector(model_entity_features)
667
+ special_symbols_start_representation = self.re_relation_projector(
668
+ special_symbols_features
669
+ )
670
+ re_logits = torch.einsum(
671
+ "bse,bde,bfe->bsdfe",
672
+ model_subject_features,
673
+ model_object_features,
674
+ special_symbols_start_representation,
675
+ )
676
+ re_logits = self.re_classifier(re_logits)
677
+
678
+ return re_logits
679
+
680
+ def compute_entity_logits(
681
+ self,
682
+ model_entity_features,
683
+ special_symbols_features,
684
+ ) -> torch.Tensor:
685
+ model_ed_features = self.re_entities_projector(model_entity_features)
686
+ special_symbols_ed_representation = self.re_definition_projector(
687
+ special_symbols_features
688
+ )
689
+ logits = torch.einsum(
690
+ "bce,bde->bcde",
691
+ model_ed_features,
692
+ special_symbols_ed_representation,
693
+ )
694
+ logits = self.re_ed_classifier(logits)
695
+ start_logits = self._mask_logits(
696
+ logits,
697
+ (model_entity_features == -100)
698
+ .all(2)
699
+ .long()
700
+ .unsqueeze(2)
701
+ .repeat(1, 1, torch.sum(model_entity_features, dim=1)[0].item()),
702
+ )
703
+
704
+ return logits
705
+
706
+ def compute_loss(self, logits, labels, mask=None):
707
+ logits = logits.view(-1, logits.shape[-1])
708
+ labels = labels.view(-1).long()
709
+ if mask is not None:
710
+ return self.criterion(logits[mask], labels[mask])
711
+ return self.criterion(logits, labels)
712
+
713
+ def compute_ned_end_loss(self, ned_end_logits, end_labels):
714
+ if ned_end_logits is None:
715
+ return 0
716
+ ned_end_labels = torch.zeros_like(end_labels)
717
+ ned_end_labels[end_labels == -100] = -100
718
+ ned_end_labels[end_labels > 0] = 1
719
+ return self.compute_loss(ned_end_logits, ned_end_labels)
720
+
721
+ def compute_ned_type_loss(
722
+ self,
723
+ disambiguation_labels,
724
+ re_ned_entities_logits,
725
+ ned_type_logits,
726
+ re_entities_logits,
727
+ entity_types,
728
+ ):
729
+ if self.entity_type_loss and self.relation_disambiguation_loss:
730
+ return self.compute_loss(disambiguation_labels, re_ned_entities_logits)
731
+ if self.entity_type_loss:
732
+ return self.compute_loss(
733
+ disambiguation_labels[:, :, :entity_types], ned_type_logits
734
+ )
735
+ if self.relation_disambiguation_loss:
736
+ return self.compute_loss(disambiguation_labels, re_entities_logits)
737
+ return 0
738
+
739
+ def compute_relation_loss(self, relation_labels, re_logits):
740
+ return self.compute_loss(
741
+ re_logits, relation_labels, relation_labels.view(-1) != -100
742
+ )
743
+
744
+ def forward(
745
+ self,
746
+ input_ids: torch.Tensor,
747
+ attention_mask: torch.Tensor,
748
+ token_type_ids: torch.Tensor,
749
+ prediction_mask: Optional[torch.Tensor] = None,
750
+ special_symbols_mask: Optional[torch.Tensor] = None,
751
+ special_symbols_mask_entities: Optional[torch.Tensor] = None,
752
+ start_labels: Optional[torch.Tensor] = None,
753
+ end_labels: Optional[torch.Tensor] = None,
754
+ disambiguation_labels: Optional[torch.Tensor] = None,
755
+ relation_labels: Optional[torch.Tensor] = None,
756
+ is_validation: bool = False,
757
+ is_prediction: bool = False,
758
+ *args,
759
+ **kwargs,
760
+ ) -> Dict[str, Any]:
761
+
762
+ batch_size = input_ids.shape[0]
763
+
764
+ model_features = self._get_model_features(
765
+ input_ids, attention_mask, token_type_ids
766
+ )
767
+
768
+ # named entity detection
769
+ if is_prediction and start_labels is not None:
770
+ ned_start_logits, ned_start_probabilities, ned_start_predictions = (
771
+ None,
772
+ None,
773
+ torch.zeros_like(start_labels),
774
+ )
775
+ ned_end_logits, ned_end_probabilities, ned_end_predictions = (
776
+ None,
777
+ None,
778
+ torch.zeros_like(end_labels),
779
+ )
780
+
781
+ ned_start_predictions[start_labels > 0] = 1
782
+ ned_end_predictions[end_labels > 0] = 1
783
+ ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
784
+ else:
785
+ # start boundary prediction
786
+ ned_start_logits = self.ned_start_classifier(model_features)
787
+ ned_start_logits = self._mask_logits(
788
+ ned_start_logits, prediction_mask
789
+ ) # why?
790
+ ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
791
+ ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
792
+
793
+ # end boundary prediction
794
+ ned_start_labels = (
795
+ torch.zeros_like(start_labels) if start_labels is not None else None
796
+ )
797
+
798
+ # start_labels contain entity id at their position, we just need 1 for start of entity
799
+ if ned_start_labels is not None:
800
+ ned_start_labels[start_labels > 0] = 1
801
+
802
+ # compute end logits only if there are any start predictions.
803
+ # For each start prediction, n end predictions are made
804
+ ned_end_logits = self.compute_ned_end_logits(
805
+ ned_start_predictions,
806
+ ned_start_labels,
807
+ model_features,
808
+ prediction_mask,
809
+ batch_size,
810
+ )
811
+ # For each start prediction, n end predictions are made based on
812
+ # binary classification ie. argmax at each position.
813
+ ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
814
+ ned_end_predictions = ned_end_probabilities.argmax(dim=-1)
815
+ if is_prediction or is_validation:
816
+ end_preds_count = ned_end_predictions.sum(1)
817
+ # If there are no end predictions for a start prediction, remove the start prediction
818
+ ned_start_predictions[ned_start_predictions == 1] = (
819
+ end_preds_count != 0
820
+ ).long()
821
+ ned_end_predictions = ned_end_predictions[end_preds_count != 0]
822
+
823
+ if end_labels is not None:
824
+ end_labels = end_labels[~(end_labels == -100).all(2)]
825
+
826
+ start_position, end_position = (
827
+ (start_labels, end_labels)
828
+ if (not is_prediction and not is_validation)
829
+ else (ned_start_predictions, ned_end_predictions)
830
+ )
831
+
832
+ start_counts = (start_position > 0).sum(1)
833
+ ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
834
+
835
+ # We can only predict relations if we have start and end predictions
836
+ if (end_position > 0).sum() > 0:
837
+ ends_count = (end_position > 0).sum(1)
838
+ model_subject_features = torch.cat(
839
+ [
840
+ torch.repeat_interleave(
841
+ model_features[start_position > 0], ends_count, dim=0
842
+ ), # start position features
843
+ torch.repeat_interleave(model_features, start_counts, dim=0)[
844
+ end_position > 0
845
+ ], # end position features
846
+ ],
847
+ dim=-1,
848
+ )
849
+ ents_count = torch.nn.utils.rnn.pad_sequence(
850
+ torch.split(ends_count, start_counts.tolist()),
851
+ batch_first=True,
852
+ padding_value=0,
853
+ ).sum(1)
854
+ model_subject_features = torch.nn.utils.rnn.pad_sequence(
855
+ torch.split(model_subject_features, ents_count.tolist()),
856
+ batch_first=True,
857
+ padding_value=-100,
858
+ )
859
+
860
+ if is_validation or is_prediction:
861
+ model_subject_features = model_subject_features[:, :30, :]
862
+
863
+ # entity disambiguation. Here relation_disambiguation_loss would only be useful to
864
+ # reduce the number of candidate relations for the next step, but currently unused.
865
+ if self.entity_type_loss or self.relation_disambiguation_loss:
866
+ (re_ned_entities_logits) = self.compute_entity_logits(
867
+ model_subject_features,
868
+ model_features[
869
+ special_symbols_mask | special_symbols_mask_entities
870
+ ].view(batch_size, -1, model_features.shape[-1]),
871
+ )
872
+ entity_types = torch.sum(special_symbols_mask_entities, dim=1)[0].item()
873
+ ned_type_logits = re_ned_entities_logits[:, :, :entity_types]
874
+ re_entities_logits = re_ned_entities_logits[:, :, entity_types:]
875
+
876
+ if self.entity_type_loss:
877
+ ned_type_probabilities = torch.softmax(ned_type_logits, dim=-1)
878
+ ned_type_predictions = ned_type_probabilities.argmax(dim=-1)
879
+ ned_type_predictions = ned_type_predictions.argmax(dim=-1)
880
+
881
+ re_entities_probabilities = torch.softmax(re_entities_logits, dim=-1)
882
+ re_entities_predictions = re_entities_probabilities.argmax(dim=-1)
883
+ else:
884
+ (
885
+ ned_type_logits,
886
+ ned_type_probabilities,
887
+ re_entities_logits,
888
+ re_entities_probabilities,
889
+ ) = (None, None, None, None)
890
+ ned_type_predictions, re_entities_predictions = (
891
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
892
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
893
+ )
894
+
895
+ # Compute relation logits
896
+ re_logits = self.compute_relation_logits(
897
+ model_subject_features,
898
+ model_features[special_symbols_mask].view(
899
+ batch_size, -1, model_features.shape[-1]
900
+ ),
901
+ )
902
+
903
+ re_probabilities = torch.softmax(re_logits, dim=-1)
904
+ # we set a thresshold instead of argmax in cause it needs to be tweaked
905
+ re_predictions = re_probabilities[:, :, :, :, 1] > 0.5
906
+ # re_predictions = re_probabilities.argmax(dim=-1)
907
+ re_probabilities = re_probabilities[:, :, :, :, 1]
908
+
909
+ else:
910
+ (
911
+ ned_type_logits,
912
+ ned_type_probabilities,
913
+ re_entities_logits,
914
+ re_entities_probabilities,
915
+ ) = (None, None, None, None)
916
+ ned_type_predictions, re_entities_predictions = (
917
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
918
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
919
+ )
920
+ re_logits, re_probabilities, re_predictions = (
921
+ torch.zeros(
922
+ [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
923
+ ).to(input_ids.device),
924
+ torch.zeros(
925
+ [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
926
+ ).to(input_ids.device),
927
+ torch.zeros(
928
+ [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
929
+ ).to(input_ids.device),
930
+ )
931
+
932
+ # output build
933
+ output_dict = dict(
934
+ batch_size=batch_size,
935
+ ned_start_logits=ned_start_logits,
936
+ ned_start_probabilities=ned_start_probabilities,
937
+ ned_start_predictions=ned_start_predictions,
938
+ ned_end_logits=ned_end_logits,
939
+ ned_end_probabilities=ned_end_probabilities,
940
+ ned_end_predictions=ned_end_predictions,
941
+ ned_type_logits=ned_type_logits,
942
+ ned_type_probabilities=ned_type_probabilities,
943
+ ned_type_predictions=ned_type_predictions,
944
+ re_entities_logits=re_entities_logits,
945
+ re_entities_probabilities=re_entities_probabilities,
946
+ re_entities_predictions=re_entities_predictions,
947
+ re_logits=re_logits,
948
+ re_probabilities=re_probabilities,
949
+ re_predictions=re_predictions,
950
+ )
951
+
952
+ if (
953
+ start_labels is not None
954
+ and end_labels is not None
955
+ and relation_labels is not None
956
+ ):
957
+ ned_start_loss = self.compute_loss(ned_start_logits, ned_start_labels)
958
+ ned_end_loss = self.compute_ned_end_loss(ned_end_logits, end_labels)
959
+ if self.entity_type_loss or self.relation_disambiguation_loss:
960
+ ned_type_loss = self.compute_ned_type_loss(
961
+ disambiguation_labels,
962
+ re_ned_entities_logits,
963
+ ned_type_logits,
964
+ re_entities_logits,
965
+ entity_types,
966
+ )
967
+ relation_loss = self.compute_relation_loss(relation_labels, re_logits)
968
+ # compute loss. We can skip the relation loss if we are in the first epochs (optional)
969
+ if self.entity_type_loss or self.relation_disambiguation_loss:
970
+ output_dict["loss"] = (
971
+ ned_start_loss + ned_end_loss + relation_loss + ned_type_loss
972
+ ) / 4
973
+ output_dict["ned_type_loss"] = ned_type_loss
974
+ else:
975
+ output_dict["loss"] = (
976
+ ned_start_loss + ned_end_loss + relation_loss
977
+ ) / 3
978
+
979
+ output_dict["ned_start_loss"] = ned_start_loss
980
+ output_dict["ned_end_loss"] = ned_end_loss
981
+ output_dict["re_loss"] = relation_loss
982
+
983
+ return output_dict
models/relik-reader-aida-deberta-small/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06ecdbcc11050fe88db21ad7b1e032ff2f28a5a819cb7ed6b6b3a62937c67637
3
+ size 577138490
models/relik-reader-aida-deberta-small/special_tokens_map.json ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "--NME--",
4
+ "[E-0]",
5
+ "[E-1]",
6
+ "[E-2]",
7
+ "[E-3]",
8
+ "[E-4]",
9
+ "[E-5]",
10
+ "[E-6]",
11
+ "[E-7]",
12
+ "[E-8]",
13
+ "[E-9]",
14
+ "[E-10]",
15
+ "[E-11]",
16
+ "[E-12]",
17
+ "[E-13]",
18
+ "[E-14]",
19
+ "[E-15]",
20
+ "[E-16]",
21
+ "[E-17]",
22
+ "[E-18]",
23
+ "[E-19]",
24
+ "[E-20]",
25
+ "[E-21]",
26
+ "[E-22]",
27
+ "[E-23]",
28
+ "[E-24]",
29
+ "[E-25]",
30
+ "[E-26]",
31
+ "[E-27]",
32
+ "[E-28]",
33
+ "[E-29]",
34
+ "[E-30]",
35
+ "[E-31]",
36
+ "[E-32]",
37
+ "[E-33]",
38
+ "[E-34]",
39
+ "[E-35]",
40
+ "[E-36]",
41
+ "[E-37]",
42
+ "[E-38]",
43
+ "[E-39]",
44
+ "[E-40]",
45
+ "[E-41]",
46
+ "[E-42]",
47
+ "[E-43]",
48
+ "[E-44]",
49
+ "[E-45]",
50
+ "[E-46]",
51
+ "[E-47]",
52
+ "[E-48]",
53
+ "[E-49]",
54
+ "[E-50]",
55
+ "[E-51]",
56
+ "[E-52]",
57
+ "[E-53]",
58
+ "[E-54]",
59
+ "[E-55]",
60
+ "[E-56]",
61
+ "[E-57]",
62
+ "[E-58]",
63
+ "[E-59]",
64
+ "[E-60]",
65
+ "[E-61]",
66
+ "[E-62]",
67
+ "[E-63]",
68
+ "[E-64]",
69
+ "[E-65]",
70
+ "[E-66]",
71
+ "[E-67]",
72
+ "[E-68]",
73
+ "[E-69]",
74
+ "[E-70]",
75
+ "[E-71]",
76
+ "[E-72]",
77
+ "[E-73]",
78
+ "[E-74]",
79
+ "[E-75]",
80
+ "[E-76]",
81
+ "[E-77]",
82
+ "[E-78]",
83
+ "[E-79]",
84
+ "[E-80]",
85
+ "[E-81]",
86
+ "[E-82]",
87
+ "[E-83]",
88
+ "[E-84]",
89
+ "[E-85]",
90
+ "[E-86]",
91
+ "[E-87]",
92
+ "[E-88]",
93
+ "[E-89]",
94
+ "[E-90]",
95
+ "[E-91]",
96
+ "[E-92]",
97
+ "[E-93]",
98
+ "[E-94]",
99
+ "[E-95]",
100
+ "[E-96]",
101
+ "[E-97]",
102
+ "[E-98]",
103
+ "[E-99]"
104
+ ],
105
+ "bos_token": "[CLS]",
106
+ "cls_token": "[CLS]",
107
+ "eos_token": "[SEP]",
108
+ "mask_token": "[MASK]",
109
+ "pad_token": "[PAD]",
110
+ "sep_token": "[SEP]",
111
+ "unk_token": "[UNK]"
112
+ }
models/relik-reader-aida-deberta-small/spm.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd
3
+ size 2464616
models/relik-reader-aida-deberta-small/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
models/relik-reader-aida-deberta-small/tokenizer_config.json ADDED
@@ -0,0 +1,970 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": true,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "[PAD]",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "[CLS]",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "[SEP]",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "[UNK]",
30
+ "lstrip": false,
31
+ "normalized": true,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "128000": {
37
+ "content": "[MASK]",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "128001": {
45
+ "content": "--NME--",
46
+ "lstrip": true,
47
+ "normalized": false,
48
+ "rstrip": true,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "128002": {
53
+ "content": "[E-0]",
54
+ "lstrip": true,
55
+ "normalized": false,
56
+ "rstrip": true,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "128003": {
61
+ "content": "[E-1]",
62
+ "lstrip": true,
63
+ "normalized": false,
64
+ "rstrip": true,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "128004": {
69
+ "content": "[E-2]",
70
+ "lstrip": true,
71
+ "normalized": false,
72
+ "rstrip": true,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "128005": {
77
+ "content": "[E-3]",
78
+ "lstrip": true,
79
+ "normalized": false,
80
+ "rstrip": true,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "128006": {
85
+ "content": "[E-4]",
86
+ "lstrip": true,
87
+ "normalized": false,
88
+ "rstrip": true,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "128007": {
93
+ "content": "[E-5]",
94
+ "lstrip": true,
95
+ "normalized": false,
96
+ "rstrip": true,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "128008": {
101
+ "content": "[E-6]",
102
+ "lstrip": true,
103
+ "normalized": false,
104
+ "rstrip": true,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "128009": {
109
+ "content": "[E-7]",
110
+ "lstrip": true,
111
+ "normalized": false,
112
+ "rstrip": true,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "128010": {
117
+ "content": "[E-8]",
118
+ "lstrip": true,
119
+ "normalized": false,
120
+ "rstrip": true,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "128011": {
125
+ "content": "[E-9]",
126
+ "lstrip": true,
127
+ "normalized": false,
128
+ "rstrip": true,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "128012": {
133
+ "content": "[E-10]",
134
+ "lstrip": true,
135
+ "normalized": false,
136
+ "rstrip": true,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "128013": {
141
+ "content": "[E-11]",
142
+ "lstrip": true,
143
+ "normalized": false,
144
+ "rstrip": true,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "128014": {
149
+ "content": "[E-12]",
150
+ "lstrip": true,
151
+ "normalized": false,
152
+ "rstrip": true,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "128015": {
157
+ "content": "[E-13]",
158
+ "lstrip": true,
159
+ "normalized": false,
160
+ "rstrip": true,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "128016": {
165
+ "content": "[E-14]",
166
+ "lstrip": true,
167
+ "normalized": false,
168
+ "rstrip": true,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "128017": {
173
+ "content": "[E-15]",
174
+ "lstrip": true,
175
+ "normalized": false,
176
+ "rstrip": true,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "128018": {
181
+ "content": "[E-16]",
182
+ "lstrip": true,
183
+ "normalized": false,
184
+ "rstrip": true,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "128019": {
189
+ "content": "[E-17]",
190
+ "lstrip": true,
191
+ "normalized": false,
192
+ "rstrip": true,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "128020": {
197
+ "content": "[E-18]",
198
+ "lstrip": true,
199
+ "normalized": false,
200
+ "rstrip": true,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "128021": {
205
+ "content": "[E-19]",
206
+ "lstrip": true,
207
+ "normalized": false,
208
+ "rstrip": true,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "128022": {
213
+ "content": "[E-20]",
214
+ "lstrip": true,
215
+ "normalized": false,
216
+ "rstrip": true,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "128023": {
221
+ "content": "[E-21]",
222
+ "lstrip": true,
223
+ "normalized": false,
224
+ "rstrip": true,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "128024": {
229
+ "content": "[E-22]",
230
+ "lstrip": true,
231
+ "normalized": false,
232
+ "rstrip": true,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "128025": {
237
+ "content": "[E-23]",
238
+ "lstrip": true,
239
+ "normalized": false,
240
+ "rstrip": true,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "128026": {
245
+ "content": "[E-24]",
246
+ "lstrip": true,
247
+ "normalized": false,
248
+ "rstrip": true,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "128027": {
253
+ "content": "[E-25]",
254
+ "lstrip": true,
255
+ "normalized": false,
256
+ "rstrip": true,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "128028": {
261
+ "content": "[E-26]",
262
+ "lstrip": true,
263
+ "normalized": false,
264
+ "rstrip": true,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "128029": {
269
+ "content": "[E-27]",
270
+ "lstrip": true,
271
+ "normalized": false,
272
+ "rstrip": true,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "128030": {
277
+ "content": "[E-28]",
278
+ "lstrip": true,
279
+ "normalized": false,
280
+ "rstrip": true,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "128031": {
285
+ "content": "[E-29]",
286
+ "lstrip": true,
287
+ "normalized": false,
288
+ "rstrip": true,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "128032": {
293
+ "content": "[E-30]",
294
+ "lstrip": true,
295
+ "normalized": false,
296
+ "rstrip": true,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "128033": {
301
+ "content": "[E-31]",
302
+ "lstrip": true,
303
+ "normalized": false,
304
+ "rstrip": true,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "128034": {
309
+ "content": "[E-32]",
310
+ "lstrip": true,
311
+ "normalized": false,
312
+ "rstrip": true,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "128035": {
317
+ "content": "[E-33]",
318
+ "lstrip": true,
319
+ "normalized": false,
320
+ "rstrip": true,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "128036": {
325
+ "content": "[E-34]",
326
+ "lstrip": true,
327
+ "normalized": false,
328
+ "rstrip": true,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "128037": {
333
+ "content": "[E-35]",
334
+ "lstrip": true,
335
+ "normalized": false,
336
+ "rstrip": true,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "128038": {
341
+ "content": "[E-36]",
342
+ "lstrip": true,
343
+ "normalized": false,
344
+ "rstrip": true,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "128039": {
349
+ "content": "[E-37]",
350
+ "lstrip": true,
351
+ "normalized": false,
352
+ "rstrip": true,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "128040": {
357
+ "content": "[E-38]",
358
+ "lstrip": true,
359
+ "normalized": false,
360
+ "rstrip": true,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "128041": {
365
+ "content": "[E-39]",
366
+ "lstrip": true,
367
+ "normalized": false,
368
+ "rstrip": true,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "128042": {
373
+ "content": "[E-40]",
374
+ "lstrip": true,
375
+ "normalized": false,
376
+ "rstrip": true,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "128043": {
381
+ "content": "[E-41]",
382
+ "lstrip": true,
383
+ "normalized": false,
384
+ "rstrip": true,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "128044": {
389
+ "content": "[E-42]",
390
+ "lstrip": true,
391
+ "normalized": false,
392
+ "rstrip": true,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "128045": {
397
+ "content": "[E-43]",
398
+ "lstrip": true,
399
+ "normalized": false,
400
+ "rstrip": true,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "128046": {
405
+ "content": "[E-44]",
406
+ "lstrip": true,
407
+ "normalized": false,
408
+ "rstrip": true,
409
+ "single_word": false,
410
+ "special": true
411
+ },
412
+ "128047": {
413
+ "content": "[E-45]",
414
+ "lstrip": true,
415
+ "normalized": false,
416
+ "rstrip": true,
417
+ "single_word": false,
418
+ "special": true
419
+ },
420
+ "128048": {
421
+ "content": "[E-46]",
422
+ "lstrip": true,
423
+ "normalized": false,
424
+ "rstrip": true,
425
+ "single_word": false,
426
+ "special": true
427
+ },
428
+ "128049": {
429
+ "content": "[E-47]",
430
+ "lstrip": true,
431
+ "normalized": false,
432
+ "rstrip": true,
433
+ "single_word": false,
434
+ "special": true
435
+ },
436
+ "128050": {
437
+ "content": "[E-48]",
438
+ "lstrip": true,
439
+ "normalized": false,
440
+ "rstrip": true,
441
+ "single_word": false,
442
+ "special": true
443
+ },
444
+ "128051": {
445
+ "content": "[E-49]",
446
+ "lstrip": true,
447
+ "normalized": false,
448
+ "rstrip": true,
449
+ "single_word": false,
450
+ "special": true
451
+ },
452
+ "128052": {
453
+ "content": "[E-50]",
454
+ "lstrip": true,
455
+ "normalized": false,
456
+ "rstrip": true,
457
+ "single_word": false,
458
+ "special": true
459
+ },
460
+ "128053": {
461
+ "content": "[E-51]",
462
+ "lstrip": true,
463
+ "normalized": false,
464
+ "rstrip": true,
465
+ "single_word": false,
466
+ "special": true
467
+ },
468
+ "128054": {
469
+ "content": "[E-52]",
470
+ "lstrip": true,
471
+ "normalized": false,
472
+ "rstrip": true,
473
+ "single_word": false,
474
+ "special": true
475
+ },
476
+ "128055": {
477
+ "content": "[E-53]",
478
+ "lstrip": true,
479
+ "normalized": false,
480
+ "rstrip": true,
481
+ "single_word": false,
482
+ "special": true
483
+ },
484
+ "128056": {
485
+ "content": "[E-54]",
486
+ "lstrip": true,
487
+ "normalized": false,
488
+ "rstrip": true,
489
+ "single_word": false,
490
+ "special": true
491
+ },
492
+ "128057": {
493
+ "content": "[E-55]",
494
+ "lstrip": true,
495
+ "normalized": false,
496
+ "rstrip": true,
497
+ "single_word": false,
498
+ "special": true
499
+ },
500
+ "128058": {
501
+ "content": "[E-56]",
502
+ "lstrip": true,
503
+ "normalized": false,
504
+ "rstrip": true,
505
+ "single_word": false,
506
+ "special": true
507
+ },
508
+ "128059": {
509
+ "content": "[E-57]",
510
+ "lstrip": true,
511
+ "normalized": false,
512
+ "rstrip": true,
513
+ "single_word": false,
514
+ "special": true
515
+ },
516
+ "128060": {
517
+ "content": "[E-58]",
518
+ "lstrip": true,
519
+ "normalized": false,
520
+ "rstrip": true,
521
+ "single_word": false,
522
+ "special": true
523
+ },
524
+ "128061": {
525
+ "content": "[E-59]",
526
+ "lstrip": true,
527
+ "normalized": false,
528
+ "rstrip": true,
529
+ "single_word": false,
530
+ "special": true
531
+ },
532
+ "128062": {
533
+ "content": "[E-60]",
534
+ "lstrip": true,
535
+ "normalized": false,
536
+ "rstrip": true,
537
+ "single_word": false,
538
+ "special": true
539
+ },
540
+ "128063": {
541
+ "content": "[E-61]",
542
+ "lstrip": true,
543
+ "normalized": false,
544
+ "rstrip": true,
545
+ "single_word": false,
546
+ "special": true
547
+ },
548
+ "128064": {
549
+ "content": "[E-62]",
550
+ "lstrip": true,
551
+ "normalized": false,
552
+ "rstrip": true,
553
+ "single_word": false,
554
+ "special": true
555
+ },
556
+ "128065": {
557
+ "content": "[E-63]",
558
+ "lstrip": true,
559
+ "normalized": false,
560
+ "rstrip": true,
561
+ "single_word": false,
562
+ "special": true
563
+ },
564
+ "128066": {
565
+ "content": "[E-64]",
566
+ "lstrip": true,
567
+ "normalized": false,
568
+ "rstrip": true,
569
+ "single_word": false,
570
+ "special": true
571
+ },
572
+ "128067": {
573
+ "content": "[E-65]",
574
+ "lstrip": true,
575
+ "normalized": false,
576
+ "rstrip": true,
577
+ "single_word": false,
578
+ "special": true
579
+ },
580
+ "128068": {
581
+ "content": "[E-66]",
582
+ "lstrip": true,
583
+ "normalized": false,
584
+ "rstrip": true,
585
+ "single_word": false,
586
+ "special": true
587
+ },
588
+ "128069": {
589
+ "content": "[E-67]",
590
+ "lstrip": true,
591
+ "normalized": false,
592
+ "rstrip": true,
593
+ "single_word": false,
594
+ "special": true
595
+ },
596
+ "128070": {
597
+ "content": "[E-68]",
598
+ "lstrip": true,
599
+ "normalized": false,
600
+ "rstrip": true,
601
+ "single_word": false,
602
+ "special": true
603
+ },
604
+ "128071": {
605
+ "content": "[E-69]",
606
+ "lstrip": true,
607
+ "normalized": false,
608
+ "rstrip": true,
609
+ "single_word": false,
610
+ "special": true
611
+ },
612
+ "128072": {
613
+ "content": "[E-70]",
614
+ "lstrip": true,
615
+ "normalized": false,
616
+ "rstrip": true,
617
+ "single_word": false,
618
+ "special": true
619
+ },
620
+ "128073": {
621
+ "content": "[E-71]",
622
+ "lstrip": true,
623
+ "normalized": false,
624
+ "rstrip": true,
625
+ "single_word": false,
626
+ "special": true
627
+ },
628
+ "128074": {
629
+ "content": "[E-72]",
630
+ "lstrip": true,
631
+ "normalized": false,
632
+ "rstrip": true,
633
+ "single_word": false,
634
+ "special": true
635
+ },
636
+ "128075": {
637
+ "content": "[E-73]",
638
+ "lstrip": true,
639
+ "normalized": false,
640
+ "rstrip": true,
641
+ "single_word": false,
642
+ "special": true
643
+ },
644
+ "128076": {
645
+ "content": "[E-74]",
646
+ "lstrip": true,
647
+ "normalized": false,
648
+ "rstrip": true,
649
+ "single_word": false,
650
+ "special": true
651
+ },
652
+ "128077": {
653
+ "content": "[E-75]",
654
+ "lstrip": true,
655
+ "normalized": false,
656
+ "rstrip": true,
657
+ "single_word": false,
658
+ "special": true
659
+ },
660
+ "128078": {
661
+ "content": "[E-76]",
662
+ "lstrip": true,
663
+ "normalized": false,
664
+ "rstrip": true,
665
+ "single_word": false,
666
+ "special": true
667
+ },
668
+ "128079": {
669
+ "content": "[E-77]",
670
+ "lstrip": true,
671
+ "normalized": false,
672
+ "rstrip": true,
673
+ "single_word": false,
674
+ "special": true
675
+ },
676
+ "128080": {
677
+ "content": "[E-78]",
678
+ "lstrip": true,
679
+ "normalized": false,
680
+ "rstrip": true,
681
+ "single_word": false,
682
+ "special": true
683
+ },
684
+ "128081": {
685
+ "content": "[E-79]",
686
+ "lstrip": true,
687
+ "normalized": false,
688
+ "rstrip": true,
689
+ "single_word": false,
690
+ "special": true
691
+ },
692
+ "128082": {
693
+ "content": "[E-80]",
694
+ "lstrip": true,
695
+ "normalized": false,
696
+ "rstrip": true,
697
+ "single_word": false,
698
+ "special": true
699
+ },
700
+ "128083": {
701
+ "content": "[E-81]",
702
+ "lstrip": true,
703
+ "normalized": false,
704
+ "rstrip": true,
705
+ "single_word": false,
706
+ "special": true
707
+ },
708
+ "128084": {
709
+ "content": "[E-82]",
710
+ "lstrip": true,
711
+ "normalized": false,
712
+ "rstrip": true,
713
+ "single_word": false,
714
+ "special": true
715
+ },
716
+ "128085": {
717
+ "content": "[E-83]",
718
+ "lstrip": true,
719
+ "normalized": false,
720
+ "rstrip": true,
721
+ "single_word": false,
722
+ "special": true
723
+ },
724
+ "128086": {
725
+ "content": "[E-84]",
726
+ "lstrip": true,
727
+ "normalized": false,
728
+ "rstrip": true,
729
+ "single_word": false,
730
+ "special": true
731
+ },
732
+ "128087": {
733
+ "content": "[E-85]",
734
+ "lstrip": true,
735
+ "normalized": false,
736
+ "rstrip": true,
737
+ "single_word": false,
738
+ "special": true
739
+ },
740
+ "128088": {
741
+ "content": "[E-86]",
742
+ "lstrip": true,
743
+ "normalized": false,
744
+ "rstrip": true,
745
+ "single_word": false,
746
+ "special": true
747
+ },
748
+ "128089": {
749
+ "content": "[E-87]",
750
+ "lstrip": true,
751
+ "normalized": false,
752
+ "rstrip": true,
753
+ "single_word": false,
754
+ "special": true
755
+ },
756
+ "128090": {
757
+ "content": "[E-88]",
758
+ "lstrip": true,
759
+ "normalized": false,
760
+ "rstrip": true,
761
+ "single_word": false,
762
+ "special": true
763
+ },
764
+ "128091": {
765
+ "content": "[E-89]",
766
+ "lstrip": true,
767
+ "normalized": false,
768
+ "rstrip": true,
769
+ "single_word": false,
770
+ "special": true
771
+ },
772
+ "128092": {
773
+ "content": "[E-90]",
774
+ "lstrip": true,
775
+ "normalized": false,
776
+ "rstrip": true,
777
+ "single_word": false,
778
+ "special": true
779
+ },
780
+ "128093": {
781
+ "content": "[E-91]",
782
+ "lstrip": true,
783
+ "normalized": false,
784
+ "rstrip": true,
785
+ "single_word": false,
786
+ "special": true
787
+ },
788
+ "128094": {
789
+ "content": "[E-92]",
790
+ "lstrip": true,
791
+ "normalized": false,
792
+ "rstrip": true,
793
+ "single_word": false,
794
+ "special": true
795
+ },
796
+ "128095": {
797
+ "content": "[E-93]",
798
+ "lstrip": true,
799
+ "normalized": false,
800
+ "rstrip": true,
801
+ "single_word": false,
802
+ "special": true
803
+ },
804
+ "128096": {
805
+ "content": "[E-94]",
806
+ "lstrip": true,
807
+ "normalized": false,
808
+ "rstrip": true,
809
+ "single_word": false,
810
+ "special": true
811
+ },
812
+ "128097": {
813
+ "content": "[E-95]",
814
+ "lstrip": true,
815
+ "normalized": false,
816
+ "rstrip": true,
817
+ "single_word": false,
818
+ "special": true
819
+ },
820
+ "128098": {
821
+ "content": "[E-96]",
822
+ "lstrip": true,
823
+ "normalized": false,
824
+ "rstrip": true,
825
+ "single_word": false,
826
+ "special": true
827
+ },
828
+ "128099": {
829
+ "content": "[E-97]",
830
+ "lstrip": true,
831
+ "normalized": false,
832
+ "rstrip": true,
833
+ "single_word": false,
834
+ "special": true
835
+ },
836
+ "128100": {
837
+ "content": "[E-98]",
838
+ "lstrip": true,
839
+ "normalized": false,
840
+ "rstrip": true,
841
+ "single_word": false,
842
+ "special": true
843
+ },
844
+ "128101": {
845
+ "content": "[E-99]",
846
+ "lstrip": true,
847
+ "normalized": false,
848
+ "rstrip": true,
849
+ "single_word": false,
850
+ "special": true
851
+ }
852
+ },
853
+ "additional_special_tokens": [
854
+ "--NME--",
855
+ "[E-0]",
856
+ "[E-1]",
857
+ "[E-2]",
858
+ "[E-3]",
859
+ "[E-4]",
860
+ "[E-5]",
861
+ "[E-6]",
862
+ "[E-7]",
863
+ "[E-8]",
864
+ "[E-9]",
865
+ "[E-10]",
866
+ "[E-11]",
867
+ "[E-12]",
868
+ "[E-13]",
869
+ "[E-14]",
870
+ "[E-15]",
871
+ "[E-16]",
872
+ "[E-17]",
873
+ "[E-18]",
874
+ "[E-19]",
875
+ "[E-20]",
876
+ "[E-21]",
877
+ "[E-22]",
878
+ "[E-23]",
879
+ "[E-24]",
880
+ "[E-25]",
881
+ "[E-26]",
882
+ "[E-27]",
883
+ "[E-28]",
884
+ "[E-29]",
885
+ "[E-30]",
886
+ "[E-31]",
887
+ "[E-32]",
888
+ "[E-33]",
889
+ "[E-34]",
890
+ "[E-35]",
891
+ "[E-36]",
892
+ "[E-37]",
893
+ "[E-38]",
894
+ "[E-39]",
895
+ "[E-40]",
896
+ "[E-41]",
897
+ "[E-42]",
898
+ "[E-43]",
899
+ "[E-44]",
900
+ "[E-45]",
901
+ "[E-46]",
902
+ "[E-47]",
903
+ "[E-48]",
904
+ "[E-49]",
905
+ "[E-50]",
906
+ "[E-51]",
907
+ "[E-52]",
908
+ "[E-53]",
909
+ "[E-54]",
910
+ "[E-55]",
911
+ "[E-56]",
912
+ "[E-57]",
913
+ "[E-58]",
914
+ "[E-59]",
915
+ "[E-60]",
916
+ "[E-61]",
917
+ "[E-62]",
918
+ "[E-63]",
919
+ "[E-64]",
920
+ "[E-65]",
921
+ "[E-66]",
922
+ "[E-67]",
923
+ "[E-68]",
924
+ "[E-69]",
925
+ "[E-70]",
926
+ "[E-71]",
927
+ "[E-72]",
928
+ "[E-73]",
929
+ "[E-74]",
930
+ "[E-75]",
931
+ "[E-76]",
932
+ "[E-77]",
933
+ "[E-78]",
934
+ "[E-79]",
935
+ "[E-80]",
936
+ "[E-81]",
937
+ "[E-82]",
938
+ "[E-83]",
939
+ "[E-84]",
940
+ "[E-85]",
941
+ "[E-86]",
942
+ "[E-87]",
943
+ "[E-88]",
944
+ "[E-89]",
945
+ "[E-90]",
946
+ "[E-91]",
947
+ "[E-92]",
948
+ "[E-93]",
949
+ "[E-94]",
950
+ "[E-95]",
951
+ "[E-96]",
952
+ "[E-97]",
953
+ "[E-98]",
954
+ "[E-99]"
955
+ ],
956
+ "bos_token": "[CLS]",
957
+ "clean_up_tokenization_spaces": true,
958
+ "cls_token": "[CLS]",
959
+ "do_lower_case": false,
960
+ "eos_token": "[SEP]",
961
+ "mask_token": "[MASK]",
962
+ "model_max_length": 1000000000000000019884624838656,
963
+ "pad_token": "[PAD]",
964
+ "sep_token": "[SEP]",
965
+ "sp_model_kwargs": {},
966
+ "split_by_punct": false,
967
+ "tokenizer_class": "DebertaV2Tokenizer",
968
+ "unk_token": "[UNK]",
969
+ "vocab_type": "spm"
970
+ }
models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index/config.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ _target_: relik.retriever.indexers.inmemory.InMemoryDocumentIndex
2
+ documents:
3
+ _target_: relik.retriever.data.labels.Labels
4
+ embeddings:
5
+ _target_: torch.Tensor
6
+ name_or_dir: /media/data/EL/models/experiments/e5-small-15hard-400inbatch-64maxlen-32words-topics/2023-06-04/07-22-35/wandb/run-20230604_072319-3ql9q8oa/files/retriever/index
7
+ device: cpu
8
+ precision: null
models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index/documents.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d367a0db7f8959d0d23f78d0af229856929a552d0195079422bf8afaaad2d70
3
+ size 2813615153
models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index/embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fde55d5649350819a04dcbc242114486ccb31030df10f64b6b7213a983eecc0a
3
+ size 4533909983
models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder/added_tokens.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "[CLS]": 101,
3
+ "[MASK]": 103,
4
+ "[PAD]": 0,
5
+ "[SEP]": 102,
6
+ "[UNK]": 100
7
+ }
models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder/config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "intfloat/e5-small-v2",
3
+ "architectures": [
4
+ "GoldenRetrieverModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoModel": "hf.GoldenRetrieverModel"
9
+ },
10
+ "classifier_dropout": null,
11
+ "hidden_act": "gelu",
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 384,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 1536,
16
+ "layer_norm_eps": 1e-12,
17
+ "max_position_embeddings": 512,
18
+ "model_type": "bert",
19
+ "num_attention_heads": 12,
20
+ "num_hidden_layers": 12,
21
+ "pad_token_id": 0,
22
+ "position_embedding_type": "absolute",
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.34.0",
25
+ "type_vocab_size": 2,
26
+ "use_cache": true,
27
+ "vocab_size": 30522
28
+ }
models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder/hf.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import torch
4
+ from transformers import PretrainedConfig
5
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
6
+ from transformers.models.bert.modeling_bert import BertModel
7
+
8
+
9
+ class GoldenRetrieverConfig(PretrainedConfig):
10
+ model_type = "bert"
11
+
12
+ def __init__(
13
+ self,
14
+ vocab_size=30522,
15
+ hidden_size=768,
16
+ num_hidden_layers=12,
17
+ num_attention_heads=12,
18
+ intermediate_size=3072,
19
+ hidden_act="gelu",
20
+ hidden_dropout_prob=0.1,
21
+ attention_probs_dropout_prob=0.1,
22
+ max_position_embeddings=512,
23
+ type_vocab_size=2,
24
+ initializer_range=0.02,
25
+ layer_norm_eps=1e-12,
26
+ pad_token_id=0,
27
+ position_embedding_type="absolute",
28
+ use_cache=True,
29
+ classifier_dropout=None,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
33
+
34
+ self.vocab_size = vocab_size
35
+ self.hidden_size = hidden_size
36
+ self.num_hidden_layers = num_hidden_layers
37
+ self.num_attention_heads = num_attention_heads
38
+ self.hidden_act = hidden_act
39
+ self.intermediate_size = intermediate_size
40
+ self.hidden_dropout_prob = hidden_dropout_prob
41
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
42
+ self.max_position_embeddings = max_position_embeddings
43
+ self.type_vocab_size = type_vocab_size
44
+ self.initializer_range = initializer_range
45
+ self.layer_norm_eps = layer_norm_eps
46
+ self.position_embedding_type = position_embedding_type
47
+ self.use_cache = use_cache
48
+ self.classifier_dropout = classifier_dropout
49
+
50
+
51
+ class GoldenRetrieverModel(BertModel):
52
+ config_class = GoldenRetrieverConfig
53
+
54
+ def __init__(self, config, *args, **kwargs):
55
+ super().__init__(config)
56
+ self.layer_norm_layer = torch.nn.LayerNorm(
57
+ config.hidden_size, eps=config.layer_norm_eps
58
+ )
59
+
60
+ def forward(
61
+ self, **kwargs
62
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
63
+ attention_mask = kwargs.get("attention_mask", None)
64
+ model_outputs = super().forward(**kwargs)
65
+ if attention_mask is None:
66
+ pooler_output = model_outputs.pooler_output
67
+ else:
68
+ token_embeddings = model_outputs.last_hidden_state
69
+ input_mask_expanded = (
70
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
71
+ )
72
+ pooler_output = torch.sum(
73
+ token_embeddings * input_mask_expanded, 1
74
+ ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
75
+
76
+ pooler_output = self.layer_norm_layer(pooler_output)
77
+
78
+ if not kwargs.get("return_dict", True):
79
+ return (model_outputs[0], pooler_output) + model_outputs[2:]
80
+
81
+ return BaseModelOutputWithPoolingAndCrossAttentions(
82
+ last_hidden_state=model_outputs.last_hidden_state,
83
+ pooler_output=pooler_output,
84
+ past_key_values=model_outputs.past_key_values,
85
+ hidden_states=model_outputs.hidden_states,
86
+ attentions=model_outputs.attentions,
87
+ cross_attentions=model_outputs.cross_attentions,
88
+ )
models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:201092855fe86eff5afb1b68ea9cdaf0af98579fbb7191ad87d9726bb95e5d1f
3
+ size 133508078
models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder/tokenizer_config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "additional_special_tokens": [],
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "[CLS]",
47
+ "do_lower_case": true,
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "pad_token": "[PAD]",
51
+ "sep_token": "[SEP]",
52
+ "strip_accents": null,
53
+ "tokenize_chinese_chars": true,
54
+ "tokenizer_class": "BertTokenizer",
55
+ "unk_token": "[UNK]"
56
+ }
models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
scripts/setup.sh CHANGED
File without changes