nreimers commited on
Commit
4ceabf2
1 Parent(s): b8feffc
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "nreimers/mMiniLMv2-L6-H384-distilled-from-XLMR-Large",
3
+ "architectures": [
4
+ "XLMRobertaForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 384,
13
+ "id2label": {
14
+ "0": "LABEL_0"
15
+ },
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 1536,
18
+ "label2id": {
19
+ "LABEL_0": 0
20
+ },
21
+ "layer_norm_eps": 1e-05,
22
+ "max_position_embeddings": 514,
23
+ "model_type": "xlm-roberta",
24
+ "num_attention_heads": 12,
25
+ "num_hidden_layers": 6,
26
+ "pad_token_id": 1,
27
+ "position_embedding_type": "absolute",
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.18.0",
30
+ "type_vocab_size": 1,
31
+ "use_cache": true,
32
+ "vocab_size": 250002
33
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:795f7056ac9d129f5a7fe4bb77cda35cbefa5e5e98b507cf8ee0b2c6db122796
3
+ size 428014765
sentencepiece.bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfc8146abe2a0488e9e2a0c56de7952f7c11ab059eca145a0a727afce0db2865
3
+ size 5069051
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62c24cdc13d4c9952d63718d6c9fa4c287974249e16b7ade6d5a85e7bbb75626
3
+ size 17082660
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "nreimers/mMiniLMv2-L6-H384-distilled-from-XLMR-Large", "tokenizer_class": "XLMRobertaTokenizer"}
train_script.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from codecs import EncodedFile
2
+ from datetime import datetime
3
+ from typing import Optional
4
+
5
+ import datasets
6
+ import torch
7
+ from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything
8
+ from torch.utils.data import DataLoader
9
+ from transformers import (
10
+ AutoConfig,
11
+ AutoModelForSequenceClassification,
12
+ AutoTokenizer,
13
+ get_linear_schedule_with_warmup,
14
+ get_scheduler,
15
+ )
16
+ import torch
17
+ import sys
18
+ import os
19
+ from argparse import ArgumentParser
20
+ from datasets import load_dataset
21
+ import tqdm
22
+ import json
23
+ import gzip
24
+ import random
25
+ from pytorch_lightning.callbacks import ModelCheckpoint
26
+ import numpy as np
27
+ from shutil import copyfile
28
+ from pytorch_lightning.loggers import WandbLogger
29
+ import transformers
30
+
31
+
32
+ class MSMARCOData(LightningDataModule):
33
+ def __init__(
34
+ self,
35
+ model_name: str,
36
+ triplets_path: str,
37
+ langs,
38
+ max_seq_length: int = 250,
39
+ train_batch_size: int = 32,
40
+ eval_batch_size: int = 32,
41
+ num_negs: int = 3,
42
+ cross_lingual_chance: float = 0.0,
43
+ **kwargs,
44
+ ):
45
+ super().__init__()
46
+ self.model_name = model_name
47
+ self.triplets_path = triplets_path
48
+ self.max_seq_length = max_seq_length
49
+ self.train_batch_size = train_batch_size
50
+ self.eval_batch_size = eval_batch_size
51
+ self.langs = langs
52
+ self.num_negs = num_negs
53
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
54
+ self.cross_lingual_chance = cross_lingual_chance #Probability for cross-lingual batches
55
+
56
+ #def setup(self, stage: str):
57
+ print(f"!!!!!!!!!!!!!!!!!! SETUP {os.getpid()} !!!!!!!!!!!!!!!")
58
+
59
+ #Get the queries
60
+ self.queries = {lang: {} for lang in self.langs}
61
+
62
+ for lang in self.langs:
63
+ for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'queries-{lang}')['train'], desc=lang):
64
+ self.queries[lang][row['id']] = row['text']
65
+
66
+ #Get the passages
67
+ self.collections = {lang: load_dataset('unicamp-dl/mmarco', f'collection-{lang}')['collection'] for lang in self.langs}
68
+
69
+ #Get the triplets
70
+ with gzip.open(self.triplets_path, 'rt') as fIn:
71
+ self.triplets = [json.loads(line) for line in tqdm.tqdm(fIn, desc="triplets", total=502938)]
72
+ """
73
+ self.triplets = []
74
+ for line in tqdm.tqdm(fIn):
75
+ self.triplets.append(json.loads(line))
76
+ if len(self.triplets) >= 1000:
77
+ break
78
+ """
79
+
80
+ def collate_fn(self, batch):
81
+ cross_lingual_batch = random.random() < self.cross_lingual_chance
82
+
83
+ #Create data for list-rank-loss
84
+ query_doc_pairs = [[] for _ in range(1+self.num_negs)]
85
+
86
+ for row in batch:
87
+ qid = row['qid']
88
+ pos_id = random.choice(row['pos'])
89
+
90
+ query_lang = random.choice(self.langs)
91
+ query_text = self.queries[query_lang][qid]
92
+
93
+ doc_lang = random.choice(self.langs) if cross_lingual_batch else query_lang
94
+ query_doc_pairs[0].append((query_text, self.collections[doc_lang][pos_id]['text']))
95
+
96
+ dense_bm25_neg = list(set(row['dense_neg'] + row['bm25_neg']))
97
+ neg_ids = random.sample(dense_bm25_neg, self.num_negs)
98
+
99
+ for num_neg, neg_id in enumerate(neg_ids):
100
+ doc_lang = random.choice(self.langs) if cross_lingual_batch else query_lang
101
+ query_doc_pairs[1+num_neg].append((query_text, self.collections[doc_lang][neg_id]['text']))
102
+
103
+ #Now tokenize the data
104
+ features = [self.tokenizer(qd_pair, max_length=self.max_seq_length, padding=True, truncation='only_second', return_tensors="pt") for qd_pair in query_doc_pairs]
105
+
106
+ return features
107
+
108
+ def train_dataloader(self):
109
+ return DataLoader(self.triplets, shuffle=True, batch_size=self.train_batch_size, num_workers=1, pin_memory=True, collate_fn=self.collate_fn)
110
+
111
+
112
+
113
+
114
+
115
+ class ListRankLoss(LightningModule):
116
+ def __init__(
117
+ self,
118
+ model_name: str,
119
+ learning_rate: float = 2e-5,
120
+ warmup_steps: int = 1000,
121
+ weight_decay: float = 0.01,
122
+ train_batch_size: int = 32,
123
+ eval_batch_size: int = 32,
124
+ **kwargs,
125
+ ):
126
+ super().__init__()
127
+
128
+ self.save_hyperparameters()
129
+ print(self.hparams)
130
+
131
+ self.config = AutoConfig.from_pretrained(model_name, num_labels=1)
132
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
133
+ self.loss_fct = torch.nn.CrossEntropyLoss()
134
+ self.global_train_step = 0
135
+
136
+
137
+ def forward(self, **inputs):
138
+ return self.model(**inputs)
139
+
140
+ def training_step(self, batch, batch_idx):
141
+ pred_scores = []
142
+ scores = torch.tensor([0] * len(batch[0]['input_ids']), device=self.model.device)
143
+
144
+ for feature in batch:
145
+ pred_scores.append(self(**feature).logits.squeeze())
146
+
147
+ pred_scores = torch.stack(pred_scores, 1)
148
+ loss_value = self.loss_fct(pred_scores, scores)
149
+ self.global_train_step += 1
150
+ self.log('global_train_step', self.global_train_step)
151
+ self.log("train/loss", loss_value)
152
+
153
+ return loss_value
154
+
155
+
156
+ def setup(self, stage=None) -> None:
157
+ if stage != "fit":
158
+ return
159
+ # Get dataloader by calling it - train_dataloader() is called after setup() by default
160
+ train_loader = self.trainer.datamodule.train_dataloader()
161
+
162
+ # Calculate total steps
163
+ tb_size = self.hparams.train_batch_size * max(1, self.trainer.gpus)
164
+ ab_size = self.trainer.accumulate_grad_batches
165
+ self.total_steps = (len(train_loader) // ab_size) * self.trainer.max_epochs
166
+
167
+ print(f"{tb_size=}")
168
+ print(f"{ab_size=}")
169
+ print(f"{len(train_loader)=}")
170
+ print(f"{self.total_steps=}")
171
+
172
+
173
+ def configure_optimizers(self):
174
+ """Prepare optimizer and schedule (linear warmup and decay)"""
175
+ model = self.model
176
+ no_decay = ["bias", "LayerNorm.weight"]
177
+ optimizer_grouped_parameters = [
178
+ {
179
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
180
+ "weight_decay": self.hparams.weight_decay,
181
+ },
182
+ {
183
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
184
+ "weight_decay": 0.0,
185
+ },
186
+ ]
187
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate)
188
+
189
+ """
190
+ lr_scheduler = get_scheduler(
191
+ name="linear",
192
+ optimizer=optimizer,
193
+ num_warmup_steps=self.hparams.warmup_steps,
194
+ num_training_steps=self.total_steps,
195
+ )
196
+ """
197
+ lr_scheduler = get_linear_schedule_with_warmup(
198
+ optimizer,
199
+ num_warmup_steps=self.hparams.warmup_steps,
200
+ num_training_steps=self.total_steps,
201
+ )
202
+
203
+ scheduler = {"scheduler": lr_scheduler, "interval": "step", "frequency": 1}
204
+ return [optimizer], [scheduler]
205
+
206
+
207
+
208
+ def main(args):
209
+ dm = MSMARCOData(
210
+ model_name=args.model,
211
+ langs=args.langs,
212
+ triplets_path='data/msmarco-hard-triplets.jsonl.gz',
213
+ train_batch_size=args.batch_size,
214
+ cross_lingual_chance=args.cross_lingual_chance,
215
+ num_negs=args.num_negs
216
+ )
217
+ output_dir = f"output/{args.model.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
218
+ print("Output_dir:", output_dir)
219
+
220
+ os.makedirs(output_dir, exist_ok=True)
221
+
222
+ wandb_logger = WandbLogger(project="multilingual-cross-encoder", name=output_dir.split("/")[-1])
223
+
224
+ train_script_path = os.path.join(output_dir, 'train_script.py')
225
+ copyfile(__file__, train_script_path)
226
+ with open(train_script_path, 'a') as fOut:
227
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
228
+
229
+
230
+ # saves top-K checkpoints based on "val_loss" metric
231
+ checkpoint_callback = ModelCheckpoint(
232
+ every_n_train_steps=25000,
233
+ save_top_k=5,
234
+ monitor="global_train_step",
235
+ mode="max",
236
+ dirpath=output_dir,
237
+ filename="ckpt-{global_train_step}",
238
+ )
239
+
240
+
241
+ model = ListRankLoss(model_name=args.model, learning_rate=args.lr)
242
+
243
+ trainer = Trainer(max_epochs=args.epochs,
244
+ accelerator="gpu",
245
+ devices=args.num_gpus,
246
+ precision=args.precision,
247
+ strategy=args.strategy,
248
+ default_root_dir=output_dir,
249
+ callbacks=[checkpoint_callback],
250
+ logger=wandb_logger
251
+ )
252
+
253
+ trainer.fit(model, datamodule=dm)
254
+
255
+ #Save final HF model
256
+ final_path = os.path.join(output_dir, "final")
257
+ dm.tokenizer.save_pretrained(final_path)
258
+ model.model.save_pretrained(final_path)
259
+
260
+
261
+ def eval(args):
262
+ import ir_datasets
263
+
264
+
265
+ model = ListRankLoss.load_from_checkpoint(args.ckpt)
266
+ hf_model = model.model.cuda()
267
+ tokenizer = AutoTokenizer.from_pretrained(model.hparams.model_name)
268
+
269
+ dev_qids = set()
270
+
271
+ dev_queries = {}
272
+ dev_rel_docs = {}
273
+ needed_pids = set()
274
+ needed_qids = set()
275
+
276
+ corpus = {}
277
+ retrieved_docs = {}
278
+
279
+ dataset = ir_datasets.load("msmarco-passage/dev/small")
280
+ for query in dataset.queries_iter():
281
+ dev_qids.add(query.query_id)
282
+
283
+
284
+ with open('data/qrels.dev.tsv') as fIn:
285
+ for line in fIn:
286
+ qid, _, pid, _ = line.strip().split('\t')
287
+
288
+ if qid not in dev_qids:
289
+ continue
290
+
291
+ if qid not in dev_rel_docs:
292
+ dev_rel_docs[qid] = set()
293
+ dev_rel_docs[qid].add(pid)
294
+
295
+ retrieved_docs[qid] = set()
296
+ needed_qids.add(qid)
297
+ needed_pids.add(pid)
298
+
299
+ for query in dataset.queries_iter():
300
+ qid = query.query_id
301
+ if qid in needed_qids:
302
+ dev_queries[qid] = query.text
303
+
304
+ with open('data/top1000.dev', 'rt') as fIn:
305
+ for line in fIn:
306
+ qid, pid, query, passage = line.strip().split("\t")
307
+ corpus[pid] = passage
308
+ retrieved_docs[qid].add(pid)
309
+
310
+
311
+ ## Run evaluator
312
+ print("Queries: {}".format(len(dev_queries)))
313
+
314
+ mrr_scores = []
315
+ hf_model.eval()
316
+
317
+ with torch.no_grad():
318
+ for qid in tqdm.tqdm(dev_queries, total=len(dev_queries)):
319
+ query = dev_queries[qid]
320
+ top_pids = list(retrieved_docs[qid])
321
+ cross_inp = [[query, corpus[pid]] for pid in top_pids]
322
+
323
+ encoded = tokenizer(cross_inp, padding=True, truncation=True, return_tensors="pt").to('cuda')
324
+ output = model(**encoded)
325
+ bert_score = output.logits.detach().cpu().numpy()
326
+ bert_score = np.squeeze(bert_score)
327
+
328
+ argsort = np.argsort(-bert_score)
329
+
330
+ rank_score = 0
331
+ for rank, idx in enumerate(argsort[0:10]):
332
+ pid = top_pids[idx]
333
+ if pid in dev_rel_docs[qid]:
334
+ rank_score = 1/(rank+1)
335
+ break
336
+
337
+ mrr_scores.append(rank_score)
338
+
339
+ if len(mrr_scores) % 10 == 0:
340
+ print("{} MRR@10: {:.2f}".format(len(mrr_scores), 100*np.mean(mrr_scores)))
341
+
342
+ print("MRR@10: {:.2f}".format(np.mean(mrr_scores)*100))
343
+
344
+
345
+ if __name__ == '__main__':
346
+ parser = ArgumentParser()
347
+ parser.add_argument("--num_gpus", type=int, default=1)
348
+ parser.add_argument("--batch_size", type=int, default=32)
349
+ parser.add_argument("--epochs", type=int, default=10)
350
+ parser.add_argument("--strategy", default=None)
351
+ parser.add_argument("--model", default='microsoft/mdeberta-v3-base')
352
+ parser.add_argument("--eval", action="store_true")
353
+ parser.add_argument("--ckpt")
354
+ parser.add_argument("--cross_lingual_chance", type=float, default=0.33)
355
+ parser.add_argument("--precision", default=16)
356
+ parser.add_argument("--num_negs", type=int, default=3)
357
+ parser.add_argument("--lr", type=float, default=2e-5)
358
+ parser.add_argument("--langs", nargs="+", default=['english', 'chinese', 'french', 'german', 'indonesian', 'italian', 'portuguese', 'russian', 'spanish', 'arabic', 'dutch', 'hindi', 'japanese', 'vietnamese'])
359
+
360
+
361
+ args = parser.parse_args()
362
+
363
+ if args.eval:
364
+ eval(args)
365
+ else:
366
+ main(args)
367
+
368
+
369
+ # Script was called via:
370
+ #python cross_mutlilingual.py --model nreimers/mMiniLMv2-L6-H384-distilled-from-XLMR-Large