davidheineman commited on
Commit
8e2f8d0
1 Parent(s): 6d2b619

add knn init

Browse files
Files changed (2) hide show
  1. .gitignore +3 -1
  2. knn_db_init.py +81 -0
.gitignore CHANGED
@@ -1,4 +1,6 @@
1
  __pycache__
2
  experiments
3
  .openai-secret
4
- .mongodb-secret
 
 
 
1
  __pycache__
2
  experiments
3
  .openai-secret
4
+ .mongodb-secret
5
+ demo.mov
6
+ .DS_Store
knn_db_init.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ from openai import OpenAI
4
+
5
+ with open('.openai-secret', 'r') as f: OPENAI_API_KEY = f.read()
6
+
7
+
8
+ def read_dataset():
9
+ print("Reading dataset")
10
+ json_file = open('updateddata.json')
11
+ data = json.load(json_file)
12
+ return data
13
+
14
+
15
+ def main():
16
+ """
17
+ The purpse of this script is to take the articles json and add a field of the vector embeddings.
18
+ There is a hardcoded sleep because the openAI API is rate limited to 1M tokens / min.
19
+ In total, takes about an hour to run.
20
+ """
21
+
22
+ data = read_dataset()
23
+ print(len(data))
24
+
25
+ mini = 0
26
+ max = 750
27
+
28
+ client = OpenAI(api_key=OPENAI_API_KEY)
29
+
30
+ newjson = []
31
+
32
+ while max <= len(data):
33
+ print("------------")
34
+ print("startind", mini, "endind", max)
35
+
36
+ paper_subset = data[mini:max]
37
+ abstract_list = []
38
+ for paper in paper_subset:
39
+ abstract = paper['abstract'][0:2048]
40
+ abstract = json.dumps(abstract, ensure_ascii=True)
41
+ abstract_list.append(abstract)
42
+
43
+ totallen = 0
44
+ for thinig in abstract_list:
45
+ totallen+= len(thinig[0:2048])
46
+ print("numtokens:", totallen)
47
+
48
+ abstract_list = [x.replace("\n"," ") for x in abstract_list]
49
+ abstract_list = [x.strip() for x in abstract_list]
50
+
51
+ res = client.embeddings.create(
52
+ model="text-embedding-3-small",
53
+ input=abstract_list,
54
+ encoding_format="float"
55
+ )
56
+
57
+ resdata = res.dict()['data']
58
+
59
+ print("Successful API call")
60
+
61
+ for i in range(len(paper_subset)):
62
+ paper_subset[i]['embed'] = resdata[i]['embedding']
63
+
64
+ print("Added embeds")
65
+
66
+ newjson.append(paper_subset)
67
+
68
+ if (max == len(data)): break
69
+
70
+ time.sleep(61)
71
+
72
+ mini += 750
73
+ max += 750
74
+ max = min(max, len(data))
75
+
76
+ with open("datawithembeds.json", 'w') as f:
77
+ json.dump(newjson, f)
78
+
79
+
80
+ if __name__ == '__main__':
81
+ main()