File size: 14,455 Bytes
b974692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d1a03b
b974692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
---

language: "en"
tags:
- document-retrieval
- knowledge-distillation
datasets:
- ms_marco
---


# Intra-Document Cascading (IDCM)

We provide a retrieval trained IDCM model. Our model is trained on MSMARCO-Document with up to 2000 tokens.

This instance can be used to **re-rank a candidate set** of long documents. The base BERT architecure is a 6-layer DistilBERT.

If you want to know more about our intra document cascading model & training procedure using knowledge distillation check out our paper: https://arxiv.org/abs/2105.09816 🎉

For more information, training data, source code, and a minimal usage example please visit: https://github.com/sebastian-hofstaetter/intra-document-cascade

## Configuration

- Trained with fp16 mixed precision
- We select the top 4 windows of size (50 + 2*7 overlap words) with our fast CK model and score them with BERT

- The published code here is only usable for inference (we removed the training code)



## Model Code



````python

from transformers import AutoTokenizer,AutoModel, PreTrainedModel,PretrainedConfig

from typing import Dict

import torch

from torch import nn as nn



class IDCM_InferenceOnly(PreTrainedModel):

    '''

    IDCM is a neural re-ranking model for long documents, it creates an intra-document cascade between a fast (CK) and a slow module (BERT_Cat)

    This code is only usable for inference (we removed the training mechanism for simplicity)

    '''



    config_class = IDCM_Config

    base_model_prefix = "bert_model"



    def __init__(self,

                 cfg) -> None:

        super().__init__(cfg)



        #

        # bert - scoring

        #

        if isinstance(cfg.bert_model, str):

            self.bert_model = AutoModel.from_pretrained(cfg.bert_model)

        else:

            self.bert_model = cfg.bert_model



        #

        # final scoring (combination of bert scores)

        #

        self._classification_layer = torch.nn.Linear(self.bert_model.config.hidden_size, 1)

        self.top_k_chunks = cfg.top_k_chunks

        self.top_k_scoring = nn.Parameter(torch.full([1,self.top_k_chunks], 1, dtype=torch.float32, requires_grad=True))



        #

        # local self attention

        #

        self.padding_idx= cfg.padding_idx

        self.chunk_size = cfg.chunk_size

        self.overlap = cfg.overlap

        self.extended_chunk_size = self.chunk_size + 2 * self.overlap

        #

        # sampling stuff

        #

        self.sample_n = cfg.sample_n

        self.sample_context = cfg.sample_context


        if self.sample_context == "ck":

            i = 3

            self.sample_cnn3 = nn.Sequential(

                        nn.ConstantPad1d((0,i - 1), 0),

                        nn.Conv1d(kernel_size=i, in_channels=self.bert_model.config.dim, out_channels=self.bert_model.config.dim),

                        nn.ReLU()

                        ) 

        elif self.sample_context == "ck-small":

            i = 3

            self.sample_projector = nn.Linear(self.bert_model.config.dim,384)

            self.sample_cnn3 = nn.Sequential(

                        nn.ConstantPad1d((0,i - 1), 0),

                        nn.Conv1d(kernel_size=i, in_channels=384, out_channels=128),

                        nn.ReLU()

                        ) 


        self.sampling_binweights = nn.Linear(11, 1, bias=True)

        torch.nn.init.uniform_(self.sampling_binweights.weight, -0.01, 0.01)

        self.kernel_alpha_scaler = nn.Parameter(torch.full([1,1,11], 1, dtype=torch.float32, requires_grad=True))


        self.register_buffer("mu",nn.Parameter(torch.tensor([1.0, 0.9, 0.7, 0.5, 0.3, 0.1, -0.1, -0.3, -0.5, -0.7, -0.9]), requires_grad=False).view(1, 1, 1, -1))

        self.register_buffer("sigma", nn.Parameter(torch.tensor([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]), requires_grad=False).view(1, 1, 1, -1))

        


    def forward(self,

                query: Dict[str, torch.LongTensor],

                document: Dict[str, torch.LongTensor],

                use_fp16:bool = True,

                output_secondary_output: bool = False):


        #

        # patch up documents - local self attention

        #

        document_ids = document["input_ids"][:,1:]

        if document_ids.shape[1] > self.overlap:

            needed_padding = self.extended_chunk_size - (((document_ids.shape[1]) % self.chunk_size)  - self.overlap)

        else:

            needed_padding = self.extended_chunk_size - self.overlap - document_ids.shape[1]

        orig_doc_len = document_ids.shape[1]


        document_ids = nn.functional.pad(document_ids,(self.overlap, needed_padding),value=self.padding_idx)

        chunked_ids = document_ids.unfold(1,self.extended_chunk_size,self.chunk_size)


        batch_size = chunked_ids.shape[0]

        chunk_pieces = chunked_ids.shape[1]



        chunked_ids_unrolled=chunked_ids.reshape(-1,self.extended_chunk_size)

        packed_indices = (chunked_ids_unrolled[:,self.overlap:-self.overlap] != self.padding_idx).any(-1)

        orig_packed_indices = packed_indices.clone()

        ids_packed = chunked_ids_unrolled[packed_indices]

        mask_packed = (ids_packed != self.padding_idx)


        total_chunks=chunked_ids_unrolled.shape[0]


        packed_query_ids = query["input_ids"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["input_ids"].shape[1])[packed_indices]

        packed_query_mask = query["attention_mask"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["attention_mask"].shape[1])[packed_indices]


        #

        # sampling

        # 

        if self.sample_n > -1:

            

            #

            # ck learned matches

            #

            if self.sample_context == "ck-small":

                query_ctx = torch.nn.functional.normalize(self.sample_cnn3(self.sample_projector(self.bert_model.embeddings(packed_query_ids).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1)

                document_ctx = torch.nn.functional.normalize(self.sample_cnn3(self.sample_projector(self.bert_model.embeddings(ids_packed).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1)

            elif self.sample_context == "ck":

                query_ctx = torch.nn.functional.normalize(self.sample_cnn3((self.bert_model.embeddings(packed_query_ids).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1)

                document_ctx = torch.nn.functional.normalize(self.sample_cnn3((self.bert_model.embeddings(ids_packed).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1)

            else:

                qe = self.tk_projector(self.bert_model.embeddings(packed_query_ids).detach())

                de = self.tk_projector(self.bert_model.embeddings(ids_packed).detach())

                query_ctx = self.tk_contextualizer(qe.transpose(1,0),src_key_padding_mask=~packed_query_mask.bool()).transpose(1,0)

                document_ctx = self.tk_contextualizer(de.transpose(1,0),src_key_padding_mask=~mask_packed.bool()).transpose(1,0)

        

                query_ctx =   torch.nn.functional.normalize(query_ctx,p=2,dim=-1)

                document_ctx= torch.nn.functional.normalize(document_ctx,p=2,dim=-1)


            cosine_matrix = torch.bmm(query_ctx,document_ctx.transpose(-1, -2)).unsqueeze(-1)


            kernel_activations = torch.exp(- torch.pow(cosine_matrix - self.mu, 2) / (2 * torch.pow(self.sigma, 2))) * mask_packed.unsqueeze(-1).unsqueeze(1)

            kernel_res = torch.log(torch.clamp(torch.sum(kernel_activations, 2) * self.kernel_alpha_scaler, min=1e-4)) * packed_query_mask.unsqueeze(-1)

            packed_patch_scores = self.sampling_binweights(torch.sum(kernel_res, 1))


            

            sampling_scores_per_doc = torch.zeros((total_chunks,1), dtype=packed_patch_scores.dtype, layout=packed_patch_scores.layout, device=packed_patch_scores.device)

            sampling_scores_per_doc[packed_indices] = packed_patch_scores

            sampling_scores_per_doc = sampling_scores_per_doc.reshape(batch_size,-1,)

            sampling_scores_per_doc_orig = sampling_scores_per_doc.clone()

            sampling_scores_per_doc[sampling_scores_per_doc == 0] = -9000


            sampling_sorted = sampling_scores_per_doc.sort(descending=True)

            sampled_indices = sampling_sorted.indices + torch.arange(0,sampling_scores_per_doc.shape[0]*sampling_scores_per_doc.shape[1],sampling_scores_per_doc.shape[1],device=sampling_scores_per_doc.device).unsqueeze(-1)


            sampled_indices = sampled_indices[:,:self.sample_n]

            sampled_indices_mask = torch.zeros_like(packed_indices).scatter(0, sampled_indices.reshape(-1), 1)


            # pack indices


            packed_indices = sampled_indices_mask * packed_indices

    

            packed_query_ids = query["input_ids"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["input_ids"].shape[1])[packed_indices]

            packed_query_mask = query["attention_mask"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["attention_mask"].shape[1])[packed_indices]


            ids_packed = chunked_ids_unrolled[packed_indices]

            mask_packed = (ids_packed != self.padding_idx)


        #

        # expensive bert scores

        #

        

        bert_vecs = self.forward_representation(torch.cat([packed_query_ids,ids_packed],dim=1),torch.cat([packed_query_mask,mask_packed],dim=1))

        packed_patch_scores = self._classification_layer(bert_vecs) 


        scores_per_doc = torch.zeros((total_chunks,1), dtype=packed_patch_scores.dtype, layout=packed_patch_scores.layout, device=packed_patch_scores.device)

        scores_per_doc[packed_indices] = packed_patch_scores

        scores_per_doc = scores_per_doc.reshape(batch_size,-1,)

        scores_per_doc_orig = scores_per_doc.clone()

        scores_per_doc_orig_sorter = scores_per_doc.clone()


        if self.sample_n > -1:

            scores_per_doc = scores_per_doc * sampled_indices_mask.view(batch_size,-1)

        

        #

        # aggregate bert scores

        #


        if scores_per_doc.shape[1] < self.top_k_chunks:

            scores_per_doc = nn.functional.pad(scores_per_doc,(0, self.top_k_chunks - scores_per_doc.shape[1]))


        scores_per_doc[scores_per_doc == 0] = -9000

        scores_per_doc_orig_sorter[scores_per_doc_orig_sorter == 0] = -9000

        score = torch.sort(scores_per_doc,descending=True,dim=-1).values

        score[score <= -8900] = 0


        score = (score[:,:self.top_k_chunks] * self.top_k_scoring).sum(dim=1)


        if self.sample_n == -1:

            if output_secondary_output:

                return score,{

                    "packed_indices": orig_packed_indices.view(batch_size,-1),

                    "bert_scores":scores_per_doc_orig

                }

            else:

                return score,scores_per_doc_orig    

        else:

            if output_secondary_output:

                return score,scores_per_doc_orig,{

                    "score": score,

                    "packed_indices": orig_packed_indices.view(batch_size,-1),

                    "sampling_scores":sampling_scores_per_doc_orig,

                    "bert_scores":scores_per_doc_orig

                }


            return score


    def forward_representation(self, ids,mask,type_ids=None) -> Dict[str, torch.Tensor]:

        

        if self.bert_model.base_model_prefix == 'distilbert': # diff input / output 

            pooled = self.bert_model(input_ids=ids,

                                     attention_mask=mask)[0][:,0,:]

        elif self.bert_model.base_model_prefix == 'longformer':

            _, pooled = self.bert_model(input_ids=ids,

                                        attention_mask=mask.long(),

                                        global_attention_mask = ((1-ids)*mask).long())

        elif self.bert_model.base_model_prefix == 'roberta': # no token type ids

            _, pooled = self.bert_model(input_ids=ids,

                                        attention_mask=mask)

        else:

            _, pooled = self.bert_model(input_ids=ids,

                                        token_type_ids=type_ids,

                                        attention_mask=mask)


        return pooled


tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # honestly not sure if that is the best way to go, but it works :)

model = IDCM_InferenceOnly.from_pretrained("sebastian-hofstaetter/idcm-distilbert-msmarco_doc")
````



## Effectiveness on MSMARCO Passage & TREC Deep Learning '19



We trained our model on the MSMARCO-Document collection. We trained the selection module CK with knowledge distillation from the stronger BERT model.



For re-ranking we used the top-100 BM25 results. The throughput of IDCM should be ~600 documents with max 2000 tokens per second. 



### MSMARCO-Document-DEV



|                                  | MRR@10 | NDCG@10 |

|----------------------------------|--------|---------|

| BM25                             | .252   | .311    |

| **IDCM**                         | .380   | .446   |



### TREC-DL'19 (Document Task)



For MRR we use the recommended binarization point of the graded relevance of 2. This might skew the results when compared to other binarization point numbers.



|                                  | MRR@10 | NDCG@10 |

|----------------------------------|--------|---------|

| BM25                             | .661   | .488    |

| **IDCM**                         | .916   | .688    |



For more metrics, baselines, info and analysis, please see the paper: https://arxiv.org/abs/2105.09816



## Limitations & Bias



- The model inherits social biases from both DistilBERT and MSMARCO. 



- The model is only trained on longer documents of MSMARCO, so it might struggle with especially short document text - for short text we recommend one of our MSMARCO-Passage trained models.





## Citation



If you use our model checkpoint please cite our work as:



```
@inproceedings{Hofstaetter2021_idcm,

    author = {Sebastian Hofst{\"a}tter and Bhaskar Mitra and Hamed Zamani and Nick Craswell and Allan Hanbury},

    title = {{Intra-Document Cascading: Learning to Select Passages for Neural Document Ranking}},

    booktitle = {Proc. of SIGIR},

    year = {2021},

}

```