sebastian-hofstaetter commited on
Commit
04f9b13
1 Parent(s): 0533b10

inital model & readme

Browse files
README.md ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: "en"
3
+ tags:
4
+ - knowledge-distillation
5
+ datasets:
6
+ - ms_marco
7
+ ---
8
+
9
+ # Margin-MSE Trained PreTTR
10
+
11
+ We provide a retrieval trained DistilBert-based PreTTR model (https://arxiv.org/abs/2004.14255). Our model is trained with Margin-MSE using a 3 teacher BERT_Cat (concatenated BERT scoring) ensemble on MSMARCO-Passage.
12
+
13
+ This instance can be used to **re-rank a candidate set**. The architecture is a 6-layer DistilBERT, split at layer 3, with an additional single linear layer at the end for scoring the CLS token.
14
+
15
+ If you want to know more about our simple, yet effective knowledge distillation method for efficient information retrieval models for a variety of student architectures that is used for this model instance check out our paper: https://arxiv.org/abs/2010.02666 🎉
16
+
17
+ For more information, training data, source code, and a minimal usage example please visit: https://github.com/sebastian-hofstaetter/neural-ranking-kd
18
+
19
+ ## Configuration
20
+
21
+ - We split the DistilBERT in half at layer 3
22
+
23
+ ## Model Code
24
+
25
+ ````python
26
+ from transformers import *
27
+ from transformers.models.distilbert.modeling_distilbert import *
28
+ import math
29
+ import torch
30
+ from torch import nn as nn
31
+
32
+ class PreTTRConfig(DistilBertConfig):
33
+ join_layer_idx = 3
34
+
35
+ class PreTTR(DistilBertModel):
36
+ '''
37
+ PreTTR changes the distilbert model from huggingface to be able to split query and document until a set layer,
38
+ we skipped compression present in the original
39
+
40
+ from: Efficient Document Re-Ranking for Transformers by Precomputing Term Representations
41
+ MacAvaney, et al. https://arxiv.org/abs/2004.14255
42
+ '''
43
+ config_class = PreTTRConfig
44
+
45
+ def __init__(self, config):
46
+ super().__init__(config)
47
+ self.transformer = SplitTransformer(config) # Encoder, we override the classes, but the names stay the same -> so it gets properly initialized
48
+ self.embeddings = PosOffsetEmbeddings(config) # Embeddings
49
+ self._classification_layer = torch.nn.Linear(self.config.hidden_size, 1, bias=False)
50
+
51
+ self.join_layer_idx = config.join_layer_idx
52
+
53
+ def forward(
54
+ self,
55
+ query,
56
+ document,
57
+ use_fp16: bool = False) -> torch.Tensor:
58
+
59
+ with torch.cuda.amp.autocast(enabled=use_fp16):
60
+
61
+ query_input_ids = query["input_ids"]
62
+ query_attention_mask = query["attention_mask"]
63
+
64
+ document_input_ids = document["input_ids"][:, 1:]
65
+ document_attention_mask = document["attention_mask"][:, 1:]
66
+
67
+ query_embs = self.embeddings(query_input_ids) # (bs, seq_length, dim)
68
+ document_embs = self.embeddings(document_input_ids, query_input_ids.shape[-1]) # (bs, seq_length, dim)
69
+
70
+ tfmr_output = self.transformer(
71
+ query_embs=query_embs,
72
+ query_mask=query_attention_mask,
73
+ doc_embs=document_embs,
74
+ doc_mask=document_attention_mask,
75
+ join_layer_idx=self.join_layer_idx
76
+ )
77
+ hidden_state = tfmr_output[0]
78
+
79
+ score = self._classification_layer(hidden_state[:, 0, :]).squeeze()
80
+
81
+ return score
82
+
83
+
84
+ class PosOffsetEmbeddings(nn.Module):
85
+ def __init__(self, config):
86
+ super().__init__()
87
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
88
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
89
+ if config.sinusoidal_pos_embds:
90
+ create_sinusoidal_embeddings(
91
+ n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
92
+ )
93
+
94
+ self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
95
+ self.dropout = nn.Dropout(config.dropout)
96
+
97
+ def forward(self, input_ids, pos_offset=0):
98
+ """
99
+ Parameters
100
+ ----------
101
+ input_ids: torch.tensor(bs, max_seq_length)
102
+ The token ids to embed.
103
+
104
+ Outputs
105
+ -------
106
+ embeddings: torch.tensor(bs, max_seq_length, dim)
107
+ The embedded tokens (plus position embeddings, no token_type embeddings)
108
+ """
109
+ seq_length = input_ids.size(1)
110
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
111
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + pos_offset # (bs, max_seq_length)
112
+
113
+ word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
114
+ position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
115
+
116
+ embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
117
+ embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
118
+ embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim)
119
+ return embeddings
120
+
121
+
122
+ class SplitTransformer(nn.Module):
123
+ def __init__(self, config):
124
+ super().__init__()
125
+ self.n_layers = config.n_layers
126
+
127
+ layer = TransformerBlock(config)
128
+ self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])
129
+
130
+ def forward(self, query_embs, query_mask, doc_embs, doc_mask, join_layer_idx, output_attentions=False, output_hidden_states=False):
131
+ """
132
+ Parameters
133
+ ----------
134
+ x: torch.tensor(bs, seq_length, dim)
135
+ Input sequence embedded.
136
+ attn_mask: torch.tensor(bs, seq_length)
137
+ Attention mask on the sequence.
138
+
139
+ Outputs
140
+ -------
141
+ hidden_state: torch.tensor(bs, seq_length, dim)
142
+ Sequence of hiddens states in the last (top) layer
143
+ all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
144
+ Tuple of length n_layers with the hidden states from each layer.
145
+ Optional: only if output_hidden_states=True
146
+ all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
147
+ Tuple of length n_layers with the attention weights from each layer
148
+ Optional: only if output_attentions=True
149
+ """
150
+ all_hidden_states = ()
151
+ all_attentions = ()
152
+
153
+ #
154
+ # query / doc sep.
155
+ #
156
+ hidden_state_q = query_embs
157
+ hidden_state_d = doc_embs
158
+ for layer_module in self.layer[:join_layer_idx]:
159
+
160
+ layer_outputs_q = layer_module(
161
+ x=hidden_state_q, attn_mask=query_mask, head_mask=None, output_attentions=output_attentions
162
+ )
163
+ hidden_state_q = layer_outputs_q[-1]
164
+
165
+ layer_outputs_d = layer_module(
166
+ x=hidden_state_d, attn_mask=doc_mask, head_mask=None, output_attentions=output_attentions
167
+ )
168
+ hidden_state_d = layer_outputs_d[-1]
169
+
170
+ #
171
+ # combine
172
+ #
173
+ x = torch.cat([hidden_state_q, hidden_state_d], dim=1)
174
+ attn_mask = torch.cat([query_mask, doc_mask], dim=1)
175
+
176
+ #
177
+ # combined
178
+ #
179
+ hidden_state = x
180
+ for layer_module in self.layer[join_layer_idx:]:
181
+ layer_outputs = layer_module(
182
+ x=hidden_state, attn_mask=attn_mask, head_mask=None, output_attentions=output_attentions
183
+ )
184
+ hidden_state = layer_outputs[-1]
185
+
186
+ # Add last layer
187
+ if output_hidden_states:
188
+ all_hidden_states = all_hidden_states + (hidden_state,)
189
+
190
+ outputs = (hidden_state,)
191
+ if output_hidden_states:
192
+ outputs = outputs + (all_hidden_states,)
193
+ if output_attentions:
194
+ outputs = outputs + (all_attentions,)
195
+ return outputs # last-layer hidden state, (all hidden states), (all attentions)
196
+
197
+ #
198
+ # init the model & tokenizer (using the distilbert tokenizer)
199
+ #
200
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # honestly not sure if that is the best way to go, but it works :)
201
+ model = PreTTR.from_pretrained("sebastian-hofstaetter/prettr-distilbert-split_at_3-margin_mse-T2-msmarco")
202
+ ````
203
+
204
+ ## Effectiveness on MSMARCO Passage
205
+
206
+ We trained our model on the MSMARCO standard ("small"-400K query) training triples with knowledge distillation with a batch size of 32 on a single consumer-grade GPU (11GB memory).
207
+
208
+ For re-ranking we used the top-1000 BM25 results.
209
+
210
+ ### MSMARCO-DEV
211
+
212
+ Here, we use the larger 49K query DEV set (same range as the smaller 7K DEV set, minimal changes possible)
213
+
214
+ | | MRR@10 | NDCG@10 |
215
+ |----------------------------------|--------|---------|
216
+ | BM25 | .194 | .241 |
217
+ | **Margin-MSE PreTTR** (Re-ranking) | .386 | .447 |
218
+
219
+ For more metrics, baselines, info and analysis, please see the paper: https://arxiv.org/abs/2010.02666
220
+
221
+ ## Limitations & Bias
222
+
223
+ - The model inherits social biases from both DistilBERT and MSMARCO.
224
+
225
+ - The model is only trained on relatively short passages of MSMARCO (avg. 60 words length), so it might struggle with longer text.
226
+
227
+
228
+ ## Citation
229
+
230
+ If you use our model checkpoint please cite our work as:
231
+
232
+ ```
233
+ @misc{hofstaetter2020_crossarchitecture_kd,
234
+ title={Improving Efficient Neural Ranking Models with Cross-Architecture Knowledge Distillation},
235
+ author={Sebastian Hofst{\"a}tter and Sophia Althammer and Michael Schr{\"o}der and Mete Sertkan and Allan Hanbury},
236
+ year={2020},
237
+ eprint={2010.02666},
238
+ archivePrefix={arXiv},
239
+ primaryClass={cs.IR}
240
+ }
241
+ ```
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "distilbert-base-uncased",
3
+ "activation": "gelu",
4
+ "architectures": [
5
+ "PreTTR"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "hidden_dim": 3072,
11
+ "initializer_range": 0.02,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "distilbert",
14
+ "n_heads": 12,
15
+ "n_layers": 6,
16
+ "pad_token_id": 0,
17
+ "qa_dropout": 0.1,
18
+ "seq_classif_dropout": 0.2,
19
+ "sinusoidal_pos_embds": false,
20
+ "tie_weights_": true,
21
+ "transformers_version": "4.8.1",
22
+ "vocab_size": 30522
23
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e6090f1c5db97bc6f478e3fe66ddcce00c1dc3a006f155fe1639bf248d31b4e
3
+ size 265490231
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "distilbert-base-uncased", "tokenizer_class": "DistilBertTokenizer"}
vocab.txt ADDED
The diff for this file is too large to render. See raw diff