Edit model card

SimLM: Pre-training with Representation Bottleneck for Dense Passage Retrieval

paper available at https://arxiv.org/pdf/2207.02578

code available at https://github.com/microsoft/unilm/tree/master/simlm

Paper abstract

In this paper, we propose SimLM (Similarity matching with Language Model pre-training), a simple yet effective pre-training method for dense passage retrieval. It employs a simple bottleneck architecture that learns to compress the passage information into a dense vector through self-supervised pre-training. We use a replaced language modeling objective, which is inspired by ELECTRA, to improve the sample efficiency and reduce the mismatch of the input distribution between pre-training and fine-tuning. SimLM only requires access to unlabeled corpus, and is more broadly applicable when there are no labeled data or queries. We conduct experiments on several large-scale passage retrieval datasets, and show substantial improvements over strong baselines under various settings. Remarkably, SimLM even outperforms multi-vector approaches such as ColBERTv2 which incurs significantly more storage cost.

Results on MS-MARCO passage ranking task

Model dev MRR@10 dev R@50 dev R@1k TREC DL 2019 nDCG@10 TREC DL 2020 nDCG@10
RocketQAv2 38.8 86.2 98.1 - -
coCondenser 38.2 86.5 98.4 71.7 68.4
ColBERTv2 39.7 86.8 98.4 - -
SimLM (this model) 41.1 87.8 98.7 71.4 69.7

Usage

Get embeddings from our fine-tuned model:

import torch
from transformers import AutoModel, AutoTokenizer, BatchEncoding, PreTrainedTokenizerFast
from transformers.modeling_outputs import BaseModelOutput

def l2_normalize(x: torch.Tensor):
    return torch.nn.functional.normalize(x, p=2, dim=-1)

def encode_query(tokenizer: PreTrainedTokenizerFast, query: str) -> BatchEncoding:
    return tokenizer(query,
                     max_length=32,
                     padding=True,
                     truncation=True,
                     return_tensors='pt')

def encode_passage(tokenizer: PreTrainedTokenizerFast, passage: str, title: str = '-') -> BatchEncoding:
    return tokenizer(title,
                     text_pair=passage,
                     max_length=144,
                     padding=True,
                     truncation=True,
                     return_tensors='pt')

tokenizer = AutoTokenizer.from_pretrained('intfloat/simlm-base-msmarco-finetuned')
model = AutoModel.from_pretrained('intfloat/simlm-base-msmarco-finetuned')
model.eval()

with torch.no_grad():
    query_batch_dict = encode_query(tokenizer, 'what is qa')
    outputs: BaseModelOutput = model(**query_batch_dict, return_dict=True)
    query_embedding = l2_normalize(outputs.last_hidden_state[0, 0, :])

    psg1 = 'Quality assurance (QA) is a process-centered approach to ensuring that a company or organization is providing the best possible products or services. It is related to quality control, which focuses on the end result, such as testing a sample of items from a batch after production.'
    psg1_batch_dict = encode_passage(tokenizer, psg1)
    outputs: BaseModelOutput = model(**psg1_batch_dict, return_dict=True)
    psg1_embedding = l2_normalize(outputs.last_hidden_state[0, 0, :])

    psg2 = 'The Super Bowl is typically four hours long. The game itself takes about three and a half hours, with a 30 minute halftime show built in.'
    psg2_batch_dict = encode_passage(tokenizer, psg2)
    outputs: BaseModelOutput = model(**psg2_batch_dict, return_dict=True)
    psg2_embedding = l2_normalize(outputs.last_hidden_state[0, 0, :])

    # Higher cosine similarity means they are more relevant
    print(query_embedding.dot(psg1_embedding), query_embedding.dot(psg2_embedding))

Citation

@article{Wang2022SimLMPW,
  title={SimLM: Pre-training with Representation Bottleneck for Dense Passage Retrieval},
  author={Liang Wang and Nan Yang and Xiaolong Huang and Binxing Jiao and Linjun Yang and Daxin Jiang and Rangan Majumder and Furu Wei},
  journal={ArXiv},
  year={2022},
  volume={abs/2207.02578}
}
Downloads last month
499
Inference API
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.