Listwise MonoBERT trained on Baidu-ULTR using the Dual Learning Algorithm (DLA)
A flax-based MonoBERT cross encoder trained on the Baidu-ULTR dataset with a listwise DLA objective on clicks. Following Ai et al., the dual learning algorithm jointly infers item relevance (using a BERT model) and position bias (in our case, a single embedding parameter per rank), both by optimizing a listwise softmax cross-entropy loss. For more info, read our paper and find the code for this model here.
Test Results on Baidu-ULTR
Ranking performance is measured in DCG, nDCG, and MRR on expert annotations (6,985 queries). Click prediction performance is measured in log-likelihood on one test partition of user clicks (≈297k queries).
Model | Log-likelihood | DCG@1 | DCG@3 | DCG@5 | DCG@10 | nDCG@10 | MRR@10 |
---|---|---|---|---|---|---|---|
Pointwise Naive | 0.227 | 1.641 | 3.462 | 4.752 | 7.251 | 0.357 | 0.609 |
Pointwise Two-Tower | 0.218 | 1.629 | 3.471 | 4.822 | 7.456 | 0.367 | 0.607 |
Pointwise IPS | 0.222 | 1.295 | 2.811 | 3.977 | 6.296 | 0.307 | 0.534 |
Listwise Naive | - | 1.947 | 4.108 | 5.614 | 8.478 | 0.405 | 0.639 |
Listwise IPS | - | 1.671 | 3.530 | 4.873 | 7.450 | 0.361 | 0.603 |
Listwise DLA | - | 1.796 | 3.730 | 5.125 | 7.802 | 0.377 | 0.615 |
Usage
Here is an example of downloading the model and calling it for inference on a mock batch of input data. For more details on how to use the model on the Baidu-ULTR dataset, take a look at our training and evaluation scripts in our code repository.
import jax.numpy as jnp
from src.model import DLACrossEncoder
model = DLACrossEncoder.from_pretrained(
"philipphager/baidu-ultr_uva-bert_dla",
)
# Mock batch following Baidu-ULTR with 4 documents, each with 8 tokens
batch = {
# Query_id for each document
"query_id": jnp.array([1, 1, 1, 1]),
# Document position in SERP
"positions": jnp.array([1, 2, 3, 4]),
# Token ids for: [CLS] Query [SEP] Document
"tokens": jnp.array([
[2, 21448, 21874, 21436, 1, 20206, 4012, 2860],
[2, 21448, 21874, 21436, 1, 16794, 4522, 2082],
[2, 21448, 21874, 21436, 1, 20206, 10082, 9773],
[2, 21448, 21874, 21436, 1, 2618, 8520, 2860],
]),
# Specify if a token id belongs to the query (0) or document (1)
"token_types": jnp.array([
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1],
]),
# Marks if a token should be attended to (True) or ignored, e.g., padding tokens (False):
"attention_mask": jnp.array([
[True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True],
]),
}
outputs = model(batch, train=False)
print(outputs)
Reference
@inproceedings{Hager2024BaiduULTR,
author = {Philipp Hager and Romain Deffayet and Jean-Michel Renders and Onno Zoeter and Maarten de Rijke},
title = {Unbiased Learning to Rank Meets Reality: Lessons from Baidu’s Large-Scale Search Dataset},
booktitle = {Proceedings of the 47th International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR`24)},
organization = {ACM},
year = {2024},
}
- Downloads last month
- 5