RubenAMtz commited on
Commit
231542c
1 Parent(s): c77b67b

Upload 9 files

Browse files
.gitattributes CHANGED
@@ -2,27 +2,20 @@
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
@@ -30,5 +23,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
31
  *.xz filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
- *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
5
  *.ftz filter=lfs diff=lfs merge=lfs -text
6
  *.gz filter=lfs diff=lfs merge=lfs -text
7
  *.h5 filter=lfs diff=lfs merge=lfs -text
8
  *.joblib filter=lfs diff=lfs merge=lfs -text
9
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
10
  *.model filter=lfs diff=lfs merge=lfs -text
11
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
12
  *.onnx filter=lfs diff=lfs merge=lfs -text
13
  *.ot filter=lfs diff=lfs merge=lfs -text
14
  *.parquet filter=lfs diff=lfs merge=lfs -text
15
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
16
  *.pt filter=lfs diff=lfs merge=lfs -text
17
  *.pth filter=lfs diff=lfs merge=lfs -text
18
  *.rar filter=lfs diff=lfs merge=lfs -text
 
19
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
  *.tar.* filter=lfs diff=lfs merge=lfs -text
21
  *.tflite filter=lfs diff=lfs merge=lfs -text
 
23
  *.wasm filter=lfs diff=lfs merge=lfs -text
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ - ar
6
+ - zh
7
+ - nl
8
+ - fr
9
+ - de
10
+ - hi
11
+ - in
12
+ - it
13
+ - ja
14
+ - pt
15
+ - ru
16
+ - es
17
+ - vi
18
+ - multilingual
19
+ datasets:
20
+ - unicamp-dl/mmarco
21
+ ---
22
+ # Cross-Encoder for multilingual MS Marco
23
+
24
+ This model was trained on the [MMARCO](https://hf.co/unicamp-dl/mmarco) dataset. It is a machine translated version of MS MARCO using Google Translate. It was translated to 14 languages. In our experiments, we observed that it performs also well for other languages.
25
+
26
+ As a base model, we used the [multilingual MiniLMv2](https://huggingface.co/nreimers/mMiniLMv2-L12-H384-distilled-from-XLMR-Large) model.
27
+
28
+ The model can be used for Information Retrieval: Given a query, encode the query will all possible passages (e.g. retrieved with ElasticSearch). Then sort the passages in a decreasing order. See [SBERT.net Retrieve & Re-rank](https://www.sbert.net/examples/applications/retrieve_rerank/README.html) for more details. The training code is available here: [SBERT.net Training MS Marco](https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/ms_marco)
29
+
30
+ ## Usage with SentenceTransformers
31
+
32
+ The usage becomes easy when you have [SentenceTransformers](https://www.sbert.net/) installed. Then, you can use the pre-trained models like this:
33
+ ```python
34
+ from sentence_transformers import CrossEncoder
35
+ model = CrossEncoder('model_name')
36
+ scores = model.predict([('Query', 'Paragraph1'), ('Query', 'Paragraph2') , ('Query', 'Paragraph3')])
37
+ ```
38
+
39
+
40
+
41
+
42
+ ## Usage with Transformers
43
+
44
+ ```python
45
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
46
+ import torch
47
+
48
+ model = AutoModelForSequenceClassification.from_pretrained('model_name')
49
+ tokenizer = AutoTokenizer.from_pretrained('model_name')
50
+
51
+ features = tokenizer(['How many people live in Berlin?', 'How many people live in Berlin?'], ['Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.'], padding=True, truncation=True, return_tensors="pt")
52
+
53
+ model.eval()
54
+ with torch.no_grad():
55
+ scores = model(**features).logits
56
+ print(scores)
57
+ ```
58
+
59
+
60
+
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "nreimers/mMiniLMv2-L12-H384-distilled-from-XLMR-Large",
3
+ "architectures": [
4
+ "XLMRobertaForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 384,
13
+ "id2label": {
14
+ "0": "LABEL_0"
15
+ },
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 1536,
18
+ "label2id": {
19
+ "LABEL_0": 0
20
+ },
21
+ "layer_norm_eps": 1e-05,
22
+ "max_position_embeddings": 514,
23
+ "model_type": "xlm-roberta",
24
+ "num_attention_heads": 12,
25
+ "num_hidden_layers": 12,
26
+ "pad_token_id": 1,
27
+ "position_embedding_type": "absolute",
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.18.0",
30
+ "type_vocab_size": 1,
31
+ "use_cache": true,
32
+ "vocab_size": 250002,
33
+ "sbert_ce_default_activation_function": "torch.nn.modules.linear.Identity"
34
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1abc209e54d70bbcb08c1b5111a924fb99c0428f51cab1659310ccdcab69dc03
3
+ size 470633197
sentencepiece.bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8a54190d2b9256881ed34ab5428786629f929dd5a579350a6ef4735b86a9208
3
+ size 132
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62c24cdc13d4c9952d63718d6c9fa4c287974249e16b7ade6d5a85e7bbb75626
3
+ size 17082660
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "nreimers/mMiniLMv2-L12-H384-distilled-from-XLMR-Large", "tokenizer_class": "XLMRobertaTokenizer"}
train_script.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from codecs import EncodedFile
2
+ from datetime import datetime
3
+ from typing import Optional
4
+
5
+ import datasets
6
+ import torch
7
+ from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything
8
+ from torch.utils.data import DataLoader
9
+ from transformers import (
10
+ AutoConfig,
11
+ AutoModelForSequenceClassification,
12
+ AutoTokenizer,
13
+ get_linear_schedule_with_warmup,
14
+ get_scheduler,
15
+ )
16
+ import torch
17
+ import sys
18
+ import os
19
+ from argparse import ArgumentParser
20
+ from datasets import load_dataset
21
+ import tqdm
22
+ import json
23
+ import gzip
24
+ import random
25
+ from pytorch_lightning.callbacks import ModelCheckpoint
26
+ import numpy as np
27
+ from shutil import copyfile
28
+ from pytorch_lightning.loggers import WandbLogger
29
+ import transformers
30
+
31
+
32
+ class MSMARCOData(LightningDataModule):
33
+ def __init__(
34
+ self,
35
+ model_name: str,
36
+ triplets_path: str,
37
+ langs,
38
+ max_seq_length: int = 250,
39
+ train_batch_size: int = 32,
40
+ eval_batch_size: int = 32,
41
+ num_negs: int = 3,
42
+ cross_lingual_chance: float = 0.0,
43
+ **kwargs,
44
+ ):
45
+ super().__init__()
46
+ self.model_name = model_name
47
+ self.triplets_path = triplets_path
48
+ self.max_seq_length = max_seq_length
49
+ self.train_batch_size = train_batch_size
50
+ self.eval_batch_size = eval_batch_size
51
+ self.langs = langs
52
+ self.num_negs = num_negs
53
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
54
+ self.cross_lingual_chance = cross_lingual_chance #Probability for cross-lingual batches
55
+
56
+ #def setup(self, stage: str):
57
+ print(f"!!!!!!!!!!!!!!!!!! SETUP {os.getpid()} !!!!!!!!!!!!!!!")
58
+
59
+ #Get the queries
60
+ self.queries = {lang: {} for lang in self.langs}
61
+
62
+ for lang in self.langs:
63
+ for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'queries-{lang}')['train'], desc=lang):
64
+ self.queries[lang][row['id']] = row['text']
65
+
66
+ #Get the passages
67
+ self.collections = {lang: load_dataset('unicamp-dl/mmarco', f'collection-{lang}')['collection'] for lang in self.langs}
68
+
69
+ #Get the triplets
70
+ with gzip.open(self.triplets_path, 'rt') as fIn:
71
+ self.triplets = [json.loads(line) for line in tqdm.tqdm(fIn, desc="triplets", total=502938)]
72
+ """
73
+ self.triplets = []
74
+ for line in tqdm.tqdm(fIn):
75
+ self.triplets.append(json.loads(line))
76
+ if len(self.triplets) >= 1000:
77
+ break
78
+ """
79
+
80
+ def collate_fn(self, batch):
81
+ cross_lingual_batch = random.random() < self.cross_lingual_chance
82
+
83
+ #Create data for list-rank-loss
84
+ query_doc_pairs = [[] for _ in range(1+self.num_negs)]
85
+
86
+ for row in batch:
87
+ qid = row['qid']
88
+ pos_id = random.choice(row['pos'])
89
+
90
+ query_lang = random.choice(self.langs)
91
+ query_text = self.queries[query_lang][qid]
92
+
93
+ doc_lang = random.choice(self.langs) if cross_lingual_batch else query_lang
94
+ query_doc_pairs[0].append((query_text, self.collections[doc_lang][pos_id]['text']))
95
+
96
+ dense_bm25_neg = list(set(row['dense_neg'] + row['bm25_neg']))
97
+ neg_ids = random.sample(dense_bm25_neg, self.num_negs)
98
+
99
+ for num_neg, neg_id in enumerate(neg_ids):
100
+ doc_lang = random.choice(self.langs) if cross_lingual_batch else query_lang
101
+ query_doc_pairs[1+num_neg].append((query_text, self.collections[doc_lang][neg_id]['text']))
102
+
103
+ #Now tokenize the data
104
+ features = [self.tokenizer(qd_pair, max_length=self.max_seq_length, padding=True, truncation='only_second', return_tensors="pt") for qd_pair in query_doc_pairs]
105
+
106
+ return features
107
+
108
+ def train_dataloader(self):
109
+ return DataLoader(self.triplets, shuffle=True, batch_size=self.train_batch_size, num_workers=1, pin_memory=True, collate_fn=self.collate_fn)
110
+
111
+
112
+
113
+
114
+
115
+ class ListRankLoss(LightningModule):
116
+ def __init__(
117
+ self,
118
+ model_name: str,
119
+ learning_rate: float = 2e-5,
120
+ warmup_steps: int = 1000,
121
+ weight_decay: float = 0.01,
122
+ train_batch_size: int = 32,
123
+ eval_batch_size: int = 32,
124
+ **kwargs,
125
+ ):
126
+ super().__init__()
127
+
128
+ self.save_hyperparameters()
129
+ print(self.hparams)
130
+
131
+ self.config = AutoConfig.from_pretrained(model_name, num_labels=1)
132
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
133
+ self.loss_fct = torch.nn.CrossEntropyLoss()
134
+ self.global_train_step = 0
135
+
136
+
137
+ def forward(self, **inputs):
138
+ return self.model(**inputs)
139
+
140
+ def training_step(self, batch, batch_idx):
141
+ pred_scores = []
142
+ scores = torch.tensor([0] * len(batch[0]['input_ids']), device=self.model.device)
143
+
144
+ for feature in batch:
145
+ pred_scores.append(self(**feature).logits.squeeze())
146
+
147
+ pred_scores = torch.stack(pred_scores, 1)
148
+ loss_value = self.loss_fct(pred_scores, scores)
149
+ self.global_train_step += 1
150
+ self.log('global_train_step', self.global_train_step)
151
+ self.log("train/loss", loss_value)
152
+
153
+ return loss_value
154
+
155
+
156
+ def setup(self, stage=None) -> None:
157
+ if stage != "fit":
158
+ return
159
+ # Get dataloader by calling it - train_dataloader() is called after setup() by default
160
+ train_loader = self.trainer.datamodule.train_dataloader()
161
+
162
+ # Calculate total steps
163
+ tb_size = self.hparams.train_batch_size * max(1, self.trainer.gpus)
164
+ ab_size = self.trainer.accumulate_grad_batches
165
+ self.total_steps = (len(train_loader) // ab_size) * self.trainer.max_epochs
166
+
167
+ print(f"{tb_size=}")
168
+ print(f"{ab_size=}")
169
+ print(f"{len(train_loader)=}")
170
+ print(f"{self.total_steps=}")
171
+
172
+
173
+ def configure_optimizers(self):
174
+ """Prepare optimizer and schedule (linear warmup and decay)"""
175
+ model = self.model
176
+ no_decay = ["bias", "LayerNorm.weight"]
177
+ optimizer_grouped_parameters = [
178
+ {
179
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
180
+ "weight_decay": self.hparams.weight_decay,
181
+ },
182
+ {
183
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
184
+ "weight_decay": 0.0,
185
+ },
186
+ ]
187
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate)
188
+
189
+ lr_scheduler = get_scheduler(
190
+ name="linear",
191
+ optimizer=optimizer,
192
+ num_warmup_steps=self.hparams.warmup_steps,
193
+ num_training_steps=self.total_steps,
194
+ )
195
+
196
+ scheduler = {"scheduler": lr_scheduler, "interval": "step", "frequency": 1}
197
+ return [optimizer], [scheduler]
198
+
199
+
200
+
201
+ def main(args):
202
+ dm = MSMARCOData(
203
+ model_name=args.model,
204
+ langs=args.langs,
205
+ triplets_path='data/msmarco-hard-triplets.jsonl.gz',
206
+ train_batch_size=args.batch_size,
207
+ cross_lingual_chance=args.cross_lingual_chance,
208
+ num_negs=args.num_negs
209
+ )
210
+ output_dir = f"output/{args.model.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
211
+ print("Output_dir:", output_dir)
212
+
213
+ os.makedirs(output_dir, exist_ok=True)
214
+
215
+ wandb_logger = WandbLogger(project="multilingual-cross-encoder", name=output_dir.split("/")[-1])
216
+
217
+ train_script_path = os.path.join(output_dir, 'train_script.py')
218
+ copyfile(__file__, train_script_path)
219
+ with open(train_script_path, 'a') as fOut:
220
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
221
+
222
+
223
+ # saves top-K checkpoints based on "val_loss" metric
224
+ checkpoint_callback = ModelCheckpoint(
225
+ every_n_train_steps=25000,
226
+ save_top_k=5,
227
+ monitor="global_train_step",
228
+ mode="max",
229
+ dirpath=output_dir,
230
+ filename="ckpt-{global_train_step}",
231
+ )
232
+
233
+
234
+ model = ListRankLoss(model_name=args.model)
235
+
236
+ trainer = Trainer(max_epochs=args.epochs,
237
+ accelerator="gpu",
238
+ devices=args.num_gpus,
239
+ precision=args.precision,
240
+ strategy=args.strategy,
241
+ default_root_dir=output_dir,
242
+ callbacks=[checkpoint_callback],
243
+ logger=wandb_logger
244
+ )
245
+
246
+ trainer.fit(model, datamodule=dm)
247
+
248
+ #Save final HF model
249
+ final_path = os.path.join(output_dir, "final")
250
+ dm.tokenizer.save_pretrained(final_path)
251
+ model.model.save_pretrained(final_path)
252
+
253
+
254
+ def eval(args):
255
+ import ir_datasets
256
+
257
+
258
+ model = ListRankLoss.load_from_checkpoint(args.ckpt)
259
+ hf_model = model.model.cuda()
260
+ tokenizer = AutoTokenizer.from_pretrained(model.hparams.model_name)
261
+
262
+ dev_qids = set()
263
+
264
+ dev_queries = {}
265
+ dev_rel_docs = {}
266
+ needed_pids = set()
267
+ needed_qids = set()
268
+
269
+ corpus = {}
270
+ retrieved_docs = {}
271
+
272
+ dataset = ir_datasets.load("msmarco-passage/dev/small")
273
+ for query in dataset.queries_iter():
274
+ dev_qids.add(query.query_id)
275
+
276
+
277
+ with open('data/qrels.dev.tsv') as fIn:
278
+ for line in fIn:
279
+ qid, _, pid, _ = line.strip().split('\t')
280
+
281
+ if qid not in dev_qids:
282
+ continue
283
+
284
+ if qid not in dev_rel_docs:
285
+ dev_rel_docs[qid] = set()
286
+ dev_rel_docs[qid].add(pid)
287
+
288
+ retrieved_docs[qid] = set()
289
+ needed_qids.add(qid)
290
+ needed_pids.add(pid)
291
+
292
+ for query in dataset.queries_iter():
293
+ qid = query.query_id
294
+ if qid in needed_qids:
295
+ dev_queries[qid] = query.text
296
+
297
+ with open('data/top1000.dev', 'rt') as fIn:
298
+ for line in fIn:
299
+ qid, pid, query, passage = line.strip().split("\t")
300
+ corpus[pid] = passage
301
+ retrieved_docs[qid].add(pid)
302
+
303
+
304
+ ## Run evaluator
305
+ print("Queries: {}".format(len(dev_queries)))
306
+
307
+ mrr_scores = []
308
+ hf_model.eval()
309
+
310
+ with torch.no_grad():
311
+ for qid in tqdm.tqdm(dev_queries, total=len(dev_queries)):
312
+ query = dev_queries[qid]
313
+ top_pids = list(retrieved_docs[qid])
314
+ cross_inp = [[query, corpus[pid]] for pid in top_pids]
315
+
316
+ encoded = tokenizer(cross_inp, padding=True, truncation=True, return_tensors="pt").to('cuda')
317
+ output = model(**encoded)
318
+ bert_score = output.logits.detach().cpu().numpy()
319
+ bert_score = np.squeeze(bert_score)
320
+
321
+ argsort = np.argsort(-bert_score)
322
+
323
+ rank_score = 0
324
+ for rank, idx in enumerate(argsort[0:10]):
325
+ pid = top_pids[idx]
326
+ if pid in dev_rel_docs[qid]:
327
+ rank_score = 1/(rank+1)
328
+ break
329
+
330
+ mrr_scores.append(rank_score)
331
+
332
+ if len(mrr_scores) % 10 == 0:
333
+ print("{} MRR@10: {:.2f}".format(len(mrr_scores), 100*np.mean(mrr_scores)))
334
+
335
+ print("MRR@10: {:.2f}".format(np.mean(mrr_scores)*100))
336
+
337
+
338
+ if __name__ == '__main__':
339
+ parser = ArgumentParser()
340
+ parser.add_argument("--num_gpus", type=int, default=1)
341
+ parser.add_argument("--batch_size", type=int, default=32)
342
+ parser.add_argument("--epochs", type=int, default=10)
343
+ parser.add_argument("--strategy", default=None)
344
+ parser.add_argument("--model", default='microsoft/mdeberta-v3-base')
345
+ parser.add_argument("--eval", action="store_true")
346
+ parser.add_argument("--ckpt")
347
+ parser.add_argument("--cross_lingual_chance", type=float, default=0.33)
348
+ parser.add_argument("--precision", type=int, default=16)
349
+ parser.add_argument("--num_negs", type=int, default=3)
350
+ parser.add_argument("--langs", nargs="+", default=['english', 'chinese', 'french', 'german', 'indonesian', 'italian', 'portuguese', 'russian', 'spanish', 'arabic', 'dutch', 'hindi', 'japanese', 'vietnamese'])
351
+
352
+
353
+ args = parser.parse_args()
354
+
355
+ if args.eval:
356
+ eval(args)
357
+ else:
358
+ main(args)
359
+
360
+
361
+ # Script was called via:
362
+ #python cross_mutlilingual.py --model nreimers/mMiniLMv2-L12-H384-distilled-from-XLMR-Large