nreimers commited on
Commit
036acda
1 Parent(s): 446026d
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "distilbert-base-uncased",
3
+ "activation": "gelu",
4
+ "architectures": [
5
+ "DistilBertModel"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "hidden_dim": 3072,
11
+ "initializer_range": 0.02,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "distilbert",
14
+ "n_heads": 12,
15
+ "n_layers": 6,
16
+ "pad_token_id": 0,
17
+ "qa_dropout": 0.1,
18
+ "seq_classif_dropout": 0.2,
19
+ "sinusoidal_pos_embds": false,
20
+ "tie_weights_": true,
21
+ "transformers_version": "4.2.2",
22
+ "vocab_size": 30522
23
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ef058424878c5079e25b9e1264964b40a154a94721c5ec77e43fa3506b1e3c5
3
+ size 265491187
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "name_or_path": "distilbert-base-uncased"}
train_script.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from sentence_transformers import losses, util, models
3
+ from sentence_transformers import SentencesDataset, LoggingHandler, SentenceTransformer, evaluation
4
+ from sentence_transformers.readers import InputExample
5
+ import logging
6
+ from datetime import datetime
7
+ import os
8
+ from shutil import copyfile
9
+ import sys
10
+ import math
11
+ import gzip
12
+ import random
13
+ import tqdm
14
+ from transformers import AutoTokenizer, AutoModel, BertModel
15
+ import transformers
16
+ import torch
17
+ from SPARTA import SPARTA
18
+ import json
19
+ import numpy as np
20
+ from torch.cuda.amp import autocast
21
+ import os
22
+ from shutil import copyfile
23
+ import datetime
24
+ from collections import defaultdict
25
+ from scipy.sparse import csc_matrix, csr_matrix
26
+
27
+ random.seed(42)
28
+
29
+ scaler = torch.cuda.amp.GradScaler()
30
+
31
+ #### Just some code to print debug information to stdout
32
+ logging.basicConfig(format='%(asctime)s - %(message)s',
33
+ datefmt='%Y-%m-%d %H:%M:%S',
34
+ level=logging.INFO,
35
+ handlers=[LoggingHandler()])
36
+ #### /print debug information to stdout
37
+
38
+ # Fill GPU
39
+ fill_gpu = torch.eye(85000, dtype=torch.float, device='cuda')
40
+ del fill_gpu
41
+
42
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
43
+
44
+ model_name = sys.argv[1]
45
+ model = SPARTA(model_name, device)
46
+
47
+ model_save_path = "output/msmarco-{}-{}".format(model_name.rstrip("/").split("/")[-1], datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
48
+ model.tokenizer.save_pretrained(model_save_path)
49
+
50
+
51
+ ##Distil setting
52
+ if 'distil' in model_name:
53
+ batch_size, num_negatives = 4, 35
54
+ else:
55
+ batch_size, num_negatives = 3, 20
56
+
57
+ logging.info(f"batch_size: {batch_size}")
58
+ logging.info(f"num_neg: {num_negatives}")
59
+
60
+
61
+ # Write self to path
62
+ os.makedirs(model_save_path, exist_ok=True)
63
+
64
+ train_script_path = os.path.join(model_save_path, 'train_script.py')
65
+ copyfile(__file__, train_script_path)
66
+ with open(train_script_path, 'a') as fOut:
67
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
68
+
69
+
70
+ ########################
71
+ corpus = {}
72
+ train_queries = {}
73
+
74
+
75
+
76
+
77
+
78
+ #### Read dev file
79
+ logging.info("Create dev dataset")
80
+ dev_corpus_max_size = 100*1000
81
+
82
+ dev_queries_file = '../data/queries.dev.small.tsv'
83
+ needed_pids = set()
84
+ needed_qids = set()
85
+ dev_qids = set()
86
+
87
+ dev_queries = {}
88
+ dev_corpus = {}
89
+ dev_rel_docs = {}
90
+
91
+ with open(dev_queries_file) as fIn:
92
+ for line in fIn:
93
+ qid, query = line.strip().split("\t")
94
+ dev_qids.add(qid)
95
+
96
+ with open('../data/qrels.dev.tsv') as fIn:
97
+ for line in fIn:
98
+ qid, _, pid, _ = line.strip().split('\t')
99
+
100
+ if qid not in dev_qids:
101
+ continue
102
+
103
+ if qid not in dev_rel_docs:
104
+ dev_rel_docs[qid] = set()
105
+ dev_rel_docs[qid].add(pid)
106
+
107
+ needed_pids.add(pid)
108
+ needed_qids.add(qid)
109
+
110
+ with open(dev_queries_file) as fIn:
111
+ for line in fIn:
112
+ qid, query = line.strip().split("\t")
113
+ if qid in needed_qids:
114
+ dev_queries[qid] = query
115
+
116
+ with gzip.open('../data/collection-rnd.tsv.gz', 'rt') as fIn:
117
+ for line in fIn:
118
+ pid, passage = line.strip().split("\t")
119
+ if pid in needed_pids or dev_corpus_max_size <= 0 or len(dev_corpus) <= dev_corpus_max_size:
120
+ dev_corpus[pid] = passage
121
+
122
+ dev_corpus_pids = list(dev_corpus.keys())
123
+ dev_corpus = [dev_corpus[pid] for pid in dev_corpus_pids]
124
+
125
+ ########### Eval functions
126
+
127
+ def compute_passage_emb(passages):
128
+ sparse_embeddings = []
129
+ bert_input_emb = model.bert_model.embeddings.word_embeddings(torch.tensor(list(range(0, len(model.tokenizer))), device=device))
130
+ sparse_vec_size = 2000
131
+
132
+ # Set Special tokens [CLS] [MASK] etc. to zero
133
+ for special_id in model.tokenizer.all_special_ids:
134
+ bert_input_emb[special_id] = 0 * bert_input_emb[special_id]
135
+
136
+ with torch.no_grad():
137
+ tokens = model.tokenizer(passages, padding=True, truncation=True, return_tensors='pt', max_length=500).to(device)
138
+ passage_embeddings = model.bert_model(**tokens).last_hidden_state
139
+ for passage_emb in passage_embeddings:
140
+ scores = torch.matmul(bert_input_emb, passage_emb.transpose(0, 1))
141
+ max_scores = torch.max(scores, dim=-1).values
142
+ relu_scores = torch.relu(max_scores) #Eq. 5
143
+ final_scores = torch.log(relu_scores + 1) # Eq. 6, final score
144
+
145
+ top_results = torch.topk(final_scores, k=sparse_vec_size, sorted=True)
146
+ passage_emb = defaultdict(float)
147
+ for score, idx in zip(top_results[0].cpu().tolist(), top_results[1].cpu().tolist()):
148
+ if score > 0:
149
+ passage_emb[idx] = score
150
+ else:
151
+ break
152
+
153
+ sparse_embeddings.append(passage_emb)
154
+
155
+ return sparse_embeddings
156
+
157
+ def evaluate_msmarco():
158
+ passage_embs_sorted = []
159
+ batch_size = 32
160
+
161
+ length_sorted_idx = np.argsort([-len(pas) for pas in dev_corpus])
162
+ dev_corpus_sorted = [dev_corpus[idx] for idx in length_sorted_idx]
163
+
164
+ for start_idx in tqdm.trange(0, len(dev_corpus_sorted), batch_size, desc='encode corpus'):
165
+ passage_embs_sorted.extend(compute_passage_emb(dev_corpus_sorted[start_idx:start_idx + batch_size]))
166
+
167
+ passage_embs = [passage_embs_sorted[idx] for idx in np.argsort(length_sorted_idx)]
168
+
169
+ logging.info("Create sparse matrix")
170
+ row = []
171
+ col = []
172
+ values = []
173
+ for pid, emb in enumerate(passage_embs):
174
+ for tid, score in emb.items():
175
+ row.append(tid)
176
+ col.append(pid)
177
+ values.append(score)
178
+
179
+ sparse = csr_matrix((values, (row, col)), shape=(len(model.tokenizer), len(passage_embs)), dtype=np.float)
180
+ logging.info("Scores: {}".format(sparse.shape))
181
+
182
+ mrr = []
183
+ k = 10
184
+ for qid, question in tqdm.tqdm(dev_queries.items(), desc="score"):
185
+ token_ids = model.tokenizer(question, add_special_tokens=False)['input_ids']
186
+
187
+ # Get the candidate passages
188
+ scores = np.asarray(sparse[token_ids, :].sum(axis=0)).squeeze(0)
189
+ top_k_ind = np.argpartition(scores, -k)[-k:]
190
+ hits = sorted([(dev_corpus_pids[pid], scores[pid]) for pid in top_k_ind], key=lambda x: x[1], reverse=True)
191
+
192
+ mrr_score = 0
193
+ for rank, hit in enumerate(hits[0:10]):
194
+ pid = hit[0]
195
+ if pid in dev_rel_docs[qid]:
196
+ mrr_score = 1 / (rank + 1)
197
+ break
198
+ mrr.append(mrr_score)
199
+
200
+ assert len(mrr) == len(dev_queries)
201
+ mrr = np.mean(mrr)
202
+ logging.info("MRR@10: {:.4f}".format(mrr))
203
+ return mrr
204
+
205
+
206
+ best_score = 0 #evaluate_msmarco()
207
+
208
+ #################
209
+
210
+
211
+ #### Read train file
212
+
213
+ with gzip.open('../data/collection.tsv.gz', 'rt') as fIn:
214
+ for line in fIn:
215
+ pid, passage = line.strip().split("\t")
216
+ corpus[pid] = passage
217
+
218
+
219
+ with open('../data/queries.train.tsv', 'r') as fIn:
220
+ for line in fIn:
221
+ qid, query = line.strip().split("\t")
222
+ train_queries[qid] = {'query': query,
223
+ 'pos': set(),
224
+ 'soft-pos': set(),
225
+ 'neg': set()}
226
+
227
+
228
+
229
+ #Read qrels file for relevant positives per query
230
+ with open('../data/qrels.train.tsv') as fIn:
231
+ for line in fIn:
232
+ qid, _, pid, _ = line.strip().split()
233
+ train_queries[qid]['pos'].add(pid)
234
+
235
+
236
+ logging.info("Clean train queries")
237
+ deleted_queries = 0
238
+ for qid in list(train_queries.keys()):
239
+ if len(train_queries[qid]['pos']) == 0:
240
+ deleted_queries += 1
241
+ del train_queries[qid]
242
+ continue
243
+
244
+ logging.info("Deleted queries pos-empty: {}".format(deleted_queries))
245
+
246
+ for hard_neg_file in ['../data/hard-negatives-all.jsonl.gz']: #'../data/hard-negatives-ann-roberta.jsonl.gz']: #['../data/hard-negatives-ann-msmarco-distilbert-base-v2.jsonl.gz', '../data/hard-negatives-ann.jsonl.gz', '../data/hard-negatives-ann-no_idnt.jsonl.gz', '../data/hard-negatives-all.jsonl.gz']:
247
+ logging.info("Read hard negatives: "+hard_neg_file)
248
+ with gzip.open(hard_neg_file, 'rt') as fIn:
249
+ try:
250
+ for line in fIn:
251
+ try:
252
+ data = json.loads(line)
253
+ except:
254
+ continue
255
+ qid = data['qid']
256
+
257
+ if qid in train_queries:
258
+ neg_added = 0
259
+ max_neg_added = 100
260
+
261
+ hits = sorted(data['hits'], key=lambda x: x['score'] if 'score' in x else x['bm25-score'], reverse=True)
262
+ for hit in hits:
263
+ pid = hit['corpus_id'] if 'corpus_id' in hit else hit['pid']
264
+
265
+ if pid in train_queries[qid]['pos']: #Skip entries we have as positives
266
+ continue
267
+
268
+ if hit['bert-score'] < 0.1 and neg_added < max_neg_added:
269
+ train_queries[qid]['neg'].add(pid)
270
+ neg_added += 1
271
+ elif hit['bert-score'] > 0.9:
272
+ train_queries[qid]['soft-pos'].add(pid)
273
+ except:
274
+ pass
275
+
276
+
277
+ logging.info("Clean train queries with empty neg set")
278
+ deleted_queries = 0
279
+ for qid in list(train_queries.keys()):
280
+ if len(train_queries[qid]['neg']) == 0:
281
+ deleted_queries += 1
282
+ del train_queries[qid]
283
+ continue
284
+
285
+ logging.info("Deleted queries neg empty: {}".format(deleted_queries))
286
+
287
+ train_queries = list(train_queries.values())
288
+ for idx in range(len(train_queries)):
289
+ train_queries[idx]['pos'] = list(train_queries[idx]['pos'])
290
+ train_queries[idx]['neg'] = list(train_queries[idx]['neg'])
291
+ train_queries[idx]['soft-pos'] = list(train_queries[idx]['soft-pos'])
292
+
293
+
294
+
295
+ ###########################################
296
+
297
+
298
+
299
+
300
+ ####
301
+ # Prepare optimizers
302
+ param_optimizer = list(model.named_parameters())
303
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
304
+ optimizer_grouped_parameters = [
305
+ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
306
+ {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
307
+ ]
308
+
309
+
310
+
311
+ grad_acc_steps, lr = 1, 2e-5
312
+ #grad_acc_steps, lr = 16, 2e-5
313
+
314
+
315
+ num_epochs = 1
316
+ optimizer = transformers.AdamW(model.parameters(), lr=lr, eps=1e-6) #optimizer_grouped_parameters
317
+ t_total = math.ceil(len(train_queries)/batch_size*num_epochs)
318
+ num_warmup_steps = int(t_total/grad_acc_steps * 0.1) #10% for warm up
319
+ scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total)
320
+ loss_fct = torch.nn.CrossEntropyLoss()
321
+ max_grad_norm = 1
322
+
323
+
324
+ for epoch in tqdm.trange(num_epochs, desc='Epochs'):
325
+ random.shuffle(train_queries)
326
+ idx = 0
327
+ for start_idx in tqdm.trange(0, len(train_queries), batch_size):
328
+ idx += 1
329
+ if (idx) % 5000 == 0:
330
+ score = evaluate_msmarco()
331
+ if score > best_score:
332
+ best_score = score
333
+ model.bert_model.save_pretrained(model_save_path)
334
+ logging.info(f"Save to {model_save_path}")
335
+
336
+ batch = train_queries[start_idx:start_idx+batch_size]
337
+ queries = [b['query'] for b in batch]
338
+
339
+ #First the positives
340
+ passages = [corpus[random.choice(b['pos'])] for b in batch]
341
+
342
+ #Then the negatives
343
+ for b in batch:
344
+ for pid in random.sample(b['neg'], k=min(len(b['neg']), num_negatives)):
345
+ passages.append(corpus[pid])
346
+
347
+
348
+ label = torch.tensor(list(range(len(batch))), device=device)
349
+
350
+ ##FP16
351
+ with autocast():
352
+ final_scores = model(queries, passages)
353
+ final_scores = 5*final_scores
354
+ loss_value = loss_fct(final_scores, label) / grad_acc_steps
355
+
356
+ scaler.scale(loss_value).backward()
357
+ if (idx + 1) % grad_acc_steps == 0:
358
+ scaler.unscale_(optimizer)
359
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
360
+ scaler.step(optimizer)
361
+ scaler.update()
362
+ model.zero_grad()
363
+ scheduler.step()
364
+
365
+
366
+ """
367
+ #Normal FP32 with grad acc
368
+ final_scores = model(query, passages)
369
+ #Compute loss
370
+ loss_value = loss_fct(final_scores, label)
371
+ if grad_acc_steps > 1:
372
+ loss_value /= grad_acc_steps
373
+ loss_value.backward()
374
+
375
+ if (idx+1) % grad_acc_steps == 0:
376
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
377
+ optimizer.step()
378
+ model.zero_grad()
379
+ scheduler.step()
380
+ """
381
+
382
+
383
+ logging.info("Final eval:")
384
+ evaluate_msmarco()
385
+
386
+ # Script was called via:
387
+ #python train_sparta_msmarco.py distilbert-base-uncased no weight decay, 5* score scaling
vocab.txt ADDED
The diff for this file is too large to render. See raw diff