Spaces:
Running
Running
jianghuyihei
commited on
Commit
•
e3a17c0
1
Parent(s):
a81bf47
delete async
Browse files- .gitattributes copy +0 -35
- LLM.py +43 -8
- agents.py +34 -73
- app.py +1 -1
- main.py +4 -9
- searcher/sementic_search.py +49 -121
.gitattributes copy
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLM.py
CHANGED
@@ -123,7 +123,13 @@ class openai_llm(base_llm):
|
|
123 |
input=text,
|
124 |
timeout= 180
|
125 |
)
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
except Exception as e:
|
128 |
print(f"get embbeding failed: {e}")
|
129 |
print(e)
|
@@ -147,7 +153,13 @@ class openai_llm(base_llm):
|
|
147 |
input=text,
|
148 |
timeout= 180
|
149 |
)
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
except Exception as e:
|
152 |
await asyncio.sleep(0.1)
|
153 |
print(f"get embbeding failed: {e}")
|
@@ -178,9 +190,32 @@ class openai_llm(base_llm):
|
|
178 |
|
179 |
|
180 |
if __name__ == "__main__":
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
input=text,
|
124 |
timeout= 180
|
125 |
)
|
126 |
+
embbeding = embbeding.data
|
127 |
+
if len(embbeding) == 0:
|
128 |
+
return None
|
129 |
+
elif len(embbeding) == 1:
|
130 |
+
return embbeding[0].embedding
|
131 |
+
else:
|
132 |
+
return [e.embedding for e in embbeding]
|
133 |
except Exception as e:
|
134 |
print(f"get embbeding failed: {e}")
|
135 |
print(e)
|
|
|
153 |
input=text,
|
154 |
timeout= 180
|
155 |
)
|
156 |
+
embbeding = embbeding.data
|
157 |
+
if len(embbeding) == 0:
|
158 |
+
return None
|
159 |
+
elif len(embbeding) == 1:
|
160 |
+
return embbeding[0].embedding
|
161 |
+
else:
|
162 |
+
return [e.embedding for e in embbeding]
|
163 |
except Exception as e:
|
164 |
await asyncio.sleep(0.1)
|
165 |
print(f"get embbeding failed: {e}")
|
|
|
190 |
|
191 |
|
192 |
if __name__ == "__main__":
|
193 |
+
import os
|
194 |
+
import yaml
|
195 |
+
|
196 |
+
def cal_cosine_similarity_matric(matric1, matric2):
|
197 |
+
if isinstance(matric1, list):
|
198 |
+
matric1 = np.array(matric1)
|
199 |
+
if isinstance(matric2, list):
|
200 |
+
matric2 = np.array(matric2)
|
201 |
+
if len(matric1.shape) == 1:
|
202 |
+
matric1 = matric1.reshape(1, -1)
|
203 |
+
if len(matric2.shape) == 1:
|
204 |
+
matric2 = matric2.reshape(1, -1)
|
205 |
+
dot_product = np.dot(matric1, matric2.T)
|
206 |
+
norm1 = np.linalg.norm(matric1, axis=1)
|
207 |
+
norm2 = np.linalg.norm(matric2, axis=1)
|
208 |
+
|
209 |
+
cos_sim = dot_product / np.outer(norm1, norm2)
|
210 |
+
scores = cos_sim.flatten()
|
211 |
+
# 返回一个list
|
212 |
+
return scores.tolist()
|
213 |
+
|
214 |
+
texts = ["What is the capital of France?","What is the capital of Spain?", "What is the capital of Italy?", "What is the capital of Germany?"]
|
215 |
+
text = "What is the capital of France?"
|
216 |
+
llm = openai_llm()
|
217 |
+
embbedings = llm.get_embbeding(texts)
|
218 |
+
embbeding = llm.get_embbeding(text)
|
219 |
+
|
220 |
+
scores = cal_cosine_similarity_matric(embbedings, embbeding)
|
221 |
+
print(scores)
|
agents.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
import json
|
2 |
import time
|
3 |
-
import asyncio
|
4 |
-
import os
|
5 |
from searcher import Result,SementicSearcher
|
6 |
from LLM import openai_llm
|
7 |
from prompts import *
|
@@ -17,10 +15,10 @@ def get_llms():
|
|
17 |
cheap_llm = get_llm("gpt-4o-mini")
|
18 |
return main_llm,cheap_llm
|
19 |
|
20 |
-
|
21 |
prompt = get_judge_idea_all_prompt(idea0,idea1,topic)
|
22 |
messages = [{"role":"user","content":prompt}]
|
23 |
-
response =
|
24 |
novelty = extract(response,"novelty")
|
25 |
relevance = extract(response,"relevance")
|
26 |
significance = extract(response,"significance")
|
@@ -55,16 +53,16 @@ class DeepResearchAgent:
|
|
55 |
def wrap_messages(self,prompt):
|
56 |
return [{"role":"user","content":prompt}]
|
57 |
|
58 |
-
|
59 |
-
return
|
60 |
|
61 |
-
|
62 |
-
return
|
63 |
|
64 |
-
|
65 |
prompt = get_deep_search_query_prompt(topic,query)
|
66 |
messages = self.wrap_messages(prompt)
|
67 |
-
response =
|
68 |
search_query = extract(response,"queries")
|
69 |
try:
|
70 |
search_query = json.loads(search_query)
|
@@ -73,17 +71,17 @@ class DeepResearchAgent:
|
|
73 |
search_query = [query]
|
74 |
return search_query
|
75 |
|
76 |
-
|
77 |
self.topic = topic
|
78 |
print(f"begin to generate search query for {topic}")
|
79 |
-
search_query =
|
80 |
papers = []
|
81 |
for query in search_query:
|
82 |
failed_query = []
|
83 |
current_papers = []
|
84 |
cnt = 0
|
85 |
while len(current_papers) == 0 and cnt < 10:
|
86 |
-
paper =
|
87 |
if paper and len(paper) > 0 and paper[0]:
|
88 |
self.read_papers.add(paper[0].title)
|
89 |
current_papers.append(paper[0])
|
@@ -91,7 +89,7 @@ class DeepResearchAgent:
|
|
91 |
failed_query.append(query)
|
92 |
prompt = get_deep_rewrite_query_prompt(failed_query,topic)
|
93 |
messages = self.wrap_messages(prompt)
|
94 |
-
new_query =
|
95 |
new_query = extract(new_query,"query")
|
96 |
print(f"Failed to search papers for {query}, regenerating query {new_query} to search papers.")
|
97 |
query = new_query
|
@@ -104,67 +102,30 @@ class DeepResearchAgent:
|
|
104 |
print(f"failed to generate idea {topic}")
|
105 |
return None,None,None,None,None,None,None,None,None
|
106 |
|
107 |
-
|
108 |
-
results = await asyncio.gather(*tasks)
|
109 |
-
results = [result for result in results if result]
|
110 |
-
if len(results) ==0:
|
111 |
-
print(f"failed to generate idea {topic}")
|
112 |
-
return None,None,None,None,None,None,None,None,None
|
113 |
-
|
114 |
-
ideas,idea_chains,experiments,entities,trends,futures,humans,years = [[result[i] for result in results] for i in range(8)]
|
115 |
-
|
116 |
-
tasks = []
|
117 |
-
for i,idea_1 in enumerate(ideas):
|
118 |
-
for j,idea_2 in enumerate(ideas):
|
119 |
-
if i != j:
|
120 |
-
tasks.append(judge_idea(i,j,idea_1,idea_2,topic,self.llm))
|
121 |
-
results = await asyncio.gather(*tasks)
|
122 |
-
elo_scores = [0 for _ in range(len(ideas))]
|
123 |
-
elo_selected = 0
|
124 |
-
def change_winner_to_score(winner,score_1,score_2):
|
125 |
-
try:
|
126 |
-
winner = int(winner)
|
127 |
-
except:
|
128 |
-
return score_1+0.5,score_2+0.5
|
129 |
-
if winner == 0:
|
130 |
-
return score_1+1,score_2
|
131 |
-
if winner == 2:
|
132 |
-
return score_1+0.5,score_2+0.5
|
133 |
-
return score_1,score_2+1
|
134 |
-
for result in results:
|
135 |
-
i,j,novelty,relevance,significance,clarity,feasibility,effectiveness = result
|
136 |
-
for dimension in [novelty,relevance,significance,clarity,feasibility,effectiveness]:
|
137 |
-
elo_scores[i],elo_scores[j] = change_winner_to_score(dimension,elo_scores[i],elo_scores[j])
|
138 |
-
print(f"i:{i},j:{j},novelty:{novelty},relevance:{relevance},significance:{significance},clarity:{clarity},feasibility:{feasibility},effectiveness:{effectiveness}")
|
139 |
-
print(elo_scores)
|
140 |
-
try:
|
141 |
-
elo_selected = elo_scores.index(max(elo_scores))
|
142 |
-
except:
|
143 |
-
elo_selected = 0
|
144 |
|
145 |
-
idea,experiment,entities,idea_chain,trend,future,human,year = ideas[elo_selected],experiments[elo_selected],entities[elo_selected],idea_chains[elo_selected],trends[elo_selected],futures[elo_selected],humans[elo_selected],years[elo_selected]
|
146 |
print(f"successfully generated idea")
|
147 |
-
return idea,experiment,entities,idea_chain,
|
148 |
|
149 |
-
|
150 |
article = paper.article
|
151 |
if not article:
|
152 |
return None
|
153 |
paper_content = self.reader.read_paper_content(article)
|
154 |
prompt = get_deep_reference_prompt(paper_content,self.topic)
|
155 |
messages = self.wrap_messages(prompt)
|
156 |
-
response =
|
157 |
entities = extract(response,"entities")
|
158 |
idea = extract(response,"idea")
|
159 |
experiment = extract(response,"experiment")
|
160 |
references = extract(response,"references")
|
161 |
return idea,experiment,entities,references,paper.title
|
162 |
|
163 |
-
|
164 |
paper_content = self.reader.read_paper_content_with_ref(article)
|
165 |
prompt = get_deep_reference_prompt(paper_content,self.topic)
|
166 |
messages = self.wrap_messages(prompt)
|
167 |
-
response =
|
168 |
entities = extract(response,"entities")
|
169 |
idea = extract(response,"idea")
|
170 |
experiment = extract(response,"experiment")
|
@@ -172,7 +133,7 @@ class DeepResearchAgent:
|
|
172 |
return idea,experiment,entities,references
|
173 |
|
174 |
|
175 |
-
|
176 |
print(f"begin to deep research paper {paper.title}")
|
177 |
article = paper.article
|
178 |
if not article:
|
@@ -183,7 +144,7 @@ class DeepResearchAgent:
|
|
183 |
experiments = []
|
184 |
total_entities = []
|
185 |
years = []
|
186 |
-
idea,experiment,entities,references =
|
187 |
try:
|
188 |
references = json.loads(references)
|
189 |
except:
|
@@ -200,7 +161,7 @@ class DeepResearchAgent:
|
|
200 |
# search before
|
201 |
while len(idea_chain)<self.max_chain_length:
|
202 |
rerank_query = f"{self.topic} {current_title} {current_abstract}"
|
203 |
-
citation_paper =
|
204 |
if not citation_paper:
|
205 |
print(f"failed to find citation paper for {current_title}")
|
206 |
break
|
@@ -208,10 +169,10 @@ class DeepResearchAgent:
|
|
208 |
abstract = citation_paper.abstract
|
209 |
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
|
210 |
messages = self.wrap_messages(prompt)
|
211 |
-
response =
|
212 |
relevant = extract(response,"relevant")
|
213 |
if relevant != "0":
|
214 |
-
result =
|
215 |
if not result:
|
216 |
break
|
217 |
idea,experiment,entities,_,_ = result
|
@@ -238,13 +199,13 @@ class DeepResearchAgent:
|
|
238 |
references.pop(0)
|
239 |
if reference in self.read_papers:
|
240 |
continue
|
241 |
-
search_paper =
|
242 |
if len(search_paper) > 0:
|
243 |
s_p = search_paper[0]
|
244 |
if s_p and s_p.title not in self.read_papers:
|
245 |
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
|
246 |
messages = self.wrap_messages(prompt)
|
247 |
-
response =
|
248 |
relevant = extract(response,"relevant")
|
249 |
if relevant != "0" or len(idea_chain) < self.min_chain_length:
|
250 |
article = s_p.article
|
@@ -257,7 +218,7 @@ class DeepResearchAgent:
|
|
257 |
|
258 |
if not article:
|
259 |
rerank_query = f"topic: {self.topic} Title: {current_title} Abstract: {current_abstract}"
|
260 |
-
search_paper =
|
261 |
if not search_paper:
|
262 |
print(f"failed to find citation paper for {current_title}")
|
263 |
continue
|
@@ -273,10 +234,10 @@ class DeepResearchAgent:
|
|
273 |
if s_p and s_p.title not in self.read_papers:
|
274 |
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
|
275 |
messages = self.wrap_messages(prompt)
|
276 |
-
response =
|
277 |
relevant = extract(response,"relevant")
|
278 |
if relevant == "1" or len(idea_chain) < self.min_chain_length:
|
279 |
-
article =
|
280 |
if not article:
|
281 |
continue
|
282 |
else:
|
@@ -290,7 +251,7 @@ class DeepResearchAgent:
|
|
290 |
paper_content = self.reader.read_paper_content_with_ref(article)
|
291 |
prompt = get_deep_reference_prompt(paper_content,self.topic)
|
292 |
messages = self.wrap_messages(prompt)
|
293 |
-
response =
|
294 |
idea = extract(response,"idea")
|
295 |
references = extract(response,"references")
|
296 |
experiment = extract(response,"experiment")
|
@@ -317,7 +278,7 @@ class DeepResearchAgent:
|
|
317 |
|
318 |
prompt = get_deep_trend_idea_chains_prompt(idea_chains,entities,self.topic)
|
319 |
messages = self.wrap_messages(prompt)
|
320 |
-
response =
|
321 |
trend = extract(response,"trend")
|
322 |
|
323 |
self.deep_research_chains.append({"idea_chains":idea_chains,"trend":trend,"topic":self.topic,"ideas":idea_chain,"experiments":experiments,"entities":total_entities,"years":years})
|
@@ -326,26 +287,26 @@ class DeepResearchAgent:
|
|
326 |
<entities> {{cleaned entities}}</entities>
|
327 |
"""
|
328 |
messages = self.wrap_messages(prompt)
|
329 |
-
response =
|
330 |
total_entities = extract(response,"entities")
|
331 |
bad_case = []
|
332 |
prompt = get_deep_generate_future_direciton_prompt(idea_chain,trend,self.topic,total_entities)
|
333 |
messages = self.wrap_messages(prompt)
|
334 |
-
response =
|
335 |
future = extract(response,"future")
|
336 |
human = extract(response,"human")
|
337 |
|
338 |
|
339 |
prompt = get_deep_generate_idea_prompt(idea_chains,trend,self.topic,total_entities,future,bad_case)
|
340 |
messages = self.wrap_messages(prompt)
|
341 |
-
response =
|
342 |
method = extract(response,"method")
|
343 |
novelty = extract(response,"novelty")
|
344 |
motivation = extract(response,"motivation")
|
345 |
idea = {"motivation":motivation,"novelty":novelty,"method":method}
|
346 |
prompt = get_deep_final_idea_prompt(idea_chains,trend,idea,self.topic)
|
347 |
messages = self.wrap_messages(prompt)
|
348 |
-
response =
|
349 |
final_idea = extract(response,"final_idea")
|
350 |
|
351 |
idea = final_idea
|
|
|
1 |
import json
|
2 |
import time
|
|
|
|
|
3 |
from searcher import Result,SementicSearcher
|
4 |
from LLM import openai_llm
|
5 |
from prompts import *
|
|
|
15 |
cheap_llm = get_llm("gpt-4o-mini")
|
16 |
return main_llm,cheap_llm
|
17 |
|
18 |
+
def judge_idea(i,j,idea0,idea1,topic,llm):
|
19 |
prompt = get_judge_idea_all_prompt(idea0,idea1,topic)
|
20 |
messages = [{"role":"user","content":prompt}]
|
21 |
+
response = llm.response(messages)
|
22 |
novelty = extract(response,"novelty")
|
23 |
relevance = extract(response,"relevance")
|
24 |
significance = extract(response,"significance")
|
|
|
53 |
def wrap_messages(self,prompt):
|
54 |
return [{"role":"user","content":prompt}]
|
55 |
|
56 |
+
def get_openai_response(self,messages):
|
57 |
+
return self.llm.response(messages)
|
58 |
|
59 |
+
def get_cheap_openai_response(self,messages):
|
60 |
+
return self.cheap_llm.response(messages,max_tokens = 16000)
|
61 |
|
62 |
+
def get_search_query(self,topic = None,query=None):
|
63 |
prompt = get_deep_search_query_prompt(topic,query)
|
64 |
messages = self.wrap_messages(prompt)
|
65 |
+
response = self.get_openai_response(messages)
|
66 |
search_query = extract(response,"queries")
|
67 |
try:
|
68 |
search_query = json.loads(search_query)
|
|
|
71 |
search_query = [query]
|
72 |
return search_query
|
73 |
|
74 |
+
def generate_idea_with_chain(self,topic):
|
75 |
self.topic = topic
|
76 |
print(f"begin to generate search query for {topic}")
|
77 |
+
search_query = self.get_search_query(topic=topic)
|
78 |
papers = []
|
79 |
for query in search_query:
|
80 |
failed_query = []
|
81 |
current_papers = []
|
82 |
cnt = 0
|
83 |
while len(current_papers) == 0 and cnt < 10:
|
84 |
+
paper = self.reader.search(query,1,paper_list=self.read_papers,llm=self.llm,rerank_query=f"{topic}",publicationDate=self.publicationData)
|
85 |
if paper and len(paper) > 0 and paper[0]:
|
86 |
self.read_papers.add(paper[0].title)
|
87 |
current_papers.append(paper[0])
|
|
|
89 |
failed_query.append(query)
|
90 |
prompt = get_deep_rewrite_query_prompt(failed_query,topic)
|
91 |
messages = self.wrap_messages(prompt)
|
92 |
+
new_query = self.get_openai_response(messages)
|
93 |
new_query = extract(new_query,"query")
|
94 |
print(f"Failed to search papers for {query}, regenerating query {new_query} to search papers.")
|
95 |
query = new_query
|
|
|
102 |
print(f"failed to generate idea {topic}")
|
103 |
return None,None,None,None,None,None,None,None,None
|
104 |
|
105 |
+
idea,idea_chain,experiment,entities,trend,future,human,year = self.deep_research_paper_with_chain(papers[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
|
|
107 |
print(f"successfully generated idea")
|
108 |
+
return idea,experiment,entities,idea_chain,idea,trend,future,human,year
|
109 |
|
110 |
+
def get_paper_idea_experiment_references_info(self,paper):
|
111 |
article = paper.article
|
112 |
if not article:
|
113 |
return None
|
114 |
paper_content = self.reader.read_paper_content(article)
|
115 |
prompt = get_deep_reference_prompt(paper_content,self.topic)
|
116 |
messages = self.wrap_messages(prompt)
|
117 |
+
response = self.get_cheap_openai_response(messages)
|
118 |
entities = extract(response,"entities")
|
119 |
idea = extract(response,"idea")
|
120 |
experiment = extract(response,"experiment")
|
121 |
references = extract(response,"references")
|
122 |
return idea,experiment,entities,references,paper.title
|
123 |
|
124 |
+
def get_article_idea_experiment_references_info(self,article):
|
125 |
paper_content = self.reader.read_paper_content_with_ref(article)
|
126 |
prompt = get_deep_reference_prompt(paper_content,self.topic)
|
127 |
messages = self.wrap_messages(prompt)
|
128 |
+
response = self.get_cheap_openai_response(messages)
|
129 |
entities = extract(response,"entities")
|
130 |
idea = extract(response,"idea")
|
131 |
experiment = extract(response,"experiment")
|
|
|
133 |
return idea,experiment,entities,references
|
134 |
|
135 |
|
136 |
+
def deep_research_paper_with_chain(self,paper:Result):
|
137 |
print(f"begin to deep research paper {paper.title}")
|
138 |
article = paper.article
|
139 |
if not article:
|
|
|
144 |
experiments = []
|
145 |
total_entities = []
|
146 |
years = []
|
147 |
+
idea,experiment,entities,references = self.get_article_idea_experiment_references_info(article)
|
148 |
try:
|
149 |
references = json.loads(references)
|
150 |
except:
|
|
|
161 |
# search before
|
162 |
while len(idea_chain)<self.max_chain_length:
|
163 |
rerank_query = f"{self.topic} {current_title} {current_abstract}"
|
164 |
+
citation_paper = self.reader.search_related_paper(current_title,need_reference=False,rerank_query=rerank_query,llm=self.llm,paper_list=idea_papers)
|
165 |
if not citation_paper:
|
166 |
print(f"failed to find citation paper for {current_title}")
|
167 |
break
|
|
|
169 |
abstract = citation_paper.abstract
|
170 |
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
|
171 |
messages = self.wrap_messages(prompt)
|
172 |
+
response = self.get_openai_response(messages)
|
173 |
relevant = extract(response,"relevant")
|
174 |
if relevant != "0":
|
175 |
+
result = self.get_paper_idea_experiment_references_info(citation_paper)
|
176 |
if not result:
|
177 |
break
|
178 |
idea,experiment,entities,_,_ = result
|
|
|
199 |
references.pop(0)
|
200 |
if reference in self.read_papers:
|
201 |
continue
|
202 |
+
search_paper = self.reader.search(reference,3,llm=self.llm,publicationDate=self.publicationData,paper_list= idea_papers)
|
203 |
if len(search_paper) > 0:
|
204 |
s_p = search_paper[0]
|
205 |
if s_p and s_p.title not in self.read_papers:
|
206 |
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
|
207 |
messages = self.wrap_messages(prompt)
|
208 |
+
response = self.get_openai_response(messages)
|
209 |
relevant = extract(response,"relevant")
|
210 |
if relevant != "0" or len(idea_chain) < self.min_chain_length:
|
211 |
article = s_p.article
|
|
|
218 |
|
219 |
if not article:
|
220 |
rerank_query = f"topic: {self.topic} Title: {current_title} Abstract: {current_abstract}"
|
221 |
+
search_paper = self.reader.search_related_paper(current_title,need_citation=False,rerank_query = rerank_query,llm=self.llm,paper_list=idea_papers)
|
222 |
if not search_paper:
|
223 |
print(f"failed to find citation paper for {current_title}")
|
224 |
continue
|
|
|
234 |
if s_p and s_p.title not in self.read_papers:
|
235 |
prompt = get_deep_judge_relevant_prompt(current_title,current_abstract,self.topic)
|
236 |
messages = self.wrap_messages(prompt)
|
237 |
+
response = self.get_openai_response(messages)
|
238 |
relevant = extract(response,"relevant")
|
239 |
if relevant == "1" or len(idea_chain) < self.min_chain_length:
|
240 |
+
article = s_p.article
|
241 |
if not article:
|
242 |
continue
|
243 |
else:
|
|
|
251 |
paper_content = self.reader.read_paper_content_with_ref(article)
|
252 |
prompt = get_deep_reference_prompt(paper_content,self.topic)
|
253 |
messages = self.wrap_messages(prompt)
|
254 |
+
response = self.get_cheap_openai_response(messages)
|
255 |
idea = extract(response,"idea")
|
256 |
references = extract(response,"references")
|
257 |
experiment = extract(response,"experiment")
|
|
|
278 |
|
279 |
prompt = get_deep_trend_idea_chains_prompt(idea_chains,entities,self.topic)
|
280 |
messages = self.wrap_messages(prompt)
|
281 |
+
response = self.get_openai_response(messages)
|
282 |
trend = extract(response,"trend")
|
283 |
|
284 |
self.deep_research_chains.append({"idea_chains":idea_chains,"trend":trend,"topic":self.topic,"ideas":idea_chain,"experiments":experiments,"entities":total_entities,"years":years})
|
|
|
287 |
<entities> {{cleaned entities}}</entities>
|
288 |
"""
|
289 |
messages = self.wrap_messages(prompt)
|
290 |
+
response = self.get_openai_response(messages)
|
291 |
total_entities = extract(response,"entities")
|
292 |
bad_case = []
|
293 |
prompt = get_deep_generate_future_direciton_prompt(idea_chain,trend,self.topic,total_entities)
|
294 |
messages = self.wrap_messages(prompt)
|
295 |
+
response = self.get_openai_response(messages)
|
296 |
future = extract(response,"future")
|
297 |
human = extract(response,"human")
|
298 |
|
299 |
|
300 |
prompt = get_deep_generate_idea_prompt(idea_chains,trend,self.topic,total_entities,future,bad_case)
|
301 |
messages = self.wrap_messages(prompt)
|
302 |
+
response = self.get_openai_response(messages)
|
303 |
method = extract(response,"method")
|
304 |
novelty = extract(response,"novelty")
|
305 |
motivation = extract(response,"motivation")
|
306 |
idea = {"motivation":motivation,"novelty":novelty,"method":method}
|
307 |
prompt = get_deep_final_idea_prompt(idea_chains,trend,idea,self.topic)
|
308 |
messages = self.wrap_messages(prompt)
|
309 |
+
response = self.get_openai_response(messages)
|
310 |
final_idea = extract(response,"final_idea")
|
311 |
|
312 |
idea = final_idea
|
app.py
CHANGED
@@ -332,7 +332,7 @@ def form_post(topic: str = Form(...)):
|
|
332 |
main_llm, cheap_llm = get_llms()
|
333 |
deep_research_agent = DeepResearchAgent(llm=main_llm, cheap_llm=cheap_llm, improve_cnt=1, max_chain_length=5, min_chain_length=3, max_chain_numbers=1)
|
334 |
print(f"begin to generate idea of topic {topic}")
|
335 |
-
idea, related_experiments, entities, idea_chain, ideas, trend, future, human, year =
|
336 |
idea_md = markdown.markdown(idea)
|
337 |
# 更新每日回复次数
|
338 |
reply_count += 1
|
|
|
332 |
main_llm, cheap_llm = get_llms()
|
333 |
deep_research_agent = DeepResearchAgent(llm=main_llm, cheap_llm=cheap_llm, improve_cnt=1, max_chain_length=5, min_chain_length=3, max_chain_numbers=1)
|
334 |
print(f"begin to generate idea of topic {topic}")
|
335 |
+
idea, related_experiments, entities, idea_chain, ideas, trend, future, human, year = deep_research_agent.generate_idea_with_chain(topic)
|
336 |
idea_md = markdown.markdown(idea)
|
337 |
# 更新每日回复次数
|
338 |
reply_count += 1
|
main.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
-
from agents import DeepResearchAgent,
|
2 |
import asyncio
|
3 |
import json
|
4 |
import argparse
|
5 |
|
|
|
6 |
if __name__ == '__main__':
|
7 |
|
8 |
argparser = argparse.ArgumentParser()
|
@@ -21,18 +22,12 @@ if __name__ == '__main__':
|
|
21 |
topic = args.topic
|
22 |
anchor_paper_path = args.anchor_paper_path
|
23 |
|
24 |
-
|
25 |
-
review_agent = ReviewAgent(save_file=args.save_file,llm=main_llm,cheap_llm=cheap_llm)
|
26 |
deep_research_agent = DeepResearchAgent(llm=main_llm,cheap_llm=cheap_llm,**vars(args))
|
27 |
|
28 |
print(f"begin to generate idea and experiment of topic {topic}")
|
29 |
-
idea,related_experiments,entities,idea_chain,ideas,trend,future,human,year=
|
30 |
-
experiment = asyncio.run(deep_research_agent.generate_experiment(idea,related_experiments,entities))
|
31 |
-
|
32 |
-
for i in range(args.improve_cnt):
|
33 |
-
experiment = asyncio.run(deep_research_agent.improve_experiment(review_agent,idea,experiment,entities))
|
34 |
|
35 |
print(f"succeed to generate idea and experiment of topic {topic}")
|
36 |
-
res = {"idea":idea,"
|
37 |
with open("result.json","w") as f:
|
38 |
json.dump(res,f)
|
|
|
1 |
+
from agents import DeepResearchAgent,get_llms
|
2 |
import asyncio
|
3 |
import json
|
4 |
import argparse
|
5 |
|
6 |
+
|
7 |
if __name__ == '__main__':
|
8 |
|
9 |
argparser = argparse.ArgumentParser()
|
|
|
22 |
topic = args.topic
|
23 |
anchor_paper_path = args.anchor_paper_path
|
24 |
|
|
|
|
|
25 |
deep_research_agent = DeepResearchAgent(llm=main_llm,cheap_llm=cheap_llm,**vars(args))
|
26 |
|
27 |
print(f"begin to generate idea and experiment of topic {topic}")
|
28 |
+
idea,related_experiments,entities,idea_chain,ideas,trend,future,human,year= deep_research_agent.generate_idea_with_chain(topic)
|
|
|
|
|
|
|
|
|
29 |
|
30 |
print(f"succeed to generate idea and experiment of topic {topic}")
|
31 |
+
res = {"idea":idea,"related_experiments":related_experiments,"entities":entities,"idea_chain":idea_chain,"ideas":ideas,"trend":trend,"future":future,"year":year,"human":human}
|
32 |
with open("result.json","w") as f:
|
33 |
json.dump(res,f)
|
searcher/sementic_search.py
CHANGED
@@ -7,7 +7,7 @@ import time
|
|
7 |
import aiohttp
|
8 |
import asyncio
|
9 |
import numpy as np
|
10 |
-
|
11 |
|
12 |
def get_content_between_a_b(start_tag, end_tag, text):
|
13 |
extracted_text = ""
|
@@ -31,29 +31,6 @@ def extract(text, type):
|
|
31 |
return text
|
32 |
else:
|
33 |
return ""
|
34 |
-
|
35 |
-
|
36 |
-
async def fetch(url):
|
37 |
-
await asyncio.sleep(1) # 异步的 sleep 而不是 time.sleep
|
38 |
-
try:
|
39 |
-
timeout = aiohttp.ClientTimeout(total=120)
|
40 |
-
connector = aiohttp.TCPConnector(limit_per_host=10) # 使用连接池
|
41 |
-
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
42 |
-
async with session.get(url) as response:
|
43 |
-
if response.status == 200:
|
44 |
-
content = await response.read() # Read the response content as bytes
|
45 |
-
return content
|
46 |
-
else:
|
47 |
-
print(f"Failed to fetch the URL: {url} with status code: {response.status}")
|
48 |
-
return None
|
49 |
-
except aiohttp.ClientError as e: # 更具体的异常捕获
|
50 |
-
print(f"An error occurred while fetching the URL: {url}")
|
51 |
-
print(e)
|
52 |
-
return None
|
53 |
-
except Exception as e:
|
54 |
-
print(f"An unexpected error occurred while fetching the URL: {url}")
|
55 |
-
print(e)
|
56 |
-
return None
|
57 |
|
58 |
def download(url):
|
59 |
try:
|
@@ -103,7 +80,7 @@ class SementicSearcher:
|
|
103 |
def __init__(self, ban_paper = []) -> None:
|
104 |
self.ban_paper = ban_paper
|
105 |
|
106 |
-
|
107 |
publicationDate=None, minCitationCount=0, year=None,
|
108 |
publicationTypes=None, fieldsOfStudy=None):
|
109 |
url = 'https://api.semanticscholar.org/graph/v1/paper/search'
|
@@ -124,7 +101,6 @@ class SementicSearcher:
|
|
124 |
# Load the API key from the configuration file
|
125 |
api_key = os.environ.get('SEMENTIC_SEARCH_API_KEY',None)
|
126 |
headers = {'x-api-key': api_key} if api_key else None
|
127 |
-
await asyncio.sleep(0.5)
|
128 |
try:
|
129 |
filtered_query_params = {key: value for key, value in query_params.items() if value is not None}
|
130 |
response = requests.get(url, params=filtered_query_params, headers=headers)
|
@@ -135,7 +111,7 @@ class SementicSearcher:
|
|
135 |
elif response.status_code == 429:
|
136 |
time.sleep(1)
|
137 |
print(f"Request failed with status code {response.status_code}: begin to retry")
|
138 |
-
return
|
139 |
else:
|
140 |
print(f"Request failed with status code {response.status_code}: {response.text}")
|
141 |
return None
|
@@ -145,6 +121,23 @@ class SementicSearcher:
|
|
145 |
|
146 |
def cal_cosine_similarity(self, vec1, vec2):
|
147 |
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
def read_arxiv_from_path(self, pdf_path):
|
150 |
def is_pdf(binary_data):
|
@@ -163,97 +156,41 @@ class SementicSearcher:
|
|
163 |
return None
|
164 |
return article_dict
|
165 |
|
166 |
-
|
167 |
paper_content = f"""
|
168 |
Title: {paper['title']}
|
169 |
Abstract: {paper['abstract']}
|
170 |
"""
|
171 |
-
paper_embbeding =
|
172 |
paper_embbeding = np.array(paper_embbeding)
|
173 |
score = self.cal_cosine_similarity(query_embedding,paper_embbeding)
|
174 |
return [paper,score]
|
175 |
|
176 |
|
177 |
-
|
|
|
|
|
|
|
178 |
if len(paper_list) >= 50:
|
179 |
-
paper_list = paper_list
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
url = f'https://api.semanticscholar.org/graph/v1/paper/{paper_id}'
|
196 |
-
fields = process_fields(fields)
|
197 |
-
paper_data_query_params = {'fields': fields}
|
198 |
-
try:
|
199 |
-
async with aiohttp.ClientSession() as session:
|
200 |
-
filtered_query_params = {key: value for key, value in paper_data_query_params.items() if value is not None}
|
201 |
-
headers = {'x-api-key': os.environ.get('SEMENTIC_SEARCH_API_KEY',None)}
|
202 |
-
async with session.get(url, params=filtered_query_params, headers=headers) as response:
|
203 |
-
if response.status == 200:
|
204 |
-
response_data = await response.json()
|
205 |
-
return response_data
|
206 |
-
else:
|
207 |
-
await asyncio.sleep(0.01)
|
208 |
-
print(f"Request failed with status code {response.status}: {await response.text()}")
|
209 |
-
return None
|
210 |
-
except Exception as e:
|
211 |
-
print(f"Failed to get paper details for paper ID: {paper_id}")
|
212 |
-
return None
|
213 |
-
|
214 |
-
async def batch_retrieve_papers_async(self, paper_ids, fields = semantic_fields):
|
215 |
-
url = 'https://api.semanticscholar.org/graph/v1/paper/batch'
|
216 |
-
paper_data_query_params = {'fields': process_fields(fields)}
|
217 |
-
paper_ids_json = {"ids": paper_ids}
|
218 |
-
try:
|
219 |
-
async with aiohttp.ClientSession() as session:
|
220 |
-
filtered_query_params = {key: value for key, value in paper_data_query_params.items() if value is not None}
|
221 |
-
headers = {'x-api-key': os.environ.get('SEMENTIC_SEARCH_API_KEY',None)}
|
222 |
-
async with session.post(url, json=paper_ids_json, params=filtered_query_params, headers=headers) as response:
|
223 |
-
if response.status == 200:
|
224 |
-
response_data = await response.json()
|
225 |
-
return response_data
|
226 |
-
else:
|
227 |
-
await asyncio.sleep(0.01)
|
228 |
-
print(f"Request failed with status code {response.status}: {await response.text()}")
|
229 |
-
return None
|
230 |
-
except Exception as e:
|
231 |
-
print(f"Failed to batch retrieve papers for paper IDs: {paper_ids}")
|
232 |
-
return None
|
233 |
-
|
234 |
-
async def search_paper_from_title_async(self, query,fields = ["title","paperId"]):
|
235 |
-
url = 'https://api.semanticscholar.org/graph/v1/paper/search/match'
|
236 |
-
fields = process_fields(fields)
|
237 |
-
query_params = {'query': query, 'fields': fields}
|
238 |
-
try:
|
239 |
-
async with aiohttp.ClientSession() as session:
|
240 |
-
filtered_query_params = {key: value for key, value in query_params.items() if value is not None}
|
241 |
-
headers = {'x-api-key': os.environ.get('SEMENTIC_SEARCH_API_KEY',None)}
|
242 |
-
async with session.get(url, params=filtered_query_params, headers=headers) as response:
|
243 |
-
if response.status == 200:
|
244 |
-
response_data = await response.json()
|
245 |
-
return response_data
|
246 |
-
else:
|
247 |
-
await asyncio.sleep(0.01)
|
248 |
-
print(f"Request failed with status code {response.status}: {await response.text()}")
|
249 |
-
return None
|
250 |
-
except Exception as e:
|
251 |
-
await asyncio.sleep(0.01)
|
252 |
-
print(f"Failed to search paper from title: {query}")
|
253 |
-
return None
|
254 |
|
255 |
|
256 |
-
|
257 |
if rerank_query:
|
258 |
rerank_query_embbeding = llm.get_embbeding(rerank_query)
|
259 |
rerank_query_embbeding = np.array(rerank_query_embbeding)
|
@@ -270,7 +207,7 @@ Abstract: {paper['abstract']}
|
|
270 |
readed_papers = [paper.title for paper in paper_list]
|
271 |
|
272 |
print(f"Searching for papers related to the query: <{query}>")
|
273 |
-
results =
|
274 |
if not results or "data" not in results:
|
275 |
return []
|
276 |
|
@@ -293,8 +230,7 @@ Abstract: {paper['abstract']}
|
|
293 |
paper_candidates = results
|
294 |
|
295 |
if llm and rerank_query:
|
296 |
-
paper_candidates =
|
297 |
-
paper_candidates = [paper[0] for paper in paper_candidates if paper]
|
298 |
|
299 |
if need_download:
|
300 |
for result in paper_candidates:
|
@@ -326,10 +262,10 @@ Abstract: {paper['abstract']}
|
|
326 |
break
|
327 |
return final_results
|
328 |
|
329 |
-
|
330 |
-
print(f"Searching for the related papers of <{title}
|
331 |
fileds = ["title","abstract","citations.title","citations.abstract","citations.citationCount","references.title","references.abstract","references.citationCount","citations.isOpenAccess","citations.openAccessPdf","references.isOpenAccess","references.openAccessPdf","citations.year","references.year"]
|
332 |
-
results =
|
333 |
related_papers = []
|
334 |
related_papers_title = []
|
335 |
if not results or "data" not in results:
|
@@ -367,8 +303,7 @@ Abstract: {paper['abstract']}
|
|
367 |
if rerank_query and llm:
|
368 |
rerank_query_embbeding = llm.get_embbeding(rerank_query)
|
369 |
rerank_query_embbeding = np.array(rerank_query_embbeding)
|
370 |
-
related_papers =
|
371 |
-
related_papers = [paper[0] for paper in related_papers]
|
372 |
related_papers = [[paper["title"],paper["abstract"],paper["openAccessPdf"]["url"],paper["citationCount"],paper['year']] for paper in related_papers]
|
373 |
else:
|
374 |
related_papers = [[paper["title"],paper["abstract"],paper["openAccessPdf"]["url"],paper["citationCount"],paper['year']] for paper in related_papers]
|
@@ -385,13 +320,6 @@ Abstract: {paper['abstract']}
|
|
385 |
return result
|
386 |
return None
|
387 |
|
388 |
-
|
389 |
-
async def download_pdf_async(self, pdf_link):
|
390 |
-
content = await fetch(pdf_link)
|
391 |
-
if not content:
|
392 |
-
return None
|
393 |
-
else:
|
394 |
-
return content
|
395 |
|
396 |
def download_pdf(self, pdf_link):
|
397 |
content = download(pdf_link)
|
|
|
7 |
import aiohttp
|
8 |
import asyncio
|
9 |
import numpy as np
|
10 |
+
import random
|
11 |
|
12 |
def get_content_between_a_b(start_tag, end_tag, text):
|
13 |
extracted_text = ""
|
|
|
31 |
return text
|
32 |
else:
|
33 |
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
def download(url):
|
36 |
try:
|
|
|
80 |
def __init__(self, ban_paper = []) -> None:
|
81 |
self.ban_paper = ban_paper
|
82 |
|
83 |
+
def search_papers(self, query, limit=5, offset=0, fields=["title", "paperId", "abstract", "isOpenAccess", 'openAccessPdf', "year","publicationDate","citations.title","citations.abstract","citations.isOpenAccess","citations.openAccessPdf","citations.citationCount","citationCount","citations.year"],
|
84 |
publicationDate=None, minCitationCount=0, year=None,
|
85 |
publicationTypes=None, fieldsOfStudy=None):
|
86 |
url = 'https://api.semanticscholar.org/graph/v1/paper/search'
|
|
|
101 |
# Load the API key from the configuration file
|
102 |
api_key = os.environ.get('SEMENTIC_SEARCH_API_KEY',None)
|
103 |
headers = {'x-api-key': api_key} if api_key else None
|
|
|
104 |
try:
|
105 |
filtered_query_params = {key: value for key, value in query_params.items() if value is not None}
|
106 |
response = requests.get(url, params=filtered_query_params, headers=headers)
|
|
|
111 |
elif response.status_code == 429:
|
112 |
time.sleep(1)
|
113 |
print(f"Request failed with status code {response.status_code}: begin to retry")
|
114 |
+
return self.search_papers(query, limit, offset, fields, publicationDate, minCitationCount, year, publicationTypes, fieldsOfStudy)
|
115 |
else:
|
116 |
print(f"Request failed with status code {response.status_code}: {response.text}")
|
117 |
return None
|
|
|
121 |
|
122 |
def cal_cosine_similarity(self, vec1, vec2):
|
123 |
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
124 |
+
|
125 |
+
def cal_cosine_similarity_matric(self,matric1, matric2):
|
126 |
+
if isinstance(matric1, list):
|
127 |
+
matric1 = np.array(matric1)
|
128 |
+
if isinstance(matric2, list):
|
129 |
+
matric2 = np.array(matric2)
|
130 |
+
if len(matric1.shape) == 1:
|
131 |
+
matric1 = matric1.reshape(1, -1)
|
132 |
+
if len(matric2.shape) == 1:
|
133 |
+
matric2 = matric2.reshape(1, -1)
|
134 |
+
dot_product = np.dot(matric1, matric2.T)
|
135 |
+
norm1 = np.linalg.norm(matric1, axis=1)
|
136 |
+
norm2 = np.linalg.norm(matric2, axis=1)
|
137 |
+
|
138 |
+
cos_sim = dot_product / np.outer(norm1, norm2)
|
139 |
+
scores = cos_sim.flatten()
|
140 |
+
return scores.tolist()
|
141 |
|
142 |
def read_arxiv_from_path(self, pdf_path):
|
143 |
def is_pdf(binary_data):
|
|
|
156 |
return None
|
157 |
return article_dict
|
158 |
|
159 |
+
def get_paper_embbeding_and_score(self,query_embedding, paper,llm):
|
160 |
paper_content = f"""
|
161 |
Title: {paper['title']}
|
162 |
Abstract: {paper['abstract']}
|
163 |
"""
|
164 |
+
paper_embbeding = llm.get_embbeding(paper_content)
|
165 |
paper_embbeding = np.array(paper_embbeding)
|
166 |
score = self.cal_cosine_similarity(query_embedding,paper_embbeding)
|
167 |
return [paper,score]
|
168 |
|
169 |
|
170 |
+
def rerank_papers(self, query_embedding, paper_list,llm):
|
171 |
+
if len(paper_list) == 0:
|
172 |
+
return []
|
173 |
+
paper_list = [paper for paper in paper_list if paper]
|
174 |
if len(paper_list) >= 50:
|
175 |
+
paper_list = random.sample(paper_list,50)
|
176 |
+
paper_contents = []
|
177 |
+
for paper in paper_list:
|
178 |
+
paper_content = f"""
|
179 |
+
Title: {paper['title']}
|
180 |
+
Abstract: {paper['abstract']}
|
181 |
+
"""
|
182 |
+
paper_contents.append(paper_content)
|
183 |
+
paper_contents_embbeding = llm.get_embbeding(paper_contents)
|
184 |
+
paper_contents_embbeding = np.array(paper_contents_embbeding)
|
185 |
+
scores = self.cal_cosine_similarity_matric(query_embedding,paper_contents_embbeding)
|
186 |
+
|
187 |
+
# 根据score对paper_list进行排序
|
188 |
+
paper_list = sorted(zip(paper_list,scores),key = lambda x: x[1],reverse = True)
|
189 |
+
paper_list = [paper[0] for paper in paper_list]
|
190 |
+
return paper_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
|
193 |
+
def search(self,query,max_results = 5 ,paper_list = None ,rerank_query = None,llm = None,year = None,publicationDate = None,need_download = True,fields = ["title", "paperId", "abstract", "isOpenAccess", 'openAccessPdf', "year","publicationDate","citationCount"]):
|
194 |
if rerank_query:
|
195 |
rerank_query_embbeding = llm.get_embbeding(rerank_query)
|
196 |
rerank_query_embbeding = np.array(rerank_query_embbeding)
|
|
|
207 |
readed_papers = [paper.title for paper in paper_list]
|
208 |
|
209 |
print(f"Searching for papers related to the query: <{query}>")
|
210 |
+
results = self.search_papers(query,limit = 10 * max_results,year=year,publicationDate = publicationDate,fields = fields)
|
211 |
if not results or "data" not in results:
|
212 |
return []
|
213 |
|
|
|
230 |
paper_candidates = results
|
231 |
|
232 |
if llm and rerank_query:
|
233 |
+
paper_candidates = self.rerank_papers(rerank_query_embbeding, paper_candidates,llm)
|
|
|
234 |
|
235 |
if need_download:
|
236 |
for result in paper_candidates:
|
|
|
262 |
break
|
263 |
return final_results
|
264 |
|
265 |
+
def search_related_paper(self,title,need_citation = True,need_reference = True,rerank_query = None,llm = None,paper_list = []):
|
266 |
+
print(f"Searching for the related papers of <{title}>, need_citation: {need_citation}, need_reference: {need_reference}")
|
267 |
fileds = ["title","abstract","citations.title","citations.abstract","citations.citationCount","references.title","references.abstract","references.citationCount","citations.isOpenAccess","citations.openAccessPdf","references.isOpenAccess","references.openAccessPdf","citations.year","references.year"]
|
268 |
+
results = self.search_papers(title,limit = 3,fields=fileds)
|
269 |
related_papers = []
|
270 |
related_papers_title = []
|
271 |
if not results or "data" not in results:
|
|
|
303 |
if rerank_query and llm:
|
304 |
rerank_query_embbeding = llm.get_embbeding(rerank_query)
|
305 |
rerank_query_embbeding = np.array(rerank_query_embbeding)
|
306 |
+
related_papers = self.rerank_papers(rerank_query_embbeding, related_papers,llm)
|
|
|
307 |
related_papers = [[paper["title"],paper["abstract"],paper["openAccessPdf"]["url"],paper["citationCount"],paper['year']] for paper in related_papers]
|
308 |
else:
|
309 |
related_papers = [[paper["title"],paper["abstract"],paper["openAccessPdf"]["url"],paper["citationCount"],paper['year']] for paper in related_papers]
|
|
|
320 |
return result
|
321 |
return None
|
322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
|
324 |
def download_pdf(self, pdf_link):
|
325 |
content = download(pdf_link)
|