sebastian-hofstaetter commited on
Commit
5ddf7ff
1 Parent(s): 160bf93

Add model, tokenizer, & initial model card

Browse files
README.md ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: "en"
3
+ tags:
4
+ - dpr
5
+ - dense-passage-retrieval
6
+ - knowledge-distillation
7
+ datasets:
8
+ - ms_marco
9
+ ---
10
+
11
+ # Margin-MSE Trained DistilBERT-Cat (vanilla/mono/concatenated DistilBERT re-ranker)
12
+
13
+ We provide a retrieval trained DistilBERT-Cat model (https://arxiv.org/pdf/2004.12832.pdf). Our model is trained with Margin-MSE using a 3 teacher BERT_Cat (concatenated BERT scoring) ensemble on MSMARCO-Passage.
14
+
15
+ This instance can be used to **re-rank a candidate set**. The architecure is a 6-layer DistilBERT, with an additional single linear layer at the end.
16
+
17
+ If you want to know more about our simple, yet effective knowledge distillation method for efficient information retrieval models for a variety of student architectures that is used for this model instance check out our paper: https://arxiv.org/abs/2010.02666 🎉
18
+
19
+ For more information, training data, source code, and a minimal usage example please visit: https://github.com/sebastian-hofstaetter/neural-ranking-kd
20
+
21
+ ## Configuration
22
+
23
+ - fp16 trained, so fp16 inference shouldn't be a problem
24
+
25
+ ## Model Code
26
+
27
+ ````python
28
+ from transformers import AutoTokenizer,AutoModel, PreTrainedModel,PretrainedConfig
29
+ from typing import Dict
30
+ import torch
31
+
32
+ class BERT_Cat_Config(PretrainedConfig):
33
+ model_type = "BERT_Cat"
34
+ bert_model: str
35
+ trainable: bool = True
36
+
37
+ class BERT_Cat(PreTrainedModel):
38
+ """
39
+ The vanilla/mono BERT concatenated (we lovingly refer to as BERT_Cat) architecture
40
+ -> requires input concatenation before model, so that batched input is possible
41
+ """
42
+ config_class = BERT_Cat_Config
43
+ base_model_prefix = "bert_model"
44
+
45
+ def __init__(self,
46
+ cfg) -> None:
47
+ super().__init__(cfg)
48
+
49
+ self.bert_model = AutoModel.from_pretrained(cfg.bert_model)
50
+
51
+ for p in self.bert_model.parameters():
52
+ p.requires_grad = cfg.trainable
53
+
54
+ self._classification_layer = torch.nn.Linear(self.bert_model.config.hidden_size, 1)
55
+
56
+ def forward(self,
57
+ query_n_doc_sequence):
58
+
59
+ vecs = self.bert_model(**query_n_doc_sequence)[0][:,0,:] # assuming a distilbert model here
60
+ score = self._classification_layer(vecs)
61
+ return score
62
+
63
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # honestly not sure if that is the best way to go, but it works :)
64
+ model = BERT_Cat.from_pretrained("sebastian-hofstaetter/distilbert-cat-margin_mse-T2-msmarco")
65
+ ````
66
+
67
+ ## Effectiveness on MSMARCO Passage & TREC Deep Learning '19
68
+
69
+ We trained our model on the MSMARCO standard ("small"-400K query) training triples with knowledge distillation with a batch size of 32 on a single consumer-grade GPU (11GB memory).
70
+
71
+ For re-ranking we used the top-1000 BM25 results.
72
+
73
+ ### MSMARCO-DEV
74
+
75
+ Here, we use the larger 49K query DEV set (same range as the smaller 7K DEV set, minimal changes possible)
76
+
77
+ | | MRR@10 | NDCG@10 |
78
+ |----------------------------------|--------|---------|
79
+ | BM25 | .194 | .241 |
80
+ | **Margin-MSE DistilBERT_Cat** (Re-ranking) | .391 | .451 |
81
+
82
+ ### TREC-DL'19
83
+
84
+ For MRR we use the recommended binarization point of the graded relevance of 2. This might skew the results when compared to other binarization point numbers.
85
+
86
+ | | MRR@10 | NDCG@10 |
87
+ |----------------------------------|--------|---------|
88
+ | BM25 | .689 | .501 |
89
+ | **Margin-MSE DistilBERT_Cat** (Re-ranking) | .891 | .747 |
90
+
91
+ For more metrics, baselines, info and analysis, please see the paper: https://arxiv.org/abs/2010.02666
92
+
93
+ ## Limitations & Bias
94
+
95
+ - The model inherits social biases from both DistilBERT and MSMARCO.
96
+
97
+ - The model is only trained on relatively short passages of MSMARCO (avg. 60 words length), so it might struggle with longer text.
98
+
99
+
100
+ ## Citation
101
+
102
+ If you use our model checkpoint please cite our work as:
103
+
104
+ ```
105
+ @misc{hofstaetter2020_crossarchitecture_kd,
106
+ title={Improving Efficient Neural Ranking Models with Cross-Architecture Knowledge Distillation},
107
+ author={Sebastian Hofst{\"a}tter and Sophia Althammer and Michael Schr{\"o}der and Mete Sertkan and Allan Hanbury},
108
+ year={2020},
109
+ eprint={2010.02666},
110
+ archivePrefix={arXiv},
111
+ primaryClass={cs.IR}
112
+ }
113
+ ```
config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BERT_Cat"
4
+ ],
5
+ "bert_model": "distilbert-base-uncased",
6
+ "model_type": "BERT_Cat",
7
+ "trainable": true
8
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:390e15c0bbed9cb3c909ffeb6f89b910db673e12f23e9c2366d9c9c4e267ed2d
3
+ size 265477723
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"do_lower_case": true, "do_basic_tokenize": true, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "name_or_path": "distilbert-base-uncased"}
vocab.txt ADDED
The diff for this file is too large to render. See raw diff