davidheineman
commited on
Commit
•
8e2f8d0
1
Parent(s):
6d2b619
add knn init
Browse files- .gitignore +3 -1
- knn_db_init.py +81 -0
.gitignore
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
__pycache__
|
2 |
experiments
|
3 |
.openai-secret
|
4 |
-
.mongodb-secret
|
|
|
|
|
|
1 |
__pycache__
|
2 |
experiments
|
3 |
.openai-secret
|
4 |
+
.mongodb-secret
|
5 |
+
demo.mov
|
6 |
+
.DS_Store
|
knn_db_init.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import time
|
3 |
+
from openai import OpenAI
|
4 |
+
|
5 |
+
with open('.openai-secret', 'r') as f: OPENAI_API_KEY = f.read()
|
6 |
+
|
7 |
+
|
8 |
+
def read_dataset():
|
9 |
+
print("Reading dataset")
|
10 |
+
json_file = open('updateddata.json')
|
11 |
+
data = json.load(json_file)
|
12 |
+
return data
|
13 |
+
|
14 |
+
|
15 |
+
def main():
|
16 |
+
"""
|
17 |
+
The purpse of this script is to take the articles json and add a field of the vector embeddings.
|
18 |
+
There is a hardcoded sleep because the openAI API is rate limited to 1M tokens / min.
|
19 |
+
In total, takes about an hour to run.
|
20 |
+
"""
|
21 |
+
|
22 |
+
data = read_dataset()
|
23 |
+
print(len(data))
|
24 |
+
|
25 |
+
mini = 0
|
26 |
+
max = 750
|
27 |
+
|
28 |
+
client = OpenAI(api_key=OPENAI_API_KEY)
|
29 |
+
|
30 |
+
newjson = []
|
31 |
+
|
32 |
+
while max <= len(data):
|
33 |
+
print("------------")
|
34 |
+
print("startind", mini, "endind", max)
|
35 |
+
|
36 |
+
paper_subset = data[mini:max]
|
37 |
+
abstract_list = []
|
38 |
+
for paper in paper_subset:
|
39 |
+
abstract = paper['abstract'][0:2048]
|
40 |
+
abstract = json.dumps(abstract, ensure_ascii=True)
|
41 |
+
abstract_list.append(abstract)
|
42 |
+
|
43 |
+
totallen = 0
|
44 |
+
for thinig in abstract_list:
|
45 |
+
totallen+= len(thinig[0:2048])
|
46 |
+
print("numtokens:", totallen)
|
47 |
+
|
48 |
+
abstract_list = [x.replace("\n"," ") for x in abstract_list]
|
49 |
+
abstract_list = [x.strip() for x in abstract_list]
|
50 |
+
|
51 |
+
res = client.embeddings.create(
|
52 |
+
model="text-embedding-3-small",
|
53 |
+
input=abstract_list,
|
54 |
+
encoding_format="float"
|
55 |
+
)
|
56 |
+
|
57 |
+
resdata = res.dict()['data']
|
58 |
+
|
59 |
+
print("Successful API call")
|
60 |
+
|
61 |
+
for i in range(len(paper_subset)):
|
62 |
+
paper_subset[i]['embed'] = resdata[i]['embedding']
|
63 |
+
|
64 |
+
print("Added embeds")
|
65 |
+
|
66 |
+
newjson.append(paper_subset)
|
67 |
+
|
68 |
+
if (max == len(data)): break
|
69 |
+
|
70 |
+
time.sleep(61)
|
71 |
+
|
72 |
+
mini += 750
|
73 |
+
max += 750
|
74 |
+
max = min(max, len(data))
|
75 |
+
|
76 |
+
with open("datawithembeds.json", 'w') as f:
|
77 |
+
json.dump(newjson, f)
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == '__main__':
|
81 |
+
main()
|