xiaowenbin commited on
Commit
cd2135c
1 Parent(s): e02795d

Upload mteb_eval_openai.py

Browse files
Files changed (1) hide show
  1. mteb_eval_openai.py +94 -0
mteb_eval_openai.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import hashlib
5
+ import numpy as np
6
+ import requests
7
+
8
+
9
+ OPENAI_BASE_URL = os.environ.get('OPENAI_BASE_URL', '')
10
+ OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', '')
11
+ EMB_CACHE_DIR = os.environ.get('EMB_CACHE_DIR', '.cache/embs')
12
+ os.makedirs(EMB_CACHE_DIR, exist_ok=True)
13
+
14
+
15
+ def uuid_for_text(text):
16
+ return hashlib.md5(text.encode('utf8')).hexdigest()
17
+
18
+ def request_openai_emb(texts, model="text-embedding-3-large",
19
+ base_url='https://api.openai.com', prefix_url='/v1/embeddings',
20
+ timeout=4, retry=3, interval=2, caching=True):
21
+ if isinstance(texts, str):
22
+ texts = [texts]
23
+ assert len(texts) <= 256
24
+
25
+ data = []
26
+ if caching:
27
+ for text in texts:
28
+ emb_file = f"{EMB_CACHE_DIR}/{uuid_for_text(text)}"
29
+ if os.path.isfile(emb_file) and os.path.getsize(emb_file) > 0:
30
+ data.append(np.loadtxt(emb_file))
31
+ if len(texts) == len(data):
32
+ return data
33
+
34
+ url = f"{OPENAI_BASE_URL}{prefix_url}" if OPENAI_BASE_URL else f"{base_url}{prefix_url}"
35
+ headers = {
36
+ "Authorization": f"Bearer {OPENAI_API_KEY}",
37
+ "Content-Type": "application/json"
38
+ }
39
+ payload = {"input": texts, "model": model}
40
+
41
+ while retry > 0 and len(data) == 0:
42
+ try:
43
+ r = requests.post(url, headers=headers, json=payload,
44
+ timeout=timeout)
45
+ res = r.json()
46
+ for x in res["data"]:
47
+ data.append(np.array(x["embedding"]))
48
+ except Exception as e:
49
+ print(f"request openai, retry {retry}, error: {e}", file=sys.stderr)
50
+ time.sleep(interval)
51
+ retry -= 1
52
+
53
+ if len(data) != len(texts):
54
+ data = []
55
+
56
+ if caching and len(data) > 0:
57
+ for text, emb in zip(texts, data):
58
+ emb_file = f"{EMB_CACHE_DIR}/{uuid_for_text(text)}"
59
+ np.savetxt(emb_file, emb)
60
+
61
+ return data
62
+
63
+
64
+ class OpenaiEmbModel:
65
+
66
+ def encode(self, sentences, batch_size=32, **kwargs):
67
+ batch_size = min(64, batch_size)
68
+
69
+ embs = []
70
+ for i in range(0, len(sentences), batch_size):
71
+ batch_texts = sentences[i:i+batch_size]
72
+ batch_embs = request_openai_emb(batch_texts,
73
+ caching=True, retry=3, interval=2)
74
+ assert len(batch_texts) == len(batch_embs), "The batch of texts and embs DONT match!"
75
+ embs.extend(batch_embs)
76
+
77
+ return embs
78
+
79
+
80
+ model = OpenaiEmbModel()
81
+
82
+ ######
83
+ # test
84
+ #####
85
+ #embs = model.encode(['全国', '北京'])
86
+ #print(embs)
87
+
88
+ # task_list
89
+ task_list = ['Classification', 'Clustering', 'Reranking', 'Retrieval', 'STS', 'PairClassification']
90
+ # languages
91
+ task_langs=["zh", "zh-CN"]
92
+
93
+ evaluation = MTEB(task_types=task_list, task_langs=task_langs)
94
+ evaluation.run(model, output_folder=f"results/zh/{model_name.split('/')[-1]}")