Text Ranking
sentence-transformers
PyTorch
JAX
ONNX
Safetensors
OpenVINO
Transformers
English
bert
text-classification
text-embeddings-inference
Instructions to use cross-encoder/ms-marco-TinyBERT-L6 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use cross-encoder/ms-marco-TinyBERT-L6 with sentence-transformers:
from sentence_transformers import CrossEncoder model = CrossEncoder("cross-encoder/ms-marco-TinyBERT-L6") query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores) - Transformers
How to use cross-encoder/ms-marco-TinyBERT-L6 with Transformers:
# Load model directly from transformers import AutoTokenizer, AutoModelForSequenceClassification tokenizer = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-TinyBERT-L6") model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/ms-marco-TinyBERT-L6") - Notebooks
- Google Colab
- Kaggle
nreimers commited on
Commit ·
67d6b31
1
Parent(s): d0a17bf
upload
Browse files- CERerankingEvaluator_results.csv +127 -0
- README.md +34 -0
- config.json +27 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +1 -0
- tokenizer_config.json +1 -0
- train_script.py +193 -0
- vocab.txt +0 -0
CERerankingEvaluator_results.csv
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
epoch,steps,MRR@10
|
| 2 |
+
0,5000,0.5650095238095237
|
| 3 |
+
0,10000,0.5849968253968254
|
| 4 |
+
0,15000,0.6097650793650794
|
| 5 |
+
0,20000,0.6246285714285715
|
| 6 |
+
0,25000,0.6100253968253967
|
| 7 |
+
0,30000,0.6270730158730159
|
| 8 |
+
0,35000,0.6138888888888888
|
| 9 |
+
0,40000,0.6240317460317462
|
| 10 |
+
0,45000,0.6327619047619049
|
| 11 |
+
0,50000,0.619631746031746
|
| 12 |
+
0,55000,0.5871142857142856
|
| 13 |
+
0,60000,0.6175809523809525
|
| 14 |
+
0,65000,0.6081968253968254
|
| 15 |
+
0,70000,0.6151301587301587
|
| 16 |
+
0,75000,0.6093269841269842
|
| 17 |
+
0,80000,0.6032571428571428
|
| 18 |
+
0,85000,0.6138063492063491
|
| 19 |
+
0,90000,0.6156952380952381
|
| 20 |
+
0,95000,0.6303523809523809
|
| 21 |
+
0,100000,0.6061523809523809
|
| 22 |
+
0,105000,0.6133174603174603
|
| 23 |
+
0,110000,0.6226063492063493
|
| 24 |
+
0,115000,0.6176349206349206
|
| 25 |
+
0,120000,0.6104761904761905
|
| 26 |
+
0,125000,0.6332253968253967
|
| 27 |
+
0,130000,0.6289523809523808
|
| 28 |
+
0,135000,0.6181809523809524
|
| 29 |
+
0,140000,0.6399841269841271
|
| 30 |
+
0,145000,0.623073015873016
|
| 31 |
+
0,150000,0.5963587301587302
|
| 32 |
+
0,155000,0.6157301587301588
|
| 33 |
+
0,160000,0.613120634920635
|
| 34 |
+
0,165000,0.6089936507936508
|
| 35 |
+
0,170000,0.6203301587301587
|
| 36 |
+
0,175000,0.6171269841269841
|
| 37 |
+
0,180000,0.5939269841269841
|
| 38 |
+
0,185000,0.6417873015873015
|
| 39 |
+
0,190000,0.6164476190476191
|
| 40 |
+
0,195000,0.6215841269841269
|
| 41 |
+
0,200000,0.6298984126984126
|
| 42 |
+
0,205000,0.6030507936507936
|
| 43 |
+
0,210000,0.6084730158730158
|
| 44 |
+
0,215000,0.6092730158730159
|
| 45 |
+
0,220000,0.5939650793650793
|
| 46 |
+
0,225000,0.6124190476190475
|
| 47 |
+
0,230000,0.6039269841269841
|
| 48 |
+
0,235000,0.6253301587301587
|
| 49 |
+
0,240000,0.634904761904762
|
| 50 |
+
0,245000,0.6317015873015873
|
| 51 |
+
0,250000,0.6196603174603175
|
| 52 |
+
0,255000,0.6287396825396825
|
| 53 |
+
0,260000,0.6095746031746031
|
| 54 |
+
0,265000,0.6263492063492063
|
| 55 |
+
0,270000,0.6171079365079365
|
| 56 |
+
0,275000,0.6289523809523809
|
| 57 |
+
0,280000,0.6202634920634921
|
| 58 |
+
0,285000,0.6255301587301587
|
| 59 |
+
0,290000,0.5993841269841268
|
| 60 |
+
0,295000,0.6191841269841271
|
| 61 |
+
0,300000,0.6203396825396825
|
| 62 |
+
0,305000,0.6128412698412699
|
| 63 |
+
0,310000,0.6090825396825398
|
| 64 |
+
0,315000,0.5950539682539682
|
| 65 |
+
0,320000,0.5990444444444444
|
| 66 |
+
0,325000,0.6042412698412698
|
| 67 |
+
0,330000,0.5960190476190476
|
| 68 |
+
0,335000,0.6106222222222223
|
| 69 |
+
0,340000,0.6055968253968255
|
| 70 |
+
0,345000,0.5984095238095238
|
| 71 |
+
0,350000,0.6142984126984128
|
| 72 |
+
0,355000,0.6137746031746032
|
| 73 |
+
0,360000,0.6018412698412698
|
| 74 |
+
0,365000,0.6123079365079365
|
| 75 |
+
0,370000,0.6130285714285715
|
| 76 |
+
0,375000,0.6008412698412698
|
| 77 |
+
0,380000,0.6020698412698412
|
| 78 |
+
0,385000,0.6100222222222222
|
| 79 |
+
0,390000,0.5971650793650793
|
| 80 |
+
0,395000,0.5941968253968255
|
| 81 |
+
0,400000,0.5871428571428571
|
| 82 |
+
0,405000,0.6100190476190476
|
| 83 |
+
0,410000,0.5903174603174602
|
| 84 |
+
0,415000,0.5988317460317459
|
| 85 |
+
0,420000,0.6132380952380952
|
| 86 |
+
0,425000,0.6144412698412698
|
| 87 |
+
0,430000,0.5980888888888888
|
| 88 |
+
0,435000,0.5973746031746032
|
| 89 |
+
0,440000,0.595384126984127
|
| 90 |
+
0,445000,0.5871714285714286
|
| 91 |
+
0,450000,0.6012412698412699
|
| 92 |
+
0,455000,0.5873047619047618
|
| 93 |
+
0,460000,0.595584126984127
|
| 94 |
+
0,465000,0.5804285714285713
|
| 95 |
+
0,470000,0.5887619047619047
|
| 96 |
+
0,475000,0.5872761904761904
|
| 97 |
+
0,480000,0.5871396825396825
|
| 98 |
+
0,485000,0.5907174603174602
|
| 99 |
+
0,490000,0.5880412698412699
|
| 100 |
+
0,495000,0.5807968253968254
|
| 101 |
+
0,500000,0.5909746031746032
|
| 102 |
+
0,505000,0.5912984126984128
|
| 103 |
+
0,510000,0.5942761904761905
|
| 104 |
+
0,515000,0.5840222222222223
|
| 105 |
+
0,520000,0.5852380952380952
|
| 106 |
+
0,525000,0.582784126984127
|
| 107 |
+
0,530000,0.5916190476190476
|
| 108 |
+
0,535000,0.5777269841269841
|
| 109 |
+
0,540000,0.582120634920635
|
| 110 |
+
0,545000,0.5746634920634921
|
| 111 |
+
0,550000,0.5746444444444445
|
| 112 |
+
0,555000,0.5632444444444444
|
| 113 |
+
0,560000,0.5799650793650795
|
| 114 |
+
0,565000,0.5932507936507936
|
| 115 |
+
0,570000,0.5816190476190476
|
| 116 |
+
0,575000,0.5838857142857143
|
| 117 |
+
0,580000,0.5859650793650794
|
| 118 |
+
0,585000,0.5843968253968255
|
| 119 |
+
0,590000,0.5840634920634921
|
| 120 |
+
0,595000,0.5958285714285714
|
| 121 |
+
0,600000,0.5842857142857142
|
| 122 |
+
0,605000,0.5892507936507937
|
| 123 |
+
0,610000,0.5914507936507937
|
| 124 |
+
0,615000,0.5953968253968254
|
| 125 |
+
0,620000,0.5925174603174603
|
| 126 |
+
0,625000,0.5890857142857143
|
| 127 |
+
0,-1,0.5890857142857143
|
README.md
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Cross-Encoder for MS Marco
|
| 2 |
+
|
| 3 |
+
This model uses [TinyBERT](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT), a tiny BERT model with only 6 layers. The base model is: General_TinyBERT_v2(6layer-768dim)
|
| 4 |
+
|
| 5 |
+
It was trained on [MS Marco Passage Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking) task.
|
| 6 |
+
|
| 7 |
+
The model can be used for Information Retrieval: Given a query, encode the query will all possible passages (e.g. retrieved with ElasticSearch). Then sort the passages in a decreasing order. See [SBERT.net Information Retrieval](https://github.com/UKPLab/sentence-transformers/tree/master/examples/applications/information-retrieval) for more details. The training code is available here: [SBERT.net Training MS Marco](https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/ms_marco)
|
| 8 |
+
|
| 9 |
+
## Usage and Performance
|
| 10 |
+
|
| 11 |
+
Pre-trained models can be used like this:
|
| 12 |
+
```
|
| 13 |
+
from sentence_transformers import CrossEncoder
|
| 14 |
+
model = CrossEncoder('model_name', max_length=512)
|
| 15 |
+
scores = model.predict([('Query', 'Paragraph1'), ('Query', 'Paragraph2') , ('Query', 'Paragraph3')])
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
In the following table, we provide various pre-trained Cross-Encoders together with their performance on the [TREC Deep Learning 2019](https://microsoft.github.io/TREC-2019-Deep-Learning/) and the [MS Marco Passage Reranking](https://github.com/microsoft/MSMARCO-Passage-Ranking/) dataset.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
| Model-Name | NDCG@10 (TREC DL 19) | MRR@10 (MS Marco Dev) | Docs / Sec (BertTokenizerFast) | Docs / Sec |
|
| 22 |
+
| ------------- |:-------------| -----| --- | --- |
|
| 23 |
+
| cross-encoder/ms-marco-TinyBERT-L-2 | 67.43 | 30.15 | 9000 | 780
|
| 24 |
+
| cross-encoder/ms-marco-TinyBERT-L-4 | 68.09 | 34.50 | 2900 | 760
|
| 25 |
+
| cross-encoder/ms-marco-TinyBERT-L-6 | 69.57 | 36.13 | 680 | 660
|
| 26 |
+
| cross-encoder/ms-marco-electra-base | 71.99 | 36.41 | 340 | 340
|
| 27 |
+
| *Other models* | | | |
|
| 28 |
+
| nboost/pt-tinybert-msmarco | 63.63 | 28.80 | 2900 | 760
|
| 29 |
+
| nboost/pt-bert-base-uncased-msmarco | 70.94 | 34.75 | 340 | 340|
|
| 30 |
+
| nboost/pt-bert-large-msmarco | 73.36 | 36.48 | 100 | 100 |
|
| 31 |
+
| Capreolus/electra-base-msmarco | 71.23 | | 340 | 340 |
|
| 32 |
+
| amberoad/bert-multilingual-passage-reranking-msmarco | 68.40 | | 330 | 330
|
| 33 |
+
|
| 34 |
+
Note: Runtime was computed on a V100 GPU. A bottleneck for smaller models is the standard Python tokenizer from Huggingface v3. Replacing it with the fast tokenizer based on Rust, the throughput is significantly improved:
|
config.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "nreimers/TinyBERT_L-6_H-768_v2",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"BertForSequenceClassification"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.1,
|
| 7 |
+
"gradient_checkpointing": false,
|
| 8 |
+
"hidden_act": "gelu",
|
| 9 |
+
"hidden_dropout_prob": 0.1,
|
| 10 |
+
"hidden_size": 768,
|
| 11 |
+
"id2label": {
|
| 12 |
+
"0": "LABEL_0"
|
| 13 |
+
},
|
| 14 |
+
"initializer_range": 0.02,
|
| 15 |
+
"intermediate_size": 3072,
|
| 16 |
+
"label2id": {
|
| 17 |
+
"LABEL_0": 0
|
| 18 |
+
},
|
| 19 |
+
"layer_norm_eps": 1e-12,
|
| 20 |
+
"max_position_embeddings": 512,
|
| 21 |
+
"model_type": "bert",
|
| 22 |
+
"num_attention_heads": 12,
|
| 23 |
+
"num_hidden_layers": 6,
|
| 24 |
+
"pad_token_id": 0,
|
| 25 |
+
"type_vocab_size": 2,
|
| 26 |
+
"vocab_size": 30522
|
| 27 |
+
}
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0cf70691b32e33dc1a57845ceeeb81bc70cfec62a0d584d063f13576403f2759
|
| 3 |
+
size 267871721
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"do_lower_case": true, "do_basic_tokenize": true, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": "/home/ukp-reimers/.cache/torch/transformers/68bd39c5d90b3f2d8da930bf3efe6ad55d21ae39cfbf97d37817cce8149bf2f3.dd8bd9bfd3664b530ea4e645105f557769387b3da9f79bdb55ed556bdd80611d", "tokenizer_file": null, "name_or_path": "nreimers/TinyBERT_L-6_H-768_v2"}
|
train_script.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import DataLoader
|
| 2 |
+
from sentence_transformers import LoggingHandler
|
| 3 |
+
from sentence_transformers.cross_encoder import CrossEncoder
|
| 4 |
+
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
|
| 5 |
+
from sentence_transformers import InputExample
|
| 6 |
+
import logging
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
import gzip
|
| 9 |
+
import sys
|
| 10 |
+
import numpy as np
|
| 11 |
+
import os
|
| 12 |
+
from shutil import copyfile
|
| 13 |
+
import csv
|
| 14 |
+
import tqdm
|
| 15 |
+
|
| 16 |
+
#### Just some code to print debug information to stdout
|
| 17 |
+
logging.basicConfig(format='%(asctime)s - %(message)s',
|
| 18 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
| 19 |
+
level=logging.INFO,
|
| 20 |
+
handlers=[LoggingHandler()])
|
| 21 |
+
#### /print debug information to stdout
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
#Define our Cross-Encoder
|
| 25 |
+
model_name = sys.argv[1] #'google/electra-small-discriminator'
|
| 26 |
+
train_batch_size = 32
|
| 27 |
+
num_epochs = 1
|
| 28 |
+
model_save_path = 'output/training_ms-marco_cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 29 |
+
|
| 30 |
+
#We set num_labels=1, which predicts a continous score between 0 and 1
|
| 31 |
+
model = CrossEncoder(model_name, num_labels=1, max_length=512)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Write self to path
|
| 35 |
+
os.makedirs(model_save_path, exist_ok=True)
|
| 36 |
+
|
| 37 |
+
train_script_path = os.path.join(model_save_path, 'train_script.py')
|
| 38 |
+
copyfile(__file__, train_script_path)
|
| 39 |
+
with open(train_script_path, 'a') as fOut:
|
| 40 |
+
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
corpus = {}
|
| 44 |
+
queries = {}
|
| 45 |
+
|
| 46 |
+
#### Read train file
|
| 47 |
+
with gzip.open('../data/collection.tsv.gz', 'rt') as fIn:
|
| 48 |
+
for line in fIn:
|
| 49 |
+
pid, passage = line.strip().split("\t")
|
| 50 |
+
corpus[pid] = passage
|
| 51 |
+
|
| 52 |
+
with open('../data/queries.train.tsv', 'r') as fIn:
|
| 53 |
+
for line in fIn:
|
| 54 |
+
qid, query = line.strip().split("\t")
|
| 55 |
+
queries[qid] = query
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
pos_neg_ration = (4+1)
|
| 60 |
+
cnt = 0
|
| 61 |
+
train_samples = []
|
| 62 |
+
dev_samples = {}
|
| 63 |
+
|
| 64 |
+
num_dev_queries = 125
|
| 65 |
+
num_max_dev_negatives = 200
|
| 66 |
+
|
| 67 |
+
with gzip.open('../data/qidpidtriples.rnd-shuf.train-eval.tsv.gz', 'rt') as fIn:
|
| 68 |
+
for line in fIn:
|
| 69 |
+
qid, pos_id, neg_id = line.strip().split()
|
| 70 |
+
|
| 71 |
+
if qid not in dev_samples and len(dev_samples) < num_dev_queries:
|
| 72 |
+
dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()}
|
| 73 |
+
|
| 74 |
+
if qid in dev_samples:
|
| 75 |
+
dev_samples[qid]['positive'].add(corpus[pos_id])
|
| 76 |
+
|
| 77 |
+
if len(dev_samples[qid]['negative']) < num_max_dev_negatives:
|
| 78 |
+
dev_samples[qid]['negative'].add(corpus[neg_id])
|
| 79 |
+
|
| 80 |
+
with gzip.open('../data/qidpidtriples.rnd-shuf.train.tsv.gz', 'rt') as fIn:
|
| 81 |
+
for line in tqdm.tqdm(fIn, unit_scale=True):
|
| 82 |
+
cnt += 1
|
| 83 |
+
qid, pos_id, neg_id = line.strip().split()
|
| 84 |
+
query = queries[qid]
|
| 85 |
+
if (cnt % pos_neg_ration) == 0:
|
| 86 |
+
passage = corpus[pos_id]
|
| 87 |
+
label = 1
|
| 88 |
+
else:
|
| 89 |
+
passage = corpus[neg_id]
|
| 90 |
+
label = 0
|
| 91 |
+
|
| 92 |
+
train_samples.append(InputExample(texts=[query, passage], label=label))
|
| 93 |
+
|
| 94 |
+
if len(train_samples) >= 2e7:
|
| 95 |
+
break
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
|
| 100 |
+
|
| 101 |
+
# We add an evaluator, which evaluates the performance during training
|
| 102 |
+
|
| 103 |
+
class CERerankingEvaluator:
|
| 104 |
+
def __init__(self, samples, mrr_at_k: int = 10, name: str = ''):
|
| 105 |
+
self.samples = samples
|
| 106 |
+
self.name = name
|
| 107 |
+
self.mrr_at_k = mrr_at_k
|
| 108 |
+
|
| 109 |
+
if isinstance(self.samples, dict):
|
| 110 |
+
self.samples = list(self.samples.values())
|
| 111 |
+
|
| 112 |
+
self.csv_file = "CERerankingEvaluator" + ("_" + name if name else '') + "_results.csv"
|
| 113 |
+
self.csv_headers = ["epoch", "steps", "MRR@{}".format(mrr_at_k)]
|
| 114 |
+
|
| 115 |
+
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
|
| 116 |
+
if epoch != -1:
|
| 117 |
+
if steps == -1:
|
| 118 |
+
out_txt = " after epoch {}:".format(epoch)
|
| 119 |
+
else:
|
| 120 |
+
out_txt = " in epoch {} after {} steps:".format(epoch, steps)
|
| 121 |
+
else:
|
| 122 |
+
out_txt = ":"
|
| 123 |
+
|
| 124 |
+
logging.info("CERerankingEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)
|
| 125 |
+
|
| 126 |
+
all_mrr_scores = []
|
| 127 |
+
num_queries = 0
|
| 128 |
+
num_positives = []
|
| 129 |
+
num_negatives = []
|
| 130 |
+
for instance in self.samples:
|
| 131 |
+
query = instance['query']
|
| 132 |
+
positive = list(instance['positive'])
|
| 133 |
+
negative = list(instance['negative'])
|
| 134 |
+
docs = positive + negative
|
| 135 |
+
is_relevant = [True]*len(positive) + [False]*len(negative)
|
| 136 |
+
|
| 137 |
+
if len(positive) == 0 or len(negative) == 0:
|
| 138 |
+
continue
|
| 139 |
+
|
| 140 |
+
num_queries += 1
|
| 141 |
+
num_positives.append(len(positive))
|
| 142 |
+
num_negatives.append(len(negative))
|
| 143 |
+
|
| 144 |
+
model_input = [[query, doc] for doc in docs]
|
| 145 |
+
pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False)
|
| 146 |
+
pred_scores_argsort = np.argsort(-pred_scores) #Sort in decreasing order
|
| 147 |
+
|
| 148 |
+
mrr_score = 0
|
| 149 |
+
for rank, index in enumerate(pred_scores_argsort[0:self.mrr_at_k]):
|
| 150 |
+
if is_relevant[index]:
|
| 151 |
+
mrr_score = 1 / (rank+1)
|
| 152 |
+
|
| 153 |
+
all_mrr_scores.append(mrr_score)
|
| 154 |
+
|
| 155 |
+
mean_mrr = np.mean(all_mrr_scores)
|
| 156 |
+
logging.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives)))
|
| 157 |
+
logging.info("MRR@{}: {:.2f}".format(self.mrr_at_k, mean_mrr*100))
|
| 158 |
+
|
| 159 |
+
if output_path is not None:
|
| 160 |
+
csv_path = os.path.join(output_path, self.csv_file)
|
| 161 |
+
output_file_exists = os.path.isfile(csv_path)
|
| 162 |
+
with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f:
|
| 163 |
+
writer = csv.writer(f)
|
| 164 |
+
if not output_file_exists:
|
| 165 |
+
writer.writerow(self.csv_headers)
|
| 166 |
+
|
| 167 |
+
writer.writerow([epoch, steps, mean_mrr])
|
| 168 |
+
|
| 169 |
+
return mean_mrr
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
evaluator = CERerankingEvaluator(dev_samples)
|
| 173 |
+
|
| 174 |
+
# Configure the training
|
| 175 |
+
warmup_steps = 5000
|
| 176 |
+
logging.info("Warmup-steps: {}".format(warmup_steps))
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# Train the model
|
| 180 |
+
model.fit(train_dataloader=train_dataloader,
|
| 181 |
+
evaluator=evaluator,
|
| 182 |
+
epochs=num_epochs,
|
| 183 |
+
evaluation_steps=5000,
|
| 184 |
+
warmup_steps=warmup_steps,
|
| 185 |
+
output_path=model_save_path,
|
| 186 |
+
use_amp=True)
|
| 187 |
+
|
| 188 |
+
#Save latest model
|
| 189 |
+
model.save(model_save_path+'-latest')
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# Script was called via:
|
| 193 |
+
#python train_cross-encoder.py nreimers/TinyBERT_L-6_H-768_v2
|
vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|