sebastian-hofstaetter
fix readme
4d1a03b
---
language: "en"
tags:
- document-retrieval
- knowledge-distillation
datasets:
- ms_marco
---
# Intra-Document Cascading (IDCM)
We provide a retrieval trained IDCM model. Our model is trained on MSMARCO-Document with up to 2000 tokens.
This instance can be used to **re-rank a candidate set** of long documents. The base BERT architecure is a 6-layer DistilBERT.
If you want to know more about our intra document cascading model & training procedure using knowledge distillation check out our paper: https://arxiv.org/abs/2105.09816 🎉
For more information, training data, source code, and a minimal usage example please visit: https://github.com/sebastian-hofstaetter/intra-document-cascade
## Configuration
- Trained with fp16 mixed precision
- We select the top 4 windows of size (50 + 2*7 overlap words) with our fast CK model and score them with BERT
- The published code here is only usable for inference (we removed the training code)
## Model Code
````python
from transformers import AutoTokenizer,AutoModel, PreTrainedModel,PretrainedConfig
from typing import Dict
import torch
from torch import nn as nn
class IDCM_InferenceOnly(PreTrainedModel):
'''
IDCM is a neural re-ranking model for long documents, it creates an intra-document cascade between a fast (CK) and a slow module (BERT_Cat)
This code is only usable for inference (we removed the training mechanism for simplicity)
'''
config_class = IDCM_Config
base_model_prefix = "bert_model"
def __init__(self,
cfg) -> None:
super().__init__(cfg)
#
# bert - scoring
#
if isinstance(cfg.bert_model, str):
self.bert_model = AutoModel.from_pretrained(cfg.bert_model)
else:
self.bert_model = cfg.bert_model
#
# final scoring (combination of bert scores)
#
self._classification_layer = torch.nn.Linear(self.bert_model.config.hidden_size, 1)
self.top_k_chunks = cfg.top_k_chunks
self.top_k_scoring = nn.Parameter(torch.full([1,self.top_k_chunks], 1, dtype=torch.float32, requires_grad=True))
#
# local self attention
#
self.padding_idx= cfg.padding_idx
self.chunk_size = cfg.chunk_size
self.overlap = cfg.overlap
self.extended_chunk_size = self.chunk_size + 2 * self.overlap
#
# sampling stuff
#
self.sample_n = cfg.sample_n
self.sample_context = cfg.sample_context
if self.sample_context == "ck":
i = 3
self.sample_cnn3 = nn.Sequential(
nn.ConstantPad1d((0,i - 1), 0),
nn.Conv1d(kernel_size=i, in_channels=self.bert_model.config.dim, out_channels=self.bert_model.config.dim),
nn.ReLU()
)
elif self.sample_context == "ck-small":
i = 3
self.sample_projector = nn.Linear(self.bert_model.config.dim,384)
self.sample_cnn3 = nn.Sequential(
nn.ConstantPad1d((0,i - 1), 0),
nn.Conv1d(kernel_size=i, in_channels=384, out_channels=128),
nn.ReLU()
)
self.sampling_binweights = nn.Linear(11, 1, bias=True)
torch.nn.init.uniform_(self.sampling_binweights.weight, -0.01, 0.01)
self.kernel_alpha_scaler = nn.Parameter(torch.full([1,1,11], 1, dtype=torch.float32, requires_grad=True))
self.register_buffer("mu",nn.Parameter(torch.tensor([1.0, 0.9, 0.7, 0.5, 0.3, 0.1, -0.1, -0.3, -0.5, -0.7, -0.9]), requires_grad=False).view(1, 1, 1, -1))
self.register_buffer("sigma", nn.Parameter(torch.tensor([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]), requires_grad=False).view(1, 1, 1, -1))
def forward(self,
query: Dict[str, torch.LongTensor],
document: Dict[str, torch.LongTensor],
use_fp16:bool = True,
output_secondary_output: bool = False):
#
# patch up documents - local self attention
#
document_ids = document["input_ids"][:,1:]
if document_ids.shape[1] > self.overlap:
needed_padding = self.extended_chunk_size - (((document_ids.shape[1]) % self.chunk_size) - self.overlap)
else:
needed_padding = self.extended_chunk_size - self.overlap - document_ids.shape[1]
orig_doc_len = document_ids.shape[1]
document_ids = nn.functional.pad(document_ids,(self.overlap, needed_padding),value=self.padding_idx)
chunked_ids = document_ids.unfold(1,self.extended_chunk_size,self.chunk_size)
batch_size = chunked_ids.shape[0]
chunk_pieces = chunked_ids.shape[1]
chunked_ids_unrolled=chunked_ids.reshape(-1,self.extended_chunk_size)
packed_indices = (chunked_ids_unrolled[:,self.overlap:-self.overlap] != self.padding_idx).any(-1)
orig_packed_indices = packed_indices.clone()
ids_packed = chunked_ids_unrolled[packed_indices]
mask_packed = (ids_packed != self.padding_idx)
total_chunks=chunked_ids_unrolled.shape[0]
packed_query_ids = query["input_ids"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["input_ids"].shape[1])[packed_indices]
packed_query_mask = query["attention_mask"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["attention_mask"].shape[1])[packed_indices]
#
# sampling
#
if self.sample_n > -1:
#
# ck learned matches
#
if self.sample_context == "ck-small":
query_ctx = torch.nn.functional.normalize(self.sample_cnn3(self.sample_projector(self.bert_model.embeddings(packed_query_ids).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1)
document_ctx = torch.nn.functional.normalize(self.sample_cnn3(self.sample_projector(self.bert_model.embeddings(ids_packed).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1)
elif self.sample_context == "ck":
query_ctx = torch.nn.functional.normalize(self.sample_cnn3((self.bert_model.embeddings(packed_query_ids).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1)
document_ctx = torch.nn.functional.normalize(self.sample_cnn3((self.bert_model.embeddings(ids_packed).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1)
else:
qe = self.tk_projector(self.bert_model.embeddings(packed_query_ids).detach())
de = self.tk_projector(self.bert_model.embeddings(ids_packed).detach())
query_ctx = self.tk_contextualizer(qe.transpose(1,0),src_key_padding_mask=~packed_query_mask.bool()).transpose(1,0)
document_ctx = self.tk_contextualizer(de.transpose(1,0),src_key_padding_mask=~mask_packed.bool()).transpose(1,0)
query_ctx = torch.nn.functional.normalize(query_ctx,p=2,dim=-1)
document_ctx= torch.nn.functional.normalize(document_ctx,p=2,dim=-1)
cosine_matrix = torch.bmm(query_ctx,document_ctx.transpose(-1, -2)).unsqueeze(-1)
kernel_activations = torch.exp(- torch.pow(cosine_matrix - self.mu, 2) / (2 * torch.pow(self.sigma, 2))) * mask_packed.unsqueeze(-1).unsqueeze(1)
kernel_res = torch.log(torch.clamp(torch.sum(kernel_activations, 2) * self.kernel_alpha_scaler, min=1e-4)) * packed_query_mask.unsqueeze(-1)
packed_patch_scores = self.sampling_binweights(torch.sum(kernel_res, 1))
sampling_scores_per_doc = torch.zeros((total_chunks,1), dtype=packed_patch_scores.dtype, layout=packed_patch_scores.layout, device=packed_patch_scores.device)
sampling_scores_per_doc[packed_indices] = packed_patch_scores
sampling_scores_per_doc = sampling_scores_per_doc.reshape(batch_size,-1,)
sampling_scores_per_doc_orig = sampling_scores_per_doc.clone()
sampling_scores_per_doc[sampling_scores_per_doc == 0] = -9000
sampling_sorted = sampling_scores_per_doc.sort(descending=True)
sampled_indices = sampling_sorted.indices + torch.arange(0,sampling_scores_per_doc.shape[0]*sampling_scores_per_doc.shape[1],sampling_scores_per_doc.shape[1],device=sampling_scores_per_doc.device).unsqueeze(-1)
sampled_indices = sampled_indices[:,:self.sample_n]
sampled_indices_mask = torch.zeros_like(packed_indices).scatter(0, sampled_indices.reshape(-1), 1)
# pack indices
packed_indices = sampled_indices_mask * packed_indices
packed_query_ids = query["input_ids"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["input_ids"].shape[1])[packed_indices]
packed_query_mask = query["attention_mask"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["attention_mask"].shape[1])[packed_indices]
ids_packed = chunked_ids_unrolled[packed_indices]
mask_packed = (ids_packed != self.padding_idx)
#
# expensive bert scores
#
bert_vecs = self.forward_representation(torch.cat([packed_query_ids,ids_packed],dim=1),torch.cat([packed_query_mask,mask_packed],dim=1))
packed_patch_scores = self._classification_layer(bert_vecs)
scores_per_doc = torch.zeros((total_chunks,1), dtype=packed_patch_scores.dtype, layout=packed_patch_scores.layout, device=packed_patch_scores.device)
scores_per_doc[packed_indices] = packed_patch_scores
scores_per_doc = scores_per_doc.reshape(batch_size,-1,)
scores_per_doc_orig = scores_per_doc.clone()
scores_per_doc_orig_sorter = scores_per_doc.clone()
if self.sample_n > -1:
scores_per_doc = scores_per_doc * sampled_indices_mask.view(batch_size,-1)
#
# aggregate bert scores
#
if scores_per_doc.shape[1] < self.top_k_chunks:
scores_per_doc = nn.functional.pad(scores_per_doc,(0, self.top_k_chunks - scores_per_doc.shape[1]))
scores_per_doc[scores_per_doc == 0] = -9000
scores_per_doc_orig_sorter[scores_per_doc_orig_sorter == 0] = -9000
score = torch.sort(scores_per_doc,descending=True,dim=-1).values
score[score <= -8900] = 0
score = (score[:,:self.top_k_chunks] * self.top_k_scoring).sum(dim=1)
if self.sample_n == -1:
if output_secondary_output:
return score,{
"packed_indices": orig_packed_indices.view(batch_size,-1),
"bert_scores":scores_per_doc_orig
}
else:
return score,scores_per_doc_orig
else:
if output_secondary_output:
return score,scores_per_doc_orig,{
"score": score,
"packed_indices": orig_packed_indices.view(batch_size,-1),
"sampling_scores":sampling_scores_per_doc_orig,
"bert_scores":scores_per_doc_orig
}
return score
def forward_representation(self, ids,mask,type_ids=None) -> Dict[str, torch.Tensor]:
if self.bert_model.base_model_prefix == 'distilbert': # diff input / output
pooled = self.bert_model(input_ids=ids,
attention_mask=mask)[0][:,0,:]
elif self.bert_model.base_model_prefix == 'longformer':
_, pooled = self.bert_model(input_ids=ids,
attention_mask=mask.long(),
global_attention_mask = ((1-ids)*mask).long())
elif self.bert_model.base_model_prefix == 'roberta': # no token type ids
_, pooled = self.bert_model(input_ids=ids,
attention_mask=mask)
else:
_, pooled = self.bert_model(input_ids=ids,
token_type_ids=type_ids,
attention_mask=mask)
return pooled
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # honestly not sure if that is the best way to go, but it works :)
model = IDCM_InferenceOnly.from_pretrained("sebastian-hofstaetter/idcm-distilbert-msmarco_doc")
````
## Effectiveness on MSMARCO Passage & TREC Deep Learning '19
We trained our model on the MSMARCO-Document collection. We trained the selection module CK with knowledge distillation from the stronger BERT model.
For re-ranking we used the top-100 BM25 results. The throughput of IDCM should be ~600 documents with max 2000 tokens per second.
### MSMARCO-Document-DEV
| | MRR@10 | NDCG@10 |
|----------------------------------|--------|---------|
| BM25 | .252 | .311 |
| **IDCM** | .380 | .446 |
### TREC-DL'19 (Document Task)
For MRR we use the recommended binarization point of the graded relevance of 2. This might skew the results when compared to other binarization point numbers.
| | MRR@10 | NDCG@10 |
|----------------------------------|--------|---------|
| BM25 | .661 | .488 |
| **IDCM** | .916 | .688 |
For more metrics, baselines, info and analysis, please see the paper: https://arxiv.org/abs/2105.09816
## Limitations & Bias
- The model inherits social biases from both DistilBERT and MSMARCO.
- The model is only trained on longer documents of MSMARCO, so it might struggle with especially short document text - for short text we recommend one of our MSMARCO-Passage trained models.
## Citation
If you use our model checkpoint please cite our work as:
```
@inproceedings{Hofstaetter2021_idcm,
author = {Sebastian Hofst{\"a}tter and Bhaskar Mitra and Hamed Zamani and Nick Craswell and Allan Hanbury},
title = {{Intra-Document Cascading: Learning to Select Passages for Neural Document Ranking}},
booktitle = {Proc. of SIGIR},
year = {2021},
}
```