sebastian-hofstaetter
fix imports b872cbc
1
---
2
language: "en"
3
tags:
4
- knowledge-distillation
5
datasets:
6
- ms_marco
7
---
8
9
# Margin-MSE Trained PreTTR
10
11
We provide a retrieval trained DistilBert-based PreTTR model (https://arxiv.org/abs/2004.14255). Our model is trained with Margin-MSE using a 3 teacher BERT_Cat (concatenated BERT scoring) ensemble on MSMARCO-Passage.
12
13
This instance can be used to **re-rank a candidate set**. The architecture is a 6-layer DistilBERT, split at layer 3, with an additional single linear layer at the end for scoring the CLS token. 
14
15
If you want to know more about our simple, yet effective knowledge distillation method for efficient information retrieval models for a variety of student architectures that is used for this model instance check out our paper: https://arxiv.org/abs/2010.02666 🎉
16
17
For more information, training data, source code, and a minimal usage example please visit: https://github.com/sebastian-hofstaetter/neural-ranking-kd
18
19
## Configuration
20
21
- We split the DistilBERT in half at layer 3
22
23
## Model Code
24
25
````python
26
from transformers import DistilBertModel,AutoTokenizer
27
from transformers.models.distilbert.modeling_distilbert import *
28
import math
29
import torch
30
from torch import nn as nn
31
32
class PreTTRConfig(DistilBertConfig):
33
    join_layer_idx = 3
34
35
class PreTTR(DistilBertModel):
36
    '''
37
    PreTTR changes the distilbert model from huggingface to be able to split query and document until a set layer,
38
    we skipped compression present in the original
39
40
    from: Efficient Document Re-Ranking for Transformers by Precomputing Term Representations
41
          MacAvaney, et al. https://arxiv.org/abs/2004.14255
42
    '''
43
    config_class = PreTTRConfig
44
45
    def __init__(self, config):
46
        super().__init__(config)
47
        self.transformer = SplitTransformer(config)  # Encoder, we override the classes, but the names stay the same -> so it gets properly initialized
48
        self.embeddings = PosOffsetEmbeddings(config)  # Embeddings
49
        self._classification_layer = torch.nn.Linear(self.config.hidden_size, 1, bias=False)
50
51
        self.join_layer_idx = config.join_layer_idx
52
53
    def forward(
54
            self,
55
            query,
56
            document,
57
            use_fp16: bool = False) -> torch.Tensor:
58
59
        with torch.cuda.amp.autocast(enabled=use_fp16):
60
61
            query_input_ids = query["input_ids"]
62
            query_attention_mask = query["attention_mask"]
63
64
            document_input_ids = document["input_ids"][:, 1:]
65
            document_attention_mask = document["attention_mask"][:, 1:]
66
67
            query_embs = self.embeddings(query_input_ids)  # (bs, seq_length, dim)
68
            document_embs = self.embeddings(document_input_ids, query_input_ids.shape[-1])  # (bs, seq_length, dim)
69
70
            tfmr_output = self.transformer(
71
                query_embs=query_embs,
72
                query_mask=query_attention_mask,
73
                doc_embs=document_embs,
74
                doc_mask=document_attention_mask,
75
                join_layer_idx=self.join_layer_idx
76
            )
77
            hidden_state = tfmr_output[0]
78
79
            score = self._classification_layer(hidden_state[:, 0, :]).squeeze()
80
81
            return score
82
83
84
class PosOffsetEmbeddings(nn.Module):
85
    def __init__(self, config):
86
        super().__init__()
87
        self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
88
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
89
        if config.sinusoidal_pos_embds:
90
            create_sinusoidal_embeddings(
91
                n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
92
            )
93
94
        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
95
        self.dropout = nn.Dropout(config.dropout)
96
97
    def forward(self, input_ids, pos_offset=0):
98
        """
99
        Parameters
100
        ----------
101
        input_ids: torch.tensor(bs, max_seq_length)
102
            The token ids to embed.
103
104
        Outputs
105
        -------
106
        embeddings: torch.tensor(bs, max_seq_length, dim)
107
            The embedded tokens (plus position embeddings, no token_type embeddings)
108
        """
109
        seq_length = input_ids.size(1)
110
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)  # (max_seq_length)
111
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + pos_offset  # (bs, max_seq_length)
112
113
        word_embeddings = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)
114
        position_embeddings = self.position_embeddings(position_ids)  # (bs, max_seq_length, dim)
115
116
        embeddings = word_embeddings + position_embeddings  # (bs, max_seq_length, dim)
117
        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)
118
        embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)
119
        return embeddings
