zyznull commited on
Commit
b5e6d9a
1 Parent(s): a0c84dc

Upload eval_mteb.py

Browse files
Files changed (1) hide show
  1. scripts/eval_mteb.py +668 -0
scripts/eval_mteb.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from collections import defaultdict
3
+ import json
4
+ import logging
5
+ import math
6
+ import os
7
+ import sys
8
+ import queue
9
+ from typing import Dict, List, Optional, Union
10
+
11
+ from tqdm.autonotebook import trange
12
+ import datasets
13
+ import numpy as np
14
+ import torch
15
+ import torch.multiprocessing as mp
16
+ from transformers import AutoModel, AutoTokenizer
17
+ from transformers import AutoModelForCausalLM
18
+ from mteb import MTEB, CrosslingualTask, MultilingualTask
19
+
20
+ TASK_LIST_CLASSIFICATION = [
21
+ "AmazonCounterfactualClassification",
22
+ "AmazonPolarityClassification",
23
+ "AmazonReviewsClassification",
24
+ "Banking77Classification",
25
+ "EmotionClassification",
26
+ "ImdbClassification",
27
+ "MassiveIntentClassification",
28
+ "MassiveScenarioClassification",
29
+ "MTOPDomainClassification",
30
+ "MTOPIntentClassification",
31
+ "ToxicConversationsClassification",
32
+ "TweetSentimentExtractionClassification",
33
+ ]
34
+
35
+ TASK_LIST_CLUSTERING = [
36
+ "ArxivClusteringP2P",
37
+ "ArxivClusteringS2S",
38
+ "BiorxivClusteringP2P",
39
+ "BiorxivClusteringS2S",
40
+ "MedrxivClusteringP2P",
41
+ "MedrxivClusteringS2S",
42
+ "RedditClustering",
43
+ "RedditClusteringP2P",
44
+ "StackExchangeClustering",
45
+ "StackExchangeClusteringP2P",
46
+ "TwentyNewsgroupsClustering",
47
+ ]
48
+
49
+ TASK_LIST_PAIR_CLASSIFICATION = [
50
+ "SprintDuplicateQuestions",
51
+ "TwitterSemEval2015",
52
+ "TwitterURLCorpus",
53
+ ]
54
+
55
+ TASK_LIST_RERANKING = [
56
+ "AskUbuntuDupQuestions",
57
+ "MindSmallReranking",
58
+ "SciDocsRR",
59
+ "StackOverflowDupQuestions",
60
+ ]
61
+
62
+ TASK_LIST_RETRIEVAL = [
63
+ "ArguAna",
64
+ "ClimateFEVER",
65
+ "CQADupstackAndroidRetrieval",
66
+ "CQADupstackEnglishRetrieval",
67
+ "CQADupstackGamingRetrieval",
68
+ "CQADupstackGisRetrieval",
69
+ "CQADupstackMathematicaRetrieval",
70
+ "CQADupstackPhysicsRetrieval",
71
+ "CQADupstackProgrammersRetrieval",
72
+ "CQADupstackStatsRetrieval",
73
+ "CQADupstackTexRetrieval",
74
+ "CQADupstackUnixRetrieval",
75
+ "CQADupstackWebmastersRetrieval",
76
+ "CQADupstackWordpressRetrieval",
77
+ "DBPedia",
78
+ "FEVER",
79
+ "FiQA2018",
80
+ "HotpotQA",
81
+ "MSMARCO",
82
+ "NFCorpus",
83
+ "NQ",
84
+ "QuoraRetrieval",
85
+ "SCIDOCS",
86
+ "SciFact",
87
+ "Touche2020",
88
+ "TRECCOVID",
89
+ ]
90
+
91
+ TASK_LIST_STS = [
92
+ "BIOSSES",
93
+ "SICK-R",
94
+ "STS12",
95
+ "STS13",
96
+ "STS14",
97
+ "STS15",
98
+ "STS16",
99
+ "STS17",
100
+ "STS22",
101
+ "STSBenchmark",
102
+ "SummEval",
103
+ ]
104
+
105
+ MTEB_TASK_LIST = (
106
+ TASK_LIST_CLASSIFICATION
107
+ + TASK_LIST_CLUSTERING
108
+ + TASK_LIST_PAIR_CLASSIFICATION
109
+ + TASK_LIST_RERANKING
110
+ + TASK_LIST_RETRIEVAL
111
+ + TASK_LIST_STS
112
+ )
113
+
114
+
115
+ CMTEB_TASK_LIST = ['TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'OnlineShopping', 'Waimai','AmazonReviewsClassification', 'MassiveIntentClassification', 'MassiveScenarioClassification', 'MultilingualSentiment',
116
+ 'CLSClusteringS2S', 'CLSClusteringP2P', 'ThuNewsClusteringS2S', 'ThuNewsClusteringP2P',
117
+ 'Ocnli', 'Cmnli',
118
+ 'T2Reranking', 'MmarcoReranking', 'CMedQAv1', 'CMedQAv2',
119
+ 'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
120
+ 'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC', 'STS22']
121
+
122
+
123
+
124
+ logging.basicConfig(
125
+ level=logging.INFO,
126
+ format='%(asctime)s - %(levelname)s - %(name)s : %(message)s'
127
+ )
128
+
129
+ logger = logging.getLogger('eval_mteb_qwen.py')
130
+
131
+ def get_detailed_instruct(task_description: str) -> str:
132
+ if not task_description:
133
+ return ''
134
+
135
+ return 'Instruct: {}\nQuery: '.format(task_description)
136
+
137
+ def get_task_def_by_task_name_and_type(task_name: str, task_type: str, default_instruct='Given a web search query, retrieve relevant passages that answer the query') -> str:
138
+ if task_type in ['STS']:
139
+ # return "Given a premise, retrieve a hypothesis that is entailed by the premise."
140
+ return "Retrieve semantically similar text"
141
+
142
+ if task_type in ['Summarization']:
143
+ return "Given a news summary, retrieve other semantically similar summaries"
144
+
145
+ if task_type in ['BitextMining']:
146
+ return "Retrieve parallel sentences"
147
+
148
+ if task_type in ['Classification']:
149
+ task_name_to_instruct: Dict[str, str] = {
150
+ 'AmazonCounterfactualClassification': 'Classify a given Amazon customer review text as either counterfactual or not-counterfactual',
151
+ 'AmazonPolarityClassification': 'Classify Amazon reviews into positive or negative sentiment',
152
+ 'AmazonReviewsClassification': 'Classify the given Amazon review into its appropriate rating category',
153
+ 'Banking77Classification': 'Given a online banking query, find the corresponding intents',
154
+ 'EmotionClassification': 'Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise',
155
+ 'ImdbClassification': 'Classify the sentiment expressed in the given movie review text from the IMDB dataset',
156
+ 'MassiveIntentClassification': 'Given a user utterance as query, find the user intents',
157
+ 'MassiveScenarioClassification': 'Given a user utterance as query, find the user scenarios',
158
+ 'MTOPDomainClassification': 'Classify the intent domain of the given utterance in task-oriented conversation',
159
+ 'MTOPIntentClassification': 'Classify the intent of the given utterance in task-oriented conversation',
160
+ 'ToxicConversationsClassification': 'Classify the given comments as either toxic or not toxic',
161
+ 'TweetSentimentExtractionClassification': 'Classify the sentiment of a given tweet as either positive, negative, or neutral',
162
+ # C-MTEB eval instructions
163
+ 'TNews': 'Classify the fine-grained category of the given news title',
164
+ 'IFlyTek': 'Given an App description text, find the appropriate fine-grained category',
165
+ 'MultilingualSentiment': 'Classify sentiment of the customer review into positive, neutral, or negative',
166
+ 'JDReview': 'Classify the customer review for iPhone on e-commerce platform into positive or negative',
167
+ 'OnlineShopping': 'Classify the customer review for online shopping into positive or negative',
168
+ 'Waimai': 'Classify the customer review from a food takeaway platform into positive or negative',
169
+ }
170
+ return task_name_to_instruct[task_name]
171
+
172
+ if task_type in ['Clustering']:
173
+ task_name_to_instruct: Dict[str, str] = {
174
+ 'ArxivClusteringP2P': 'Identify the main and secondary category of Arxiv papers based on the titles and abstracts',
175
+ 'ArxivClusteringS2S': 'Identify the main and secondary category of Arxiv papers based on the titles',
176
+ 'BiorxivClusteringP2P': 'Identify the main category of Biorxiv papers based on the titles and abstracts',
177
+ 'BiorxivClusteringS2S': 'Identify the main category of Biorxiv papers based on the titles',
178
+ 'MedrxivClusteringP2P': 'Identify the main category of Medrxiv papers based on the titles and abstracts',
179
+ 'MedrxivClusteringS2S': 'Identify the main category of Medrxiv papers based on the titles',
180
+ 'RedditClustering': 'Identify the topic or theme of Reddit posts based on the titles',
181
+ 'RedditClusteringP2P': 'Identify the topic or theme of Reddit posts based on the titles and posts',
182
+ 'StackExchangeClustering': 'Identify the topic or theme of StackExchange posts based on the titles',
183
+ 'StackExchangeClusteringP2P': 'Identify the topic or theme of StackExchange posts based on the given paragraphs',
184
+ 'TwentyNewsgroupsClustering': 'Identify the topic or theme of the given news articles',
185
+ # C-MTEB eval instructions
186
+ 'CLSClusteringS2S': 'Identify the main category of scholar papers based on the titles',
187
+ 'CLSClusteringP2P': 'Identify the main category of scholar papers based on the titles and abstracts',
188
+ 'ThuNewsClusteringS2S': 'Identify the topic or theme of the given news articles based on the titles',
189
+ 'ThuNewsClusteringP2P': 'Identify the topic or theme of the given news articles based on the titles and contents',
190
+ }
191
+ return task_name_to_instruct[task_name]
192
+
193
+ if task_type in ['Reranking', 'PairClassification']:
194
+ task_name_to_instruct: Dict[str, str] = {
195
+ 'AskUbuntuDupQuestions': 'Retrieve duplicate questions from AskUbuntu forum',
196
+ 'MindSmallReranking': 'Retrieve relevant news articles based on user browsing history',
197
+ 'SciDocsRR': 'Given a title of a scientific paper, retrieve the titles of other relevant papers',
198
+ 'StackOverflowDupQuestions': 'Retrieve duplicate questions from StackOverflow forum',
199
+ 'SprintDuplicateQuestions': 'Retrieve duplicate questions from Sprint forum',
200
+ 'TwitterSemEval2015': 'Retrieve tweets that are semantically similar to the given tweet',
201
+ 'TwitterURLCorpus': 'Retrieve tweets that are semantically similar to the given tweet',
202
+ # C-MTEB eval instructions
203
+ 'T2Reranking': 'Given a Chinese search query, retrieve web passages that answer the question',
204
+ 'MmarcoReranking': 'Given a Chinese search query, retrieve web passages that answer the question',
205
+ 'CMedQAv1': 'Given a Chinese community medical question, retrieve replies that best answer the question',
206
+ 'CMedQAv2': 'Given a Chinese community medical question, retrieve replies that best answer the question',
207
+ 'Ocnli': 'Retrieve semantically similar text.',
208
+ 'Cmnli': 'Retrieve semantically similar text.',
209
+ }
210
+ return task_name_to_instruct[task_name]
211
+
212
+ if task_type in ['Retrieval']:
213
+ if task_name.lower().startswith('cqadupstack'):
214
+ return 'Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question'
215
+
216
+ task_name_to_instruct: Dict[str, str] = {
217
+ 'ArguAna': 'Given a claim, find documents that refute the claim',
218
+ 'ClimateFEVER': 'Given a claim about climate change, retrieve documents that support or refute the claim',
219
+ 'DBPedia': 'Given a query, retrieve relevant entity descriptions from DBPedia',
220
+ 'FEVER': 'Given a claim, retrieve documents that support or refute the claim',
221
+ 'FiQA2018': 'Given a financial question, retrieve user replies that best answer the question',
222
+ 'HotpotQA': 'Given a multi-hop question, retrieve documents that can help answer the question',
223
+ 'MSMARCO': 'Given a web search query, retrieve relevant passages that answer the query',
224
+ 'NFCorpus': 'Given a question, retrieve relevant documents that best answer the question',
225
+ 'NQ': 'Given a question, retrieve Wikipedia passages that answer the question',
226
+ 'QuoraRetrieval': 'Given a question, retrieve questions that are semantically equivalent to the given question',
227
+ 'SCIDOCS': 'Given a scientific paper title, retrieve paper abstracts that are cited by the given paper',
228
+ 'SciFact': 'Given a scientific claim, retrieve documents that support or refute the claim',
229
+ 'Touche2020': 'Given a question, retrieve detailed and persuasive arguments that answer the question',
230
+ 'TRECCOVID': 'Given a query on COVID-19, retrieve documents that answer the query',
231
+ # C-MTEB eval instructions
232
+ 'T2Retrieval': 'Given a Chinese search query, retrieve web passages that answer the question',
233
+ 'MMarcoRetrieval': 'Given a web search query, retrieve relevant passages that answer the query',
234
+ 'DuRetrieval': 'Given a Chinese search query, retrieve web passages that answer the question',
235
+ 'CovidRetrieval': 'Given a question on COVID-19, retrieve news articles that answer the question',
236
+ 'CmedqaRetrieval': 'Given a Chinese community medical question, retrieve replies that best answer the question',
237
+ 'EcomRetrieval': 'Given a user query from an e-commerce website, retrieve description sentences of relevant products',
238
+ 'MedicalRetrieval': 'Given a medical question, retrieve user replies that best answer the question',
239
+ 'VideoRetrieval': 'Given a video search query, retrieve the titles of relevant videos',
240
+ }
241
+
242
+ # add lower case keys to match some beir names
243
+ task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()})
244
+ # other cases where lower case match still doesn't work
245
+ task_name_to_instruct['trec-covid'] = task_name_to_instruct['TRECCOVID']
246
+ task_name_to_instruct['climate-fever'] = task_name_to_instruct['ClimateFEVER']
247
+ task_name_to_instruct['dbpedia-entity'] = task_name_to_instruct['DBPedia']
248
+ task_name_to_instruct['webis-touche2020'] = task_name_to_instruct['Touche2020']
249
+ task_name_to_instruct['fiqa'] = task_name_to_instruct['FiQA2018']
250
+ task_name_to_instruct['quora'] = task_name_to_instruct['QuoraRetrieval']
251
+
252
+ # for miracl evaluation
253
+ task_name_to_instruct['miracl'] = 'Given a question, retrieve Wikipedia passages that answer the question'
254
+
255
+ return task_name_to_instruct[task_name]
256
+ logging.warning(f"No instruction config for task {task_name} with type {task_type}, use default instruction.")
257
+ return default_instruct
258
+
259
+ class Encoder(torch.nn.Module):
260
+ def __init__(self, name_or_path:str, pooling: str):
261
+ super().__init__()
262
+ self.model = AutoModel.from_pretrained(name_or_path, trust_remote_code=True)
263
+ self.model = self.model.half()
264
+ self.model.eval()
265
+ self.pooling = pooling
266
+
267
+ def forward(self, **features) -> torch.Tensor:
268
+ output = self.model(**features, output_hidden_states=True, return_dict=True)
269
+ hidden_state = output.hidden_states[-1]
270
+ embeddings = self.pooler(hidden_state, **features)
271
+ return embeddings
272
+
273
+ def pooler(
274
+ self,
275
+ hidden_state: torch.Tensor,
276
+ attention_mask: torch.Tensor,
277
+ **kwargs
278
+ ) -> torch.Tensor:
279
+ if attention_mask.ndim == 2:
280
+ mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size())
281
+ elif attention_mask.ndim == 3:
282
+ mask_expanded = attention_mask
283
+ else:
284
+ raise RuntimeError(f"Unexpected {attention_mask.ndim=}")
285
+
286
+ hidden_state = hidden_state * mask_expanded
287
+
288
+ if self.pooling == 'first':
289
+ pooled_output = hidden_state[:, 0]
290
+
291
+ elif self.pooling == 'last':
292
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
293
+ if left_padding:
294
+ return hidden_state[:, -1]
295
+ else:
296
+ sequence_lengths = attention_mask.sum(dim=1) - 1
297
+ batch_size = hidden_state.shape[0]
298
+ return hidden_state[torch.arange(batch_size, device=hidden_state.device), sequence_lengths]
299
+ elif self.pooling == 'mean':
300
+ # TODO: weight
301
+ lengths = mask_expanded.sum(1).clamp(min=1e-9)
302
+ pooled_output = hidden_state.sum(dim=1) / lengths
303
+
304
+ elif self.pooling == 'weightedmean':
305
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size()).float()
306
+ # hidden_state shape: bs, seq, hidden_dim
307
+ weights = (
308
+ torch.arange(start=1, end=hidden_state.shape[1] + 1)
309
+ .unsqueeze(0)
310
+ .unsqueeze(-1)
311
+ .expand(hidden_state.size())
312
+ .float().to(hidden_state.device)
313
+ )
314
+ assert weights.shape == hidden_state.shape == input_mask_expanded.shape
315
+ input_mask_expanded = input_mask_expanded * weights
316
+
317
+ sum_embeddings = torch.sum(hidden_state * input_mask_expanded, 1)
318
+ sum_mask = input_mask_expanded.sum(1)
319
+ sum_mask = torch.clamp(sum_mask, min=1e-9)
320
+ pooled_output = sum_embeddings / sum_mask
321
+
322
+ else:
323
+ raise ValueError(f"Wrong pooler mode : {self.pooling}")
324
+ return pooled_output
325
+
326
+
327
+ class Wrapper:
328
+ def __init__(
329
+ self,
330
+ tokenizer,
331
+ encoder: Encoder,
332
+ batch_size: int,
333
+ max_seq_len: int = 512,
334
+ normalize_embeddings: bool = False,
335
+ default_query: bool = False,
336
+ force_default: bool = False,
337
+ sep: str = " ",
338
+ mp_tensor_to_cuda: bool = False,
339
+ instruction: str = None,
340
+ attn_type: str = None
341
+ ):
342
+ self.tokenizer = tokenizer
343
+ self.model = encoder
344
+ self.batch_size = batch_size
345
+ self.max_seq_len = max_seq_len
346
+ self.pool: dict = None
347
+ self.normalize_embeddings = normalize_embeddings
348
+ self.mp_tensor_to_cuda = mp_tensor_to_cuda
349
+ self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
350
+ self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
351
+ self.instruction = instruction
352
+
353
+ if self.tokenizer.padding_side != 'right':
354
+ logger.warning(f"Change tokenizer.padding_side from {self.tokenizer.padding_side} to right")
355
+ self.tokenizer.padding_side = 'right'
356
+ if self.tokenizer.pad_token is None:
357
+ logger.warning(f"Set tokenizer.pad_token as eos_token {self.tokenizer.eos_token}")
358
+ self.tokenizer.pad_token='<|endoftext|>'
359
+
360
+ def start(self, target_devices: Optional[List[str]] = None):
361
+ """
362
+ Starts multi process to process the encoding with several, independent processes.
363
+ This method is recommended if you want to encode on multiple GPUs. It is advised
364
+ to start only one process per GPU. This method works together with encode_multi_process
365
+
366
+ :param target_devices: PyTorch target devices, e.g. cuda:0, cuda:1... If None, all available CUDA devices will be used
367
+ :return: Returns a dict with the target processes, an input queue and and output queue.
368
+ """
369
+ if target_devices is None:
370
+ if torch.cuda.is_available():
371
+ target_devices = ['cuda:{}'.format(i) for i in range(torch.cuda.device_count())]
372
+ else:
373
+ logger.info("CUDA is not available. Start 4 CPU worker")
374
+ target_devices = ['cpu']*4
375
+
376
+ logger.info("Start multi-process pool on devices: {}".format(', '.join(map(str, target_devices))))
377
+ print('multi instruction', self.instruction)
378
+ ctx = mp.get_context('spawn')
379
+ input_queue = ctx.Queue()
380
+ output_queue = ctx.Queue()
381
+ processes = []
382
+
383
+ for cuda_id in target_devices:
384
+ p = ctx.Process(
385
+ target=self._encode_multi_process_worker,
386
+ args=(cuda_id, self, input_queue, output_queue),
387
+ daemon=True
388
+ )
389
+ p.start()
390
+ processes.append(p)
391
+
392
+ self.pool = {'input': input_queue, 'output': output_queue, 'processes': processes}
393
+
394
+ def stop(self):
395
+ """
396
+ Stops all processes started with start_multi_process_pool
397
+ """
398
+ for p in self.pool['processes']:
399
+ p.terminate()
400
+
401
+ for p in self.pool['processes']:
402
+ p.join()
403
+ p.close()
404
+
405
+ self.pool['input'].close()
406
+ self.pool['output'].close()
407
+
408
+ @staticmethod
409
+ def _encode_multi_process_worker(target_device: str, model, input_queue, results_queue):
410
+ """
411
+ Internal working process to encode sentences in multi-process setup
412
+ """
413
+ while True:
414
+ try:
415
+ id, sentences, kwargs = input_queue.get()
416
+ kwargs.update(device=target_device, show_progress_bar=False, convert_to_numpy=True)
417
+ embeddings = model._encode(sentences, **kwargs)
418
+ results_queue.put([id, embeddings])
419
+ except queue.Empty:
420
+ break
421
+
422
+ def encode_multi_process(
423
+ self,
424
+ sentences: List[str],
425
+ **kwargs
426
+ ):
427
+ """
428
+ This method allows to run encode() on multiple GPUs. The sentences are chunked into smaller packages
429
+ and sent to individual processes, which encode these on the different GPUs. This method is only suitable
430
+ for encoding large sets of sentences
431
+
432
+ :param sentences: List of sentences
433
+ :param pool: A pool of workers started with SentenceTransformer.start_multi_process_pool
434
+ :param chunk_size: Sentences are chunked and sent to the individual processes. If none, it determine a sensible size.
435
+ :param kwargs: other keyword arguments for model.encode() such as batch_size
436
+ :return: Numpy matrix with all embeddings
437
+ """
438
+ part_size = math.ceil(len(sentences) / len(self.pool["processes"]))
439
+ chunk_size = part_size if part_size < 3200 else 3200 # for retrieval chunk 50000
440
+
441
+ logger.debug(f"Chunk data into {math.ceil(len(sentences) / chunk_size)} packages of size {chunk_size}")
442
+
443
+ input_queue = self.pool['input']
444
+ last_chunk_id = 0
445
+ chunk = []
446
+
447
+ for sentence in sentences:
448
+ chunk.append(sentence)
449
+ if len(chunk) >= chunk_size:
450
+ input_queue.put([last_chunk_id, chunk, kwargs])
451
+ last_chunk_id += 1
452
+ chunk = []
453
+
454
+ if len(chunk) > 0:
455
+ input_queue.put([last_chunk_id, chunk, kwargs])
456
+ last_chunk_id += 1
457
+
458
+ output_queue = self.pool['output']
459
+ results_list = sorted([output_queue.get() for _ in range(last_chunk_id)], key=lambda x: x[0])
460
+ embeddings = np.concatenate([result[1] for result in results_list])
461
+ return embeddings
462
+
463
+ @staticmethod
464
+ def batch_to_device(batch, target_device):
465
+ """
466
+ send a pytorch batch to a device (CPU/GPU)
467
+ """
468
+ for key in batch:
469
+ if isinstance(batch[key], torch.Tensor):
470
+ batch[key] = batch[key].to(target_device)
471
+ return batch
472
+
473
+ def _text_length(self, text: Union[List[int], List[List[int]]]):
474
+ """
475
+ Help function to get the length for the input text. Text can be either
476
+ a list of ints (which means a single text as input), or a tuple of list of ints
477
+ (representing several text inputs to the model).
478
+ """
479
+
480
+ if isinstance(text, dict): #{key: value} case
481
+ return len(next(iter(text.values())))
482
+ elif not hasattr(text, '__len__'): #Object has no len() method
483
+ return 1
484
+ elif len(text) == 0 or isinstance(text[0], int): #Empty string or list of ints
485
+ return len(text)
486
+ else:
487
+ return sum([len(t) for t in text]) #Sum of length of individual strings
488
+
489
+ def _tokenize(self, sentences: List[str], is_query: bool):
490
+
491
+ batch_dict = tokenizer(sentences, max_length=max_length - 1, return_attention_mask=False, padding=False, truncation=True)
492
+ batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
493
+ batch_dict = tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors='pt')
494
+ batch_dict['is_causal'] = False
495
+ return batch_dict
496
+
497
+
498
+ def _encode(
499
+ self,
500
+ sentences: List[str],
501
+ is_query: bool,
502
+ convert_to_numpy: bool = True,
503
+ convert_to_tensor: bool = False,
504
+ device: str = None,
505
+ show_progress_bar: bool = True,
506
+ **kwargs
507
+ ):
508
+ """
509
+ Computes sentence embeddings
510
+
511
+ :param sentences: the sentences to embed
512
+ :param batch_size: the batch size used for the computation
513
+ :param show_progress_bar: Output a progress bar when encode sentences
514
+ :param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
515
+ :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
516
+ :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
517
+ :param device: Which torch.device to use for the computation
518
+ :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
519
+
520
+ :return:
521
+ By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
522
+ """
523
+ self.model.eval()
524
+
525
+ if convert_to_tensor:
526
+ convert_to_numpy = False
527
+
528
+ input_was_string = False
529
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'): #Cast an individual sentence to a list with length 1
530
+ sentences = [sentences]
531
+ input_was_string = True
532
+
533
+ if device is None:
534
+ device = self._target_device
535
+
536
+ self.model.to(device)
537
+
538
+ all_embeddings = []
539
+ length_sorted_idx = np.argsort([-self._text_length(s) for s in sentences])
540
+ sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
541
+
542
+ for start_index in trange(0, len(sentences), self.batch_size, desc="Batches", disable=not show_progress_bar):
543
+ sentences_batch = sentences_sorted[start_index:start_index + self.batch_size]
544
+ features = self._tokenize(sentences_batch, is_query)
545
+ features = self.batch_to_device(features, device)
546
+
547
+ with torch.no_grad():
548
+ embeddings = self.model(**features)
549
+
550
+ if self.normalize_embeddings:
551
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
552
+
553
+ # fixes for #522 and #487 to avoid oom problems on gpu with large datasets
554
+ if convert_to_numpy:
555
+ embeddings = embeddings.cpu()
556
+
557
+ all_embeddings.extend(embeddings)
558
+
559
+ all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
560
+
561
+ if convert_to_tensor:
562
+ all_embeddings = torch.stack(all_embeddings)
563
+ elif convert_to_numpy:
564
+ #all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
565
+ all_embeddings = np.asarray([emb.to(torch.float).numpy() for emb in all_embeddings])
566
+ if input_was_string:
567
+ all_embeddings = all_embeddings[0]
568
+
569
+ return all_embeddings
570
+
571
+ def encode(
572
+ self,
573
+ sentences: List[str],
574
+ is_query: Optional[bool] = None,
575
+ convert_to_tensor: bool = False,
576
+ **kwargs
577
+ ):
578
+ is_query = self.default_query if is_query is None else is_query
579
+ if is_query and self.instruction:
580
+ sentences = [self.instruction + sent for sent in sentences]
581
+ kwargs.update(is_query=is_query)
582
+ if self.pool is not None:
583
+ kwargs.update(show_progress_bar=False)
584
+ embeddings = self.encode_multi_process(sentences, **kwargs)
585
+ if convert_to_tensor:
586
+ embeddings = torch.from_numpy(embeddings)
587
+ if self.mp_tensor_to_cuda and torch.cuda.is_available():
588
+ embeddings = embeddings.to(torch.device('cuda')) # default 0-th gpu
589
+ return embeddings
590
+
591
+ return self._encode(sentences, convert_to_tensor=convert_to_tensor, **kwargs)
592
+
593
+ def encode_queries(self, queries: List[str], **kwargs):
594
+ is_query = self.default_query if self.force_default else True
595
+ return self.encode(queries, is_query=is_query, **kwargs)
596
+
597
+ def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs):
598
+ # borrowed from mteb.abstasks.AbsTaskRetrieval.DRESModel
599
+ if type(corpus) is dict:
600
+ sentences = [
601
+ (corpus["title"][i] + self.sep + corpus["text"][i]).strip()
602
+ if "title" in corpus
603
+ else corpus["text"][i].strip()
604
+ for i in range(len(corpus["text"]))
605
+ ]
606
+ elif isinstance(corpus[0], dict):
607
+ sentences = [
608
+ (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
609
+ for doc in corpus
610
+ ]
611
+ else:
612
+ sentences = corpus
613
+ is_query = self.default_query if self.force_default else False
614
+ return self.encode(sentences, is_query=is_query, **kwargs)
615
+
616
+ def main(args):
617
+ tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
618
+ encoder = Encoder(args.model, args.pooling)
619
+ model = Wrapper(
620
+ tokenizer, encoder,
621
+ batch_size=args.batch_size,
622
+ max_seq_len=args.max_seq_len,
623
+ normalize_embeddings=args.norm
624
+ )
625
+
626
+ if args.task == 'mteb':
627
+ task_names = MTEB_TASK_LIST
628
+ lang = ['en']
629
+ elif args.task == 'cmteb':
630
+ task_names = CMTEB_TASK_LIST
631
+ lang = ['zh','zh-CN']
632
+ else:
633
+ task_names = [args.task]
634
+ lang = ['en','zh','zh-CN']
635
+ for task in task_names:
636
+ evaluation = MTEB(tasks=[task], task_langs=lang)
637
+ task_cls = evaluation.tasks[0]
638
+ task_name: str = task_cls.description['name']
639
+ task_type: str = task_cls.description['type']
640
+ instruction = get_task_def_by_task_name_and_type(task_name, task_type)
641
+ model.instruction = get_detailed_instruct(instruction)
642
+ if task == 'MSMARCO':
643
+ eval_splits = ["dev"]
644
+ elif task in CMTEB_TASK_LIST:
645
+ eval_splits = task_cls.description['eval_splits']
646
+ else:
647
+ eval_splits = ["test"]
648
+
649
+ evaluation.run(model, output_folder=args.output_dir, eval_splits=eval_splits)
650
+ print('\n')
651
+
652
+
653
+ if __name__ == "__main__":
654
+ _PARSER = argparse.ArgumentParser()
655
+ _PARSER.add_argument(
656
+ "-m", "--model", type=str, default=None
657
+ )
658
+ _PARSER.add_argument("--pooling", type=str, default='last')
659
+ _PARSER.add_argument("--output_dir", type=str, default=None)
660
+ _PARSER.add_argument("--default_type", type=str, default='query')
661
+ _PARSER.add_argument("--max_seq_len", type=int, default=512)
662
+ _PARSER.add_argument("-b", "--batch_size", type=int, default=32)
663
+ _PARSER.add_argument(
664
+ "-t", "--task", type=str, default=None # None for running default tasks
665
+ )
666
+ _PARSER.add_argument("--norm", action="store_true")
667
+ _ARGS = _PARSER.parse_args()
668
+ main(_ARGS)