--- language: "en" tags: - document-retrieval - knowledge-distillation datasets: - ms_marco --- # Intra-Document Cascading (IDCM) We provide a retrieval trained IDCM model. Our model is trained on MSMARCO-Document with up to 2000 tokens. This instance can be used to **re-rank a candidate set** of long documents. The base BERT architecure is a 6-layer DistilBERT. If you want to know more about our intra document cascading model & training procedure using knowledge distillation check out our paper: https://arxiv.org/abs/2105.09816 🎉 For more information, training data, source code, and a minimal usage example please visit: https://github.com/sebastian-hofstaetter/intra-document-cascade ## Configuration - Trained with fp16 mixed precision - We select the top 4 windows of size (50 + 2*7 overlap words) with our fast CK model and score them with BERT - The published code here is only usable for inference (we removed the training code) ## Model Code ````python from transformers import AutoTokenizer,AutoModel, PreTrainedModel,PretrainedConfig from typing import Dict import torch from torch import nn as nn class IDCM_InferenceOnly(PreTrainedModel): ''' IDCM is a neural re-ranking model for long documents, it creates an intra-document cascade between a fast (CK) and a slow module (BERT_Cat) This code is only usable for inference (we removed the training mechanism for simplicity) ''' config_class = IDCM_Config base_model_prefix = "bert_model" def __init__(self, cfg) -> None: super().__init__(cfg) # # bert - scoring # if isinstance(cfg.bert_model, str): self.bert_model = AutoModel.from_pretrained(cfg.bert_model) else: self.bert_model = cfg.bert_model # # final scoring (combination of bert scores) # self._classification_layer = torch.nn.Linear(self.bert_model.config.hidden_size, 1) self.top_k_chunks = cfg.top_k_chunks self.top_k_scoring = nn.Parameter(torch.full([1,self.top_k_chunks], 1, dtype=torch.float32, requires_grad=True)) # # local self attention # self.padding_idx= cfg.padding_idx self.chunk_size = cfg.chunk_size self.overlap = cfg.overlap self.extended_chunk_size = self.chunk_size + 2 * self.overlap # # sampling stuff # self.sample_n = cfg.sample_n self.sample_context = cfg.sample_context if self.sample_context == "ck": i = 3 self.sample_cnn3 = nn.Sequential( nn.ConstantPad1d((0,i - 1), 0), nn.Conv1d(kernel_size=i, in_channels=self.bert_model.config.dim, out_channels=self.bert_model.config.dim), nn.ReLU() ) elif self.sample_context == "ck-small": i = 3 self.sample_projector = nn.Linear(self.bert_model.config.dim,384) self.sample_cnn3 = nn.Sequential( nn.ConstantPad1d((0,i - 1), 0), nn.Conv1d(kernel_size=i, in_channels=384, out_channels=128), nn.ReLU() ) self.sampling_binweights = nn.Linear(11, 1, bias=True) torch.nn.init.uniform_(self.sampling_binweights.weight, -0.01, 0.01) self.kernel_alpha_scaler = nn.Parameter(torch.full([1,1,11], 1, dtype=torch.float32, requires_grad=True)) self.register_buffer("mu",nn.Parameter(torch.tensor([1.0, 0.9, 0.7, 0.5, 0.3, 0.1, -0.1, -0.3, -0.5, -0.7, -0.9]), requires_grad=False).view(1, 1, 1, -1)) self.register_buffer("sigma", nn.Parameter(torch.tensor([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]), requires_grad=False).view(1, 1, 1, -1)) def forward(self, query: Dict[str, torch.LongTensor], document: Dict[str, torch.LongTensor], use_fp16:bool = True, output_secondary_output: bool = False): # # patch up documents - local self attention # document_ids = document["input_ids"][:,1:] if document_ids.shape[1] > self.overlap: needed_padding = self.extended_chunk_size - (((document_ids.shape[1]) % self.chunk_size) - self.overlap) else: needed_padding = self.extended_chunk_size - self.overlap - document_ids.shape[1] orig_doc_len = document_ids.shape[1] document_ids = nn.functional.pad(document_ids,(self.overlap, needed_padding),value=self.padding_idx) chunked_ids = document_ids.unfold(1,self.extended_chunk_size,self.chunk_size) batch_size = chunked_ids.shape[0] chunk_pieces = chunked_ids.shape[1] chunked_ids_unrolled=chunked_ids.reshape(-1,self.extended_chunk_size) packed_indices = (chunked_ids_unrolled[:,self.overlap:-self.overlap] != self.padding_idx).any(-1) orig_packed_indices = packed_indices.clone() ids_packed = chunked_ids_unrolled[packed_indices] mask_packed = (ids_packed != self.padding_idx) total_chunks=chunked_ids_unrolled.shape[0] packed_query_ids = query["input_ids"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["input_ids"].shape[1])[packed_indices] packed_query_mask = query["attention_mask"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["attention_mask"].shape[1])[packed_indices] # # sampling # if self.sample_n > -1: # # ck learned matches # if self.sample_context == "ck-small": query_ctx = torch.nn.functional.normalize(self.sample_cnn3(self.sample_projector(self.bert_model.embeddings(packed_query_ids).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1) document_ctx = torch.nn.functional.normalize(self.sample_cnn3(self.sample_projector(self.bert_model.embeddings(ids_packed).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1) elif self.sample_context == "ck": query_ctx = torch.nn.functional.normalize(self.sample_cnn3((self.bert_model.embeddings(packed_query_ids).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1) document_ctx = torch.nn.functional.normalize(self.sample_cnn3((self.bert_model.embeddings(ids_packed).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1) else: qe = self.tk_projector(self.bert_model.embeddings(packed_query_ids).detach()) de = self.tk_projector(self.bert_model.embeddings(ids_packed).detach()) query_ctx = self.tk_contextualizer(qe.transpose(1,0),src_key_padding_mask=~packed_query_mask.bool()).transpose(1,0) document_ctx = self.tk_contextualizer(de.transpose(1,0),src_key_padding_mask=~mask_packed.bool()).transpose(1,0) query_ctx = torch.nn.functional.normalize(query_ctx,p=2,dim=-1) document_ctx= torch.nn.functional.normalize(document_ctx,p=2,dim=-1) cosine_matrix = torch.bmm(query_ctx,document_ctx.transpose(-1, -2)).unsqueeze(-1) kernel_activations = torch.exp(- torch.pow(cosine_matrix - self.mu, 2) / (2 * torch.pow(self.sigma, 2))) * mask_packed.unsqueeze(-1).unsqueeze(1) kernel_res = torch.log(torch.clamp(torch.sum(kernel_activations, 2) * self.kernel_alpha_scaler, min=1e-4)) * packed_query_mask.unsqueeze(-1) packed_patch_scores = self.sampling_binweights(torch.sum(kernel_res, 1)) sampling_scores_per_doc = torch.zeros((total_chunks,1), dtype=packed_patch_scores.dtype, layout=packed_patch_scores.layout, device=packed_patch_scores.device) sampling_scores_per_doc[packed_indices] = packed_patch_scores sampling_scores_per_doc = sampling_scores_per_doc.reshape(batch_size,-1,) sampling_scores_per_doc_orig = sampling_scores_per_doc.clone() sampling_scores_per_doc[sampling_scores_per_doc == 0] = -9000 sampling_sorted = sampling_scores_per_doc.sort(descending=True) sampled_indices = sampling_sorted.indices + torch.arange(0,sampling_scores_per_doc.shape[0]*sampling_scores_per_doc.shape[1],sampling_scores_per_doc.shape[1],device=sampling_scores_per_doc.device).unsqueeze(-1) sampled_indices = sampled_indices[:,:self.sample_n] sampled_indices_mask = torch.zeros_like(packed_indices).scatter(0, sampled_indices.reshape(-1), 1) # pack indices packed_indices = sampled_indices_mask * packed_indices packed_query_ids = query["input_ids"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["input_ids"].shape[1])[packed_indices] packed_query_mask = query["attention_mask"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["attention_mask"].shape[1])[packed_indices] ids_packed = chunked_ids_unrolled[packed_indices] mask_packed = (ids_packed != self.padding_idx) # # expensive bert scores # bert_vecs = self.forward_representation(torch.cat([packed_query_ids,ids_packed],dim=1),torch.cat([packed_query_mask,mask_packed],dim=1)) packed_patch_scores = self._classification_layer(bert_vecs) scores_per_doc = torch.zeros((total_chunks,1), dtype=packed_patch_scores.dtype, layout=packed_patch_scores.layout, device=packed_patch_scores.device) scores_per_doc[packed_indices] = packed_patch_scores scores_per_doc = scores_per_doc.reshape(batch_size,-1,) scores_per_doc_orig = scores_per_doc.clone() scores_per_doc_orig_sorter = scores_per_doc.clone() if self.sample_n > -1: scores_per_doc = scores_per_doc * sampled_indices_mask.view(batch_size,-1) # # aggregate bert scores # if scores_per_doc.shape[1] < self.top_k_chunks: scores_per_doc = nn.functional.pad(scores_per_doc,(0, self.top_k_chunks - scores_per_doc.shape[1])) scores_per_doc[scores_per_doc == 0] = -9000 scores_per_doc_orig_sorter[scores_per_doc_orig_sorter == 0] = -9000 score = torch.sort(scores_per_doc,descending=True,dim=-1).values score[score <= -8900] = 0 score = (score[:,:self.top_k_chunks] * self.top_k_scoring).sum(dim=1) if self.sample_n == -1: if output_secondary_output: return score,{ "packed_indices": orig_packed_indices.view(batch_size,-1), "bert_scores":scores_per_doc_orig } else: return score,scores_per_doc_orig else: if output_secondary_output: return score,scores_per_doc_orig,{ "score": score, "packed_indices": orig_packed_indices.view(batch_size,-1), "sampling_scores":sampling_scores_per_doc_orig, "bert_scores":scores_per_doc_orig } return score def forward_representation(self, ids,mask,type_ids=None) -> Dict[str, torch.Tensor]: if self.bert_model.base_model_prefix == 'distilbert': # diff input / output pooled = self.bert_model(input_ids=ids, attention_mask=mask)[0][:,0,:] elif self.bert_model.base_model_prefix == 'longformer': _, pooled = self.bert_model(input_ids=ids, attention_mask=mask.long(), global_attention_mask = ((1-ids)*mask).long()) elif self.bert_model.base_model_prefix == 'roberta': # no token type ids _, pooled = self.bert_model(input_ids=ids, attention_mask=mask) else: _, pooled = self.bert_model(input_ids=ids, token_type_ids=type_ids, attention_mask=mask) return pooled tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # honestly not sure if that is the best way to go, but it works :) model = IDCM_InferenceOnly.from_pretrained("sebastian-hofstaetter/idcm-distilbert-msmarco_doc") ```` ## Effectiveness on MSMARCO Passage & TREC Deep Learning '19 We trained our model on the MSMARCO-Document collection. We trained the selection module CK with knowledge distillation from the stronger BERT model. For re-ranking we used the top-100 BM25 results. The throughput of IDCM should be ~600 documents with max 2000 tokens per second. ### MSMARCO-Document-DEV | | MRR@10 | NDCG@10 | |----------------------------------|--------|---------| | BM25 | .252 | .311 | | **IDCM** | .380 | .446 | ### TREC-DL'19 (Document Task) For MRR we use the recommended binarization point of the graded relevance of 2. This might skew the results when compared to other binarization point numbers. | | MRR@10 | NDCG@10 | |----------------------------------|--------|---------| | BM25 | .661 | .488 | | **IDCM** | .916 | .688 | For more metrics, baselines, info and analysis, please see the paper: https://arxiv.org/abs/2105.09816 ## Limitations & Bias - The model inherits social biases from both DistilBERT and MSMARCO. - The model is only trained on longer documents of MSMARCO, so it might struggle with especially short document text - for short text we recommend one of our MSMARCO-Passage trained models. ## Citation If you use our model checkpoint please cite our work as: ``` @inproceedings{Hofstaetter2021_idcm, author = {Sebastian Hofst{\"a}tter and Bhaskar Mitra and Hamed Zamani and Nick Craswell and Allan Hanbury}, title = {{Intra-Document Cascading: Learning to Select Passages for Neural Document Ranking}}, booktitle = {Proc. of SIGIR}, year = {2021}, } ```