120
121
122
class SplitTransformer(nn.Module):
123
    def __init__(self, config):
124
        super().__init__()
125
        self.n_layers = config.n_layers
126
127
        layer = TransformerBlock(config)
128
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])
129
130
    def forward(self, query_embs, query_mask, doc_embs, doc_mask, join_layer_idx, output_attentions=False, output_hidden_states=False):
131
        """
132
        Parameters
133
        ----------
134
        x: torch.tensor(bs, seq_length, dim)
135
            Input sequence embedded.
136
        attn_mask: torch.tensor(bs, seq_length)
137
            Attention mask on the sequence.
138
139
        Outputs
140
        -------
141
        hidden_state: torch.tensor(bs, seq_length, dim)
142
            Sequence of hiddens states in the last (top) layer
143
        all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
144
            Tuple of length n_layers with the hidden states from each layer.
145
            Optional: only if output_hidden_states=True
146
        all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
147
            Tuple of length n_layers with the attention weights from each layer
148
            Optional: only if output_attentions=True
149
        """
150
        all_hidden_states = ()
151
        all_attentions = ()
152
153
        #
154
        # query / doc sep.
155
        #
156
        hidden_state_q = query_embs
157
        hidden_state_d = doc_embs
158
        for layer_module in self.layer[:join_layer_idx]:
159
160
            layer_outputs_q = layer_module(
161
                x=hidden_state_q, attn_mask=query_mask, head_mask=None, output_attentions=output_attentions
162
            )
163
            hidden_state_q = layer_outputs_q[-1]
164
165
            layer_outputs_d = layer_module(
166
                x=hidden_state_d, attn_mask=doc_mask, head_mask=None, output_attentions=output_attentions
167
            )
168
            hidden_state_d = layer_outputs_d[-1]
169
170
        #
171
        # combine
172
        #
173
        x = torch.cat([hidden_state_q, hidden_state_d], dim=1)
174
        attn_mask = torch.cat([query_mask, doc_mask], dim=1)
175
176
        #
177
        # combined
178
        #
179
        hidden_state = x
180
        for layer_module in self.layer[join_layer_idx:]:
181
            layer_outputs = layer_module(
182
                x=hidden_state, attn_mask=attn_mask, head_mask=None, output_attentions=output_attentions
183
            )
184
            hidden_state = layer_outputs[-1]
185
186
        # Add last layer
187
        if output_hidden_states:
188
            all_hidden_states = all_hidden_states + (hidden_state,)
189
190
        outputs = (hidden_state,)
191
        if output_hidden_states:
192
            outputs = outputs + (all_hidden_states,)
193
        if output_attentions:
194
            outputs = outputs + (all_attentions,)
195
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)
196
197
#
198
# init the model & tokenizer (using the distilbert tokenizer)
199
#
200
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # honestly not sure if that is the best way to go, but it works :)
201
model = PreTTR.from_pretrained("sebastian-hofstaetter/prettr-distilbert-split_at_3-margin_mse-T2-msmarco")
202
````
203
204
## Effectiveness on MSMARCO Passage
205
206
We trained our model on the MSMARCO standard ("small"-400K query) training triples with knowledge distillation with a batch size of 32 on a single consumer-grade GPU (11GB memory).
207
208
For re-ranking we used the top-1000 BM25 results.
209
210
### MSMARCO-DEV
211
212
Here, we use the larger 49K query DEV set (same range as the smaller 7K DEV set, minimal changes possible)
213
214
|                                  | MRR@10 | NDCG@10 |
215
|----------------------------------|--------|---------|
216
| BM25                             | .194   | .241    |
217
| **Margin-MSE PreTTR** (Re-ranking) | .386   | .447   |
218
219
For more metrics, baselines, info and analysis, please see the paper: https://arxiv.org/abs/2010.02666
220
221
## Limitations & Bias
222
223
- The model inherits social biases from both DistilBERT and MSMARCO. 
224
225
- The model is only trained on relatively short passages of MSMARCO (avg. 60 words length), so it might struggle with longer text. 
226
227
228
## Citation
229
230
If you use our model checkpoint please cite our work as:
231
232
```
233
@misc{hofstaetter2020_crossarchitecture_kd,
234
      title={Improving Efficient Neural Ranking Models with Cross-Architecture Knowledge Distillation}, 
235
      author={Sebastian Hofst{\"a}tter and Sophia Althammer and Michael Schr{\"o}der and Mete Sertkan and Allan Hanbury},
236
      year={2020},
237
      eprint={2010.02666},
238
      archivePrefix={arXiv},
239
      primaryClass={cs.IR}
240
}
241
```