sebastian-hofstaetter
commited on
Commit
·
04f9b13
1
Parent(s):
0533b10
inital model & readme
Browse files- README.md +241 -0
- config.json +23 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +1 -0
- tokenizer.json +0 -0
- tokenizer_config.json +1 -0
- vocab.txt +0 -0
README.md
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: "en"
|
3 |
+
tags:
|
4 |
+
- knowledge-distillation
|
5 |
+
datasets:
|
6 |
+
- ms_marco
|
7 |
+
---
|
8 |
+
|
9 |
+
# Margin-MSE Trained PreTTR
|
10 |
+
|
11 |
+
We provide a retrieval trained DistilBert-based PreTTR model (https://arxiv.org/abs/2004.14255). Our model is trained with Margin-MSE using a 3 teacher BERT_Cat (concatenated BERT scoring) ensemble on MSMARCO-Passage.
|
12 |
+
|
13 |
+
This instance can be used to **re-rank a candidate set**. The architecture is a 6-layer DistilBERT, split at layer 3, with an additional single linear layer at the end for scoring the CLS token.
|
14 |
+
|
15 |
+
If you want to know more about our simple, yet effective knowledge distillation method for efficient information retrieval models for a variety of student architectures that is used for this model instance check out our paper: https://arxiv.org/abs/2010.02666 🎉
|
16 |
+
|
17 |
+
For more information, training data, source code, and a minimal usage example please visit: https://github.com/sebastian-hofstaetter/neural-ranking-kd
|
18 |
+
|
19 |
+
## Configuration
|
20 |
+
|
21 |
+
- We split the DistilBERT in half at layer 3
|
22 |
+
|
23 |
+
## Model Code
|
24 |
+
|
25 |
+
````python
|
26 |
+
from transformers import *
|
27 |
+
from transformers.models.distilbert.modeling_distilbert import *
|
28 |
+
import math
|
29 |
+
import torch
|
30 |
+
from torch import nn as nn
|
31 |
+
|
32 |
+
class PreTTRConfig(DistilBertConfig):
|
33 |
+
join_layer_idx = 3
|
34 |
+
|
35 |
+
class PreTTR(DistilBertModel):
|
36 |
+
'''
|
37 |
+
PreTTR changes the distilbert model from huggingface to be able to split query and document until a set layer,
|
38 |
+
we skipped compression present in the original
|
39 |
+
|
40 |
+
from: Efficient Document Re-Ranking for Transformers by Precomputing Term Representations
|
41 |
+
MacAvaney, et al. https://arxiv.org/abs/2004.14255
|
42 |
+
'''
|
43 |
+
config_class = PreTTRConfig
|
44 |
+
|
45 |
+
def __init__(self, config):
|
46 |
+
super().__init__(config)
|
47 |
+
self.transformer = SplitTransformer(config) # Encoder, we override the classes, but the names stay the same -> so it gets properly initialized
|
48 |
+
self.embeddings = PosOffsetEmbeddings(config) # Embeddings
|
49 |
+
self._classification_layer = torch.nn.Linear(self.config.hidden_size, 1, bias=False)
|
50 |
+
|
51 |
+
self.join_layer_idx = config.join_layer_idx
|
52 |
+
|
53 |
+
def forward(
|
54 |
+
self,
|
55 |
+
query,
|
56 |
+
document,
|
57 |
+
use_fp16: bool = False) -> torch.Tensor:
|
58 |
+
|
59 |
+
with torch.cuda.amp.autocast(enabled=use_fp16):
|
60 |
+
|
61 |
+
query_input_ids = query["input_ids"]
|
62 |
+
query_attention_mask = query["attention_mask"]
|
63 |
+
|
64 |
+
document_input_ids = document["input_ids"][:, 1:]
|
65 |
+
document_attention_mask = document["attention_mask"][:, 1:]
|
66 |
+
|
67 |
+
query_embs = self.embeddings(query_input_ids) # (bs, seq_length, dim)
|
68 |
+
document_embs = self.embeddings(document_input_ids, query_input_ids.shape[-1]) # (bs, seq_length, dim)
|
69 |
+
|
70 |
+
tfmr_output = self.transformer(
|
71 |
+
query_embs=query_embs,
|
72 |
+
query_mask=query_attention_mask,
|
73 |
+
doc_embs=document_embs,
|
74 |
+
doc_mask=document_attention_mask,
|
75 |
+
join_layer_idx=self.join_layer_idx
|
76 |
+
)
|
77 |
+
hidden_state = tfmr_output[0]
|
78 |
+
|
79 |
+
score = self._classification_layer(hidden_state[:, 0, :]).squeeze()
|
80 |
+
|
81 |
+
return score
|
82 |
+
|
83 |
+
|
84 |
+
class PosOffsetEmbeddings(nn.Module):
|
85 |
+
def __init__(self, config):
|
86 |
+
super().__init__()
|
87 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
|
88 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
|
89 |
+
if config.sinusoidal_pos_embds:
|
90 |
+
create_sinusoidal_embeddings(
|
91 |
+
n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
|
92 |
+
)
|
93 |
+
|
94 |
+
self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
|
95 |
+
self.dropout = nn.Dropout(config.dropout)
|
96 |
+
|
97 |
+
def forward(self, input_ids, pos_offset=0):
|
98 |
+
"""
|
99 |
+
Parameters
|
100 |
+
----------
|
101 |
+
input_ids: torch.tensor(bs, max_seq_length)
|
102 |
+
The token ids to embed.
|
103 |
+
|
104 |
+
Outputs
|
105 |
+
-------
|
106 |
+
embeddings: torch.tensor(bs, max_seq_length, dim)
|
107 |
+
The embedded tokens (plus position embeddings, no token_type embeddings)
|
108 |
+
"""
|
109 |
+
seq_length = input_ids.size(1)
|
110 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
|
111 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + pos_offset # (bs, max_seq_length)
|
112 |
+
|
113 |
+
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
|
114 |
+
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
|
115 |
+
|
116 |
+
embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
|
117 |
+
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
|
118 |
+
embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim)
|
119 |
+
return embeddings
|
120 |
+
|
121 |
+
|
122 |
+
class SplitTransformer(nn.Module):
|
123 |
+
def __init__(self, config):
|
124 |
+
super().__init__()
|
125 |
+
self.n_layers = config.n_layers
|
126 |
+
|
127 |
+
layer = TransformerBlock(config)
|
128 |
+
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])
|
129 |
+
|
130 |
+
def forward(self, query_embs, query_mask, doc_embs, doc_mask, join_layer_idx, output_attentions=False, output_hidden_states=False):
|
131 |
+
"""
|
132 |
+
Parameters
|
133 |
+
----------
|
134 |
+
x: torch.tensor(bs, seq_length, dim)
|
135 |
+
Input sequence embedded.
|
136 |
+
attn_mask: torch.tensor(bs, seq_length)
|
137 |
+
Attention mask on the sequence.
|
138 |
+
|
139 |
+
Outputs
|
140 |
+
-------
|
141 |
+
hidden_state: torch.tensor(bs, seq_length, dim)
|
142 |
+
Sequence of hiddens states in the last (top) layer
|
143 |
+
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
|
144 |
+
Tuple of length n_layers with the hidden states from each layer.
|
145 |
+
Optional: only if output_hidden_states=True
|
146 |
+
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
|
147 |
+
Tuple of length n_layers with the attention weights from each layer
|
148 |
+
Optional: only if output_attentions=True
|
149 |
+
"""
|
150 |
+
all_hidden_states = ()
|
151 |
+
all_attentions = ()
|
152 |
+
|
153 |
+
#
|
154 |
+
# query / doc sep.
|
155 |
+
#
|
156 |
+
hidden_state_q = query_embs
|
157 |
+
hidden_state_d = doc_embs
|
158 |
+
for layer_module in self.layer[:join_layer_idx]:
|
159 |
+
|
160 |
+
layer_outputs_q = layer_module(
|
161 |
+
x=hidden_state_q, attn_mask=query_mask, head_mask=None, output_attentions=output_attentions
|
162 |
+
)
|
163 |
+
hidden_state_q = layer_outputs_q[-1]
|
164 |
+
|
165 |
+
layer_outputs_d = layer_module(
|
166 |
+
x=hidden_state_d, attn_mask=doc_mask, head_mask=None, output_attentions=output_attentions
|
167 |
+
)
|
168 |
+
hidden_state_d = layer_outputs_d[-1]
|
169 |
+
|
170 |
+
#
|
171 |
+
# combine
|
172 |
+
#
|
173 |
+
x = torch.cat([hidden_state_q, hidden_state_d], dim=1)
|
174 |
+
attn_mask = torch.cat([query_mask, doc_mask], dim=1)
|
175 |
+
|
176 |
+
#
|
177 |
+
# combined
|
178 |
+
#
|
179 |
+
hidden_state = x
|
180 |
+
for layer_module in self.layer[join_layer_idx:]:
|
181 |
+
layer_outputs = layer_module(
|
182 |
+
x=hidden_state, attn_mask=attn_mask, head_mask=None, output_attentions=output_attentions
|
183 |
+
)
|
184 |
+
hidden_state = layer_outputs[-1]
|
185 |
+
|
186 |
+
# Add last layer
|
187 |
+
if output_hidden_states:
|
188 |
+
all_hidden_states = all_hidden_states + (hidden_state,)
|
189 |
+
|
190 |
+
outputs = (hidden_state,)
|
191 |
+
if output_hidden_states:
|
192 |
+
outputs = outputs + (all_hidden_states,)
|
193 |
+
if output_attentions:
|
194 |
+
outputs = outputs + (all_attentions,)
|
195 |
+
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
196 |
+
|
197 |
+
#
|
198 |
+
# init the model & tokenizer (using the distilbert tokenizer)
|
199 |
+
#
|
200 |
+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # honestly not sure if that is the best way to go, but it works :)
|
201 |
+
model = PreTTR.from_pretrained("sebastian-hofstaetter/prettr-distilbert-split_at_3-margin_mse-T2-msmarco")
|
202 |
+
````
|
203 |
+
|
204 |
+
## Effectiveness on MSMARCO Passage
|
205 |
+
|
206 |
+
We trained our model on the MSMARCO standard ("small"-400K query) training triples with knowledge distillation with a batch size of 32 on a single consumer-grade GPU (11GB memory).
|
207 |
+
|
208 |
+
For re-ranking we used the top-1000 BM25 results.
|
209 |
+
|
210 |
+
### MSMARCO-DEV
|
211 |
+
|
212 |
+
Here, we use the larger 49K query DEV set (same range as the smaller 7K DEV set, minimal changes possible)
|
213 |
+
|
214 |
+
| | MRR@10 | NDCG@10 |
|
215 |
+
|----------------------------------|--------|---------|
|
216 |
+
| BM25 | .194 | .241 |
|
217 |
+
| **Margin-MSE PreTTR** (Re-ranking) | .386 | .447 |
|
218 |
+
|
219 |
+
For more metrics, baselines, info and analysis, please see the paper: https://arxiv.org/abs/2010.02666
|
220 |
+
|
221 |
+
## Limitations & Bias
|
222 |
+
|
223 |
+
- The model inherits social biases from both DistilBERT and MSMARCO.
|
224 |
+
|
225 |
+
- The model is only trained on relatively short passages of MSMARCO (avg. 60 words length), so it might struggle with longer text.
|
226 |
+
|
227 |
+
|
228 |
+
## Citation
|
229 |
+
|
230 |
+
If you use our model checkpoint please cite our work as:
|
231 |
+
|
232 |
+
```
|
233 |
+
@misc{hofstaetter2020_crossarchitecture_kd,
|
234 |
+
title={Improving Efficient Neural Ranking Models with Cross-Architecture Knowledge Distillation},
|
235 |
+
author={Sebastian Hofst{\"a}tter and Sophia Althammer and Michael Schr{\"o}der and Mete Sertkan and Allan Hanbury},
|
236 |
+
year={2020},
|
237 |
+
eprint={2010.02666},
|
238 |
+
archivePrefix={arXiv},
|
239 |
+
primaryClass={cs.IR}
|
240 |
+
}
|
241 |
+
```
|
config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "distilbert-base-uncased",
|
3 |
+
"activation": "gelu",
|
4 |
+
"architectures": [
|
5 |
+
"PreTTR"
|
6 |
+
],
|
7 |
+
"attention_dropout": 0.1,
|
8 |
+
"dim": 768,
|
9 |
+
"dropout": 0.1,
|
10 |
+
"hidden_dim": 3072,
|
11 |
+
"initializer_range": 0.02,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "distilbert",
|
14 |
+
"n_heads": 12,
|
15 |
+
"n_layers": 6,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"qa_dropout": 0.1,
|
18 |
+
"seq_classif_dropout": 0.2,
|
19 |
+
"sinusoidal_pos_embds": false,
|
20 |
+
"tie_weights_": true,
|
21 |
+
"transformers_version": "4.8.1",
|
22 |
+
"vocab_size": 30522
|
23 |
+
}
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8e6090f1c5db97bc6f478e3fe66ddcce00c1dc3a006f155fe1639bf248d31b4e
|
3 |
+
size 265490231
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "distilbert-base-uncased", "tokenizer_class": "DistilBertTokenizer"}
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|