nreimers commited on
Commit
67d6b31
1 Parent(s): d0a17bf
CERerankingEvaluator_results.csv ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ epoch,steps,MRR@10
2
+ 0,5000,0.5650095238095237
3
+ 0,10000,0.5849968253968254
4
+ 0,15000,0.6097650793650794
5
+ 0,20000,0.6246285714285715
6
+ 0,25000,0.6100253968253967
7
+ 0,30000,0.6270730158730159
8
+ 0,35000,0.6138888888888888
9
+ 0,40000,0.6240317460317462
10
+ 0,45000,0.6327619047619049
11
+ 0,50000,0.619631746031746
12
+ 0,55000,0.5871142857142856
13
+ 0,60000,0.6175809523809525
14
+ 0,65000,0.6081968253968254
15
+ 0,70000,0.6151301587301587
16
+ 0,75000,0.6093269841269842
17
+ 0,80000,0.6032571428571428
18
+ 0,85000,0.6138063492063491
19
+ 0,90000,0.6156952380952381
20
+ 0,95000,0.6303523809523809
21
+ 0,100000,0.6061523809523809
22
+ 0,105000,0.6133174603174603
23
+ 0,110000,0.6226063492063493
24
+ 0,115000,0.6176349206349206
25
+ 0,120000,0.6104761904761905
26
+ 0,125000,0.6332253968253967
27
+ 0,130000,0.6289523809523808
28
+ 0,135000,0.6181809523809524
29
+ 0,140000,0.6399841269841271
30
+ 0,145000,0.623073015873016
31
+ 0,150000,0.5963587301587302
32
+ 0,155000,0.6157301587301588
33
+ 0,160000,0.613120634920635
34
+ 0,165000,0.6089936507936508
35
+ 0,170000,0.6203301587301587
36
+ 0,175000,0.6171269841269841
37
+ 0,180000,0.5939269841269841
38
+ 0,185000,0.6417873015873015
39
+ 0,190000,0.6164476190476191
40
+ 0,195000,0.6215841269841269
41
+ 0,200000,0.6298984126984126
42
+ 0,205000,0.6030507936507936
43
+ 0,210000,0.6084730158730158
44
+ 0,215000,0.6092730158730159
45
+ 0,220000,0.5939650793650793
46
+ 0,225000,0.6124190476190475
47
+ 0,230000,0.6039269841269841
48
+ 0,235000,0.6253301587301587
49
+ 0,240000,0.634904761904762
50
+ 0,245000,0.6317015873015873
51
+ 0,250000,0.6196603174603175
52
+ 0,255000,0.6287396825396825
53
+ 0,260000,0.6095746031746031
54
+ 0,265000,0.6263492063492063
55
+ 0,270000,0.6171079365079365
56
+ 0,275000,0.6289523809523809
57
+ 0,280000,0.6202634920634921
58
+ 0,285000,0.6255301587301587
59
+ 0,290000,0.5993841269841268
60
+ 0,295000,0.6191841269841271
61
+ 0,300000,0.6203396825396825
62
+ 0,305000,0.6128412698412699
63
+ 0,310000,0.6090825396825398
64
+ 0,315000,0.5950539682539682
65
+ 0,320000,0.5990444444444444
66
+ 0,325000,0.6042412698412698
67
+ 0,330000,0.5960190476190476
68
+ 0,335000,0.6106222222222223
69
+ 0,340000,0.6055968253968255
70
+ 0,345000,0.5984095238095238
71
+ 0,350000,0.6142984126984128
72
+ 0,355000,0.6137746031746032
73
+ 0,360000,0.6018412698412698
74
+ 0,365000,0.6123079365079365
75
+ 0,370000,0.6130285714285715
76
+ 0,375000,0.6008412698412698
77
+ 0,380000,0.6020698412698412
78
+ 0,385000,0.6100222222222222
79
+ 0,390000,0.5971650793650793
80
+ 0,395000,0.5941968253968255
81
+ 0,400000,0.5871428571428571
82
+ 0,405000,0.6100190476190476
83
+ 0,410000,0.5903174603174602
84
+ 0,415000,0.5988317460317459
85
+ 0,420000,0.6132380952380952
86
+ 0,425000,0.6144412698412698
87
+ 0,430000,0.5980888888888888
88
+ 0,435000,0.5973746031746032
89
+ 0,440000,0.595384126984127
90
+ 0,445000,0.5871714285714286
91
+ 0,450000,0.6012412698412699
92
+ 0,455000,0.5873047619047618
93
+ 0,460000,0.595584126984127
94
+ 0,465000,0.5804285714285713
95
+ 0,470000,0.5887619047619047
96
+ 0,475000,0.5872761904761904
97
+ 0,480000,0.5871396825396825
98
+ 0,485000,0.5907174603174602
99
+ 0,490000,0.5880412698412699
100
+ 0,495000,0.5807968253968254
101
+ 0,500000,0.5909746031746032
102
+ 0,505000,0.5912984126984128
103
+ 0,510000,0.5942761904761905
104
+ 0,515000,0.5840222222222223
105
+ 0,520000,0.5852380952380952
106
+ 0,525000,0.582784126984127
107
+ 0,530000,0.5916190476190476
108
+ 0,535000,0.5777269841269841
109
+ 0,540000,0.582120634920635
110
+ 0,545000,0.5746634920634921
111
+ 0,550000,0.5746444444444445
112
+ 0,555000,0.5632444444444444
113
+ 0,560000,0.5799650793650795
114
+ 0,565000,0.5932507936507936
115
+ 0,570000,0.5816190476190476
116
+ 0,575000,0.5838857142857143
117
+ 0,580000,0.5859650793650794
118
+ 0,585000,0.5843968253968255
119
+ 0,590000,0.5840634920634921
120
+ 0,595000,0.5958285714285714
121
+ 0,600000,0.5842857142857142
122
+ 0,605000,0.5892507936507937
123
+ 0,610000,0.5914507936507937
124
+ 0,615000,0.5953968253968254
125
+ 0,620000,0.5925174603174603
126
+ 0,625000,0.5890857142857143
127
+ 0,-1,0.5890857142857143
README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cross-Encoder for MS Marco
2
+
3
+ This model uses [TinyBERT](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT), a tiny BERT model with only 6 layers. The base model is: General_TinyBERT_v2(6layer-768dim)
4
+
5
+ It was trained on [MS Marco Passage Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking) task.
6
+
7
+ The model can be used for Information Retrieval: Given a query, encode the query will all possible passages (e.g. retrieved with ElasticSearch). Then sort the passages in a decreasing order. See [SBERT.net Information Retrieval](https://github.com/UKPLab/sentence-transformers/tree/master/examples/applications/information-retrieval) for more details. The training code is available here: [SBERT.net Training MS Marco](https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/ms_marco)
8
+
9
+ ## Usage and Performance
10
+
11
+ Pre-trained models can be used like this:
12
+ ```
13
+ from sentence_transformers import CrossEncoder
14
+ model = CrossEncoder('model_name', max_length=512)
15
+ scores = model.predict([('Query', 'Paragraph1'), ('Query', 'Paragraph2') , ('Query', 'Paragraph3')])
16
+ ```
17
+
18
+ In the following table, we provide various pre-trained Cross-Encoders together with their performance on the [TREC Deep Learning 2019](https://microsoft.github.io/TREC-2019-Deep-Learning/) and the [MS Marco Passage Reranking](https://github.com/microsoft/MSMARCO-Passage-Ranking/) dataset.
19
+
20
+
21
+ | Model-Name | NDCG@10 (TREC DL 19) | MRR@10 (MS Marco Dev) | Docs / Sec (BertTokenizerFast) | Docs / Sec |
22
+ | ------------- |:-------------| -----| --- | --- |
23
+ | cross-encoder/ms-marco-TinyBERT-L-2 | 67.43 | 30.15 | 9000 | 780
24
+ | cross-encoder/ms-marco-TinyBERT-L-4 | 68.09 | 34.50 | 2900 | 760
25
+ | cross-encoder/ms-marco-TinyBERT-L-6 | 69.57 | 36.13 | 680 | 660
26
+ | cross-encoder/ms-marco-electra-base | 71.99 | 36.41 | 340 | 340
27
+ | *Other models* | | | |
28
+ | nboost/pt-tinybert-msmarco | 63.63 | 28.80 | 2900 | 760
29
+ | nboost/pt-bert-base-uncased-msmarco | 70.94 | 34.75 | 340 | 340|
30
+ | nboost/pt-bert-large-msmarco | 73.36 | 36.48 | 100 | 100 |
31
+ | Capreolus/electra-base-msmarco | 71.23 | | 340 | 340 |
32
+ | amberoad/bert-multilingual-passage-reranking-msmarco | 68.40 | | 330 | 330
33
+
34
+ Note: Runtime was computed on a V100 GPU. A bottleneck for smaller models is the standard Python tokenizer from Huggingface v3. Replacing it with the fast tokenizer based on Rust, the throughput is significantly improved:
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "nreimers/TinyBERT_L-6_H-768_v2",
3
+ "architectures": [
4
+ "BertForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "gradient_checkpointing": false,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 768,
11
+ "id2label": {
12
+ "0": "LABEL_0"
13
+ },
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 3072,
16
+ "label2id": {
17
+ "LABEL_0": 0
18
+ },
19
+ "layer_norm_eps": 1e-12,
20
+ "max_position_embeddings": 512,
21
+ "model_type": "bert",
22
+ "num_attention_heads": 12,
23
+ "num_hidden_layers": 6,
24
+ "pad_token_id": 0,
25
+ "type_vocab_size": 2,
26
+ "vocab_size": 30522
27
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cf70691b32e33dc1a57845ceeeb81bc70cfec62a0d584d063f13576403f2759
3
+ size 267871721
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"do_lower_case": true, "do_basic_tokenize": true, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": "/home/ukp-reimers/.cache/torch/transformers/68bd39c5d90b3f2d8da930bf3efe6ad55d21ae39cfbf97d37817cce8149bf2f3.dd8bd9bfd3664b530ea4e645105f557769387b3da9f79bdb55ed556bdd80611d", "tokenizer_file": null, "name_or_path": "nreimers/TinyBERT_L-6_H-768_v2"}
train_script.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from sentence_transformers import LoggingHandler
3
+ from sentence_transformers.cross_encoder import CrossEncoder
4
+ from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
5
+ from sentence_transformers import InputExample
6
+ import logging
7
+ from datetime import datetime
8
+ import gzip
9
+ import sys
10
+ import numpy as np
11
+ import os
12
+ from shutil import copyfile
13
+ import csv
14
+ import tqdm
15
+
16
+ #### Just some code to print debug information to stdout
17
+ logging.basicConfig(format='%(asctime)s - %(message)s',
18
+ datefmt='%Y-%m-%d %H:%M:%S',
19
+ level=logging.INFO,
20
+ handlers=[LoggingHandler()])
21
+ #### /print debug information to stdout
22
+
23
+
24
+ #Define our Cross-Encoder
25
+ model_name = sys.argv[1] #'google/electra-small-discriminator'
26
+ train_batch_size = 32
27
+ num_epochs = 1
28
+ model_save_path = 'output/training_ms-marco_cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
29
+
30
+ #We set num_labels=1, which predicts a continous score between 0 and 1
31
+ model = CrossEncoder(model_name, num_labels=1, max_length=512)
32
+
33
+
34
+ # Write self to path
35
+ os.makedirs(model_save_path, exist_ok=True)
36
+
37
+ train_script_path = os.path.join(model_save_path, 'train_script.py')
38
+ copyfile(__file__, train_script_path)
39
+ with open(train_script_path, 'a') as fOut:
40
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
41
+
42
+
43
+ corpus = {}
44
+ queries = {}
45
+
46
+ #### Read train file
47
+ with gzip.open('../data/collection.tsv.gz', 'rt') as fIn:
48
+ for line in fIn:
49
+ pid, passage = line.strip().split("\t")
50
+ corpus[pid] = passage
51
+
52
+ with open('../data/queries.train.tsv', 'r') as fIn:
53
+ for line in fIn:
54
+ qid, query = line.strip().split("\t")
55
+ queries[qid] = query
56
+
57
+
58
+
59
+ pos_neg_ration = (4+1)
60
+ cnt = 0
61
+ train_samples = []
62
+ dev_samples = {}
63
+
64
+ num_dev_queries = 125
65
+ num_max_dev_negatives = 200
66
+
67
+ with gzip.open('../data/qidpidtriples.rnd-shuf.train-eval.tsv.gz', 'rt') as fIn:
68
+ for line in fIn:
69
+ qid, pos_id, neg_id = line.strip().split()
70
+
71
+ if qid not in dev_samples and len(dev_samples) < num_dev_queries:
72
+ dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()}
73
+
74
+ if qid in dev_samples:
75
+ dev_samples[qid]['positive'].add(corpus[pos_id])
76
+
77
+ if len(dev_samples[qid]['negative']) < num_max_dev_negatives:
78
+ dev_samples[qid]['negative'].add(corpus[neg_id])
79
+
80
+ with gzip.open('../data/qidpidtriples.rnd-shuf.train.tsv.gz', 'rt') as fIn:
81
+ for line in tqdm.tqdm(fIn, unit_scale=True):
82
+ cnt += 1
83
+ qid, pos_id, neg_id = line.strip().split()
84
+ query = queries[qid]
85
+ if (cnt % pos_neg_ration) == 0:
86
+ passage = corpus[pos_id]
87
+ label = 1
88
+ else:
89
+ passage = corpus[neg_id]
90
+ label = 0
91
+
92
+ train_samples.append(InputExample(texts=[query, passage], label=label))
93
+
94
+ if len(train_samples) >= 2e7:
95
+ break
96
+
97
+
98
+
99
+ train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
100
+
101
+ # We add an evaluator, which evaluates the performance during training
102
+
103
+ class CERerankingEvaluator:
104
+ def __init__(self, samples, mrr_at_k: int = 10, name: str = ''):
105
+ self.samples = samples
106
+ self.name = name
107
+ self.mrr_at_k = mrr_at_k
108
+
109
+ if isinstance(self.samples, dict):
110
+ self.samples = list(self.samples.values())
111
+
112
+ self.csv_file = "CERerankingEvaluator" + ("_" + name if name else '') + "_results.csv"
113
+ self.csv_headers = ["epoch", "steps", "MRR@{}".format(mrr_at_k)]
114
+
115
+ def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
116
+ if epoch != -1:
117
+ if steps == -1:
118
+ out_txt = " after epoch {}:".format(epoch)
119
+ else:
120
+ out_txt = " in epoch {} after {} steps:".format(epoch, steps)
121
+ else:
122
+ out_txt = ":"
123
+
124
+ logging.info("CERerankingEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)
125
+
126
+ all_mrr_scores = []
127
+ num_queries = 0
128
+ num_positives = []
129
+ num_negatives = []
130
+ for instance in self.samples:
131
+ query = instance['query']
132
+ positive = list(instance['positive'])
133
+ negative = list(instance['negative'])
134
+ docs = positive + negative
135
+ is_relevant = [True]*len(positive) + [False]*len(negative)
136
+
137
+ if len(positive) == 0 or len(negative) == 0:
138
+ continue
139
+
140
+ num_queries += 1
141
+ num_positives.append(len(positive))
142
+ num_negatives.append(len(negative))
143
+
144
+ model_input = [[query, doc] for doc in docs]
145
+ pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False)
146
+ pred_scores_argsort = np.argsort(-pred_scores) #Sort in decreasing order
147
+
148
+ mrr_score = 0
149
+ for rank, index in enumerate(pred_scores_argsort[0:self.mrr_at_k]):
150
+ if is_relevant[index]:
151
+ mrr_score = 1 / (rank+1)
152
+
153
+ all_mrr_scores.append(mrr_score)
154
+
155
+ mean_mrr = np.mean(all_mrr_scores)
156
+ logging.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives)))
157
+ logging.info("MRR@{}: {:.2f}".format(self.mrr_at_k, mean_mrr*100))
158
+
159
+ if output_path is not None:
160
+ csv_path = os.path.join(output_path, self.csv_file)
161
+ output_file_exists = os.path.isfile(csv_path)
162
+ with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f:
163
+ writer = csv.writer(f)
164
+ if not output_file_exists:
165
+ writer.writerow(self.csv_headers)
166
+
167
+ writer.writerow([epoch, steps, mean_mrr])
168
+
169
+ return mean_mrr
170
+
171
+
172
+ evaluator = CERerankingEvaluator(dev_samples)
173
+
174
+ # Configure the training
175
+ warmup_steps = 5000
176
+ logging.info("Warmup-steps: {}".format(warmup_steps))
177
+
178
+
179
+ # Train the model
180
+ model.fit(train_dataloader=train_dataloader,
181
+ evaluator=evaluator,
182
+ epochs=num_epochs,
183
+ evaluation_steps=5000,
184
+ warmup_steps=warmup_steps,
185
+ output_path=model_save_path,
186
+ use_amp=True)
187
+
188
+ #Save latest model
189
+ model.save(model_save_path+'-latest')
190
+
191
+
192
+ # Script was called via:
193
+ #python train_cross-encoder.py nreimers/TinyBERT_L-6_H-768_v2
vocab.txt ADDED
The diff for this file is too large to render. See raw diff