Spaces:
Running
Running
shaocongma
commited on
Commit
•
70e35a5
1
Parent(s):
677c576
Re-format references. Remove ArXiv API Search.
Browse files- utils/references.py +121 -124
utils/references.py
CHANGED
@@ -1,14 +1,23 @@
|
|
1 |
# Each `paper` is a dictionary containing:
|
2 |
-
# (1) paper_id (2) title (3) authors (4) year (5) link (6) abstract (7) journal
|
3 |
#
|
4 |
# Generate references:
|
5 |
# `Reference` class:
|
6 |
# 1. Read a given .bib file to collect papers; use `search_paper_abstract` method to fill missing abstract.
|
7 |
-
# 2. Given some keywords; use
|
8 |
# 3. Generate bibtex from the selected papers. --> to_bibtex()
|
9 |
# 4. Generate prompts from the selected papers: --> to_prompts()
|
10 |
# A sample prompt: {"paper_id": "paper summary"}
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
import requests
|
13 |
import re
|
14 |
import bibtexparser
|
@@ -30,12 +39,15 @@ def remove_newlines(serie):
|
|
30 |
|
31 |
def search_paper_abstract(title):
|
32 |
pg = ProxyGenerator()
|
33 |
-
success = pg.ScraperAPI("921b16f94d701308b9d9b4456ddde155")
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
def load_papers_from_bibtex(bib_file_path):
|
@@ -46,6 +58,7 @@ def load_papers_from_bibtex(bib_file_path):
|
|
46 |
else:
|
47 |
bib_papers = []
|
48 |
for bibitem in bib_database.entries:
|
|
|
49 |
paper_id = bibitem.get("ID")
|
50 |
title = bibitem.get("title")
|
51 |
if title is None:
|
@@ -68,7 +81,6 @@ def load_papers_from_bibtex(bib_file_path):
|
|
68 |
bib_papers.append(result)
|
69 |
return bib_papers
|
70 |
|
71 |
-
|
72 |
######################################################################################################################
|
73 |
# Semantic Scholar (SS) API
|
74 |
######################################################################################################################
|
@@ -131,6 +143,9 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
|
|
131 |
return authors_str, last_name
|
132 |
|
133 |
def parse_search_results(search_results_ss):
|
|
|
|
|
|
|
134 |
# turn the search result to a list of paper dictionary.
|
135 |
papers_ss = []
|
136 |
for raw_paper in search_results_ss:
|
@@ -140,16 +155,20 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
|
|
140 |
authors_str, last_name = extract_author_info(raw_paper['authors'])
|
141 |
year_str = str(raw_paper['year'])
|
142 |
title = raw_paper['title']
|
|
|
143 |
# some journal may contain &; replace it. e.g. journal={IEEE Power & Energy Society General Meeting}
|
144 |
journal = raw_paper['venue'].replace("&", "\\&")
|
145 |
if not journal:
|
146 |
journal = "arXiv preprint"
|
|
|
147 |
paper_id = extract_paper_id(last_name, year_str, title).lower()
|
148 |
link = externalIds2link(raw_paper['externalIds'])
|
|
|
149 |
if tldr and raw_paper['tldr'] is not None:
|
150 |
abstract = raw_paper['tldr']['text']
|
151 |
else:
|
152 |
abstract = remove_newlines(raw_paper['abstract'])
|
|
|
153 |
result = {
|
154 |
"paper_id": paper_id,
|
155 |
"title": title,
|
@@ -157,134 +176,65 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
|
|
157 |
"link": link,
|
158 |
"authors": authors_str,
|
159 |
"year": year_str,
|
160 |
-
"journal": journal
|
|
|
161 |
}
|
162 |
papers_ss.append(result)
|
163 |
return papers_ss
|
164 |
|
165 |
raw_results = ss_search(keyword, limit=counts)
|
166 |
if raw_results is not None:
|
167 |
-
search_results = raw_results
|
|
|
|
|
168 |
else:
|
169 |
search_results = []
|
170 |
results = parse_search_results(search_results)
|
171 |
return results
|
172 |
|
173 |
-
|
174 |
-
######################################################################################################################
|
175 |
-
# ArXiv API
|
176 |
-
######################################################################################################################
|
177 |
-
def _collect_papers_arxiv(keyword, counts=3, tldr=False):
|
178 |
-
# Build the arXiv API query URL with the given keyword and other parameters
|
179 |
-
def build_query_url(keyword, results_limit=3, sort_by="relevance", sort_order="descending"):
|
180 |
-
base_url = "http://export.arxiv.org/api/query?"
|
181 |
-
query = f"search_query=all:{keyword}&start=0&max_results={results_limit}"
|
182 |
-
query += f"&sortBy={sort_by}&sortOrder={sort_order}"
|
183 |
-
return base_url + query
|
184 |
-
|
185 |
-
# Fetch search results from the arXiv API using the constructed URL
|
186 |
-
def fetch_search_results(query_url):
|
187 |
-
response = requests.get(query_url)
|
188 |
-
return response.text
|
189 |
-
|
190 |
-
# Parse the XML content of the API response to extract paper information
|
191 |
-
def parse_results(content):
|
192 |
-
from xml.etree import ElementTree as ET
|
193 |
-
|
194 |
-
root = ET.fromstring(content)
|
195 |
-
namespace = "{http://www.w3.org/2005/Atom}"
|
196 |
-
entries = root.findall(f"{namespace}entry")
|
197 |
-
|
198 |
-
results = []
|
199 |
-
for entry in entries:
|
200 |
-
title = entry.find(f"{namespace}title").text
|
201 |
-
link = entry.find(f"{namespace}id").text
|
202 |
-
summary = entry.find(f"{namespace}summary").text
|
203 |
-
summary = remove_newlines(summary)
|
204 |
-
|
205 |
-
# Extract the authors
|
206 |
-
authors = entry.findall(f"{namespace}author")
|
207 |
-
author_list = []
|
208 |
-
for author in authors:
|
209 |
-
name = author.find(f"{namespace}name").text
|
210 |
-
author_list.append(name)
|
211 |
-
authors_str = " and ".join(author_list)
|
212 |
-
|
213 |
-
# Extract the year
|
214 |
-
published = entry.find(f"{namespace}published").text
|
215 |
-
year = published.split("-")[0]
|
216 |
-
|
217 |
-
founds = re.search(r'\d+\.\d+', link)
|
218 |
-
if founds is None:
|
219 |
-
# some links are not standard; such as "https://arxiv.org/abs/cs/0603127v1".
|
220 |
-
# will be solved in the future.
|
221 |
-
continue
|
222 |
-
else:
|
223 |
-
arxiv_id = founds.group(0)
|
224 |
-
journal = f"arXiv preprint arXiv:{arxiv_id}"
|
225 |
-
result = {
|
226 |
-
"paper_id": arxiv_id,
|
227 |
-
"title": title,
|
228 |
-
"link": link,
|
229 |
-
"abstract": summary,
|
230 |
-
"authors": authors_str,
|
231 |
-
"year": year,
|
232 |
-
"journal": journal
|
233 |
-
}
|
234 |
-
results.append(result)
|
235 |
-
|
236 |
-
return results
|
237 |
-
|
238 |
-
query_url = build_query_url(keyword, counts)
|
239 |
-
content = fetch_search_results(query_url)
|
240 |
-
results = parse_results(content)
|
241 |
-
return results
|
242 |
-
|
243 |
-
|
244 |
######################################################################################################################
|
245 |
# References Class
|
246 |
######################################################################################################################
|
247 |
|
248 |
class References:
|
249 |
-
def __init__(self
|
250 |
-
if load_papers:
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
else:
|
256 |
-
|
257 |
-
|
258 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
"""
|
260 |
keywords_dict:
|
261 |
{"machine learning": 5, "language model": 2};
|
262 |
the first is the keyword, the second is how many references are needed.
|
263 |
"""
|
264 |
-
match method:
|
265 |
-
case "arxiv":
|
266 |
-
process = _collect_papers_arxiv
|
267 |
-
case "ss":
|
268 |
-
process = _collect_papers_ss
|
269 |
-
case _:
|
270 |
-
raise NotImplementedError("Other sources have not been not supported yet.")
|
271 |
for key, counts in keywords_dict.items():
|
272 |
-
self.papers =
|
|
|
|
|
273 |
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
paper_id = paper["paper_id"]
|
278 |
-
if paper_id not in seen:
|
279 |
-
seen.add(paper_id)
|
280 |
-
papers.append(paper)
|
281 |
-
self.papers = papers
|
282 |
|
283 |
def to_bibtex(self, path_to_bibtex="ref.bib"):
|
284 |
"""
|
285 |
Turn the saved paper list into bibtex file "ref.bib". Return a list of all `paper_id`.
|
286 |
"""
|
287 |
-
papers = self.
|
288 |
|
289 |
# clear the bibtex file
|
290 |
with open(path_to_bibtex, "w", encoding="utf-8") as file:
|
@@ -308,31 +258,78 @@ class References:
|
|
308 |
file.write("\n\n")
|
309 |
return paper_ids
|
310 |
|
311 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
# `prompts`:
|
313 |
# {"paper1_bibtex_id": "paper_1_abstract", "paper2_bibtex_id": "paper2_abstract"}
|
314 |
# this will be used to instruct GPT model to cite the correct bibtex entry.
|
|
|
315 |
prompts = {}
|
316 |
-
for paper in
|
317 |
prompts[paper["paper_id"]] = paper["abstract"]
|
318 |
return prompts
|
319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
if __name__ == "__main__":
|
322 |
-
#
|
323 |
-
#
|
324 |
-
#
|
325 |
-
#
|
326 |
-
#
|
327 |
-
#
|
328 |
-
#
|
329 |
-
#
|
330 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
# for p in refs.papers:
|
332 |
# print(p["paper_id"])
|
333 |
# print(len(refs.papers))
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
-
bib = "D:\\Projects\\auto-draft\\latex_templates\\pre_refs.bib"
|
336 |
-
papers = load_papers_from_bibtex(bib)
|
337 |
-
for paper in papers:
|
338 |
-
|
|
|
1 |
# Each `paper` is a dictionary containing:
|
2 |
+
# (1) paper_id (2) title (3) authors (4) year (5) link (6) abstract (7) journal (8) embeddings
|
3 |
#
|
4 |
# Generate references:
|
5 |
# `Reference` class:
|
6 |
# 1. Read a given .bib file to collect papers; use `search_paper_abstract` method to fill missing abstract.
|
7 |
+
# 2. Given some keywords; use Semantic Scholar API to find papers.
|
8 |
# 3. Generate bibtex from the selected papers. --> to_bibtex()
|
9 |
# 4. Generate prompts from the selected papers: --> to_prompts()
|
10 |
# A sample prompt: {"paper_id": "paper summary"}
|
11 |
|
12 |
+
# todo: (1) citations & citedby of provided papers:
|
13 |
+
# load the pre-defined papers; use S2 to find all related works
|
14 |
+
# add all citations to `bib_papers`
|
15 |
+
# add all citedby to `bib_papers`
|
16 |
+
# use Semantic Scholar to find their embeddings
|
17 |
+
# (2) separate references:
|
18 |
+
# divide references into different groups to reduce the tokens count
|
19 |
+
# for generating different paragraph of related works, use different set of references
|
20 |
+
|
21 |
import requests
|
22 |
import re
|
23 |
import bibtexparser
|
|
|
39 |
|
40 |
def search_paper_abstract(title):
|
41 |
pg = ProxyGenerator()
|
42 |
+
success = pg.ScraperAPI("921b16f94d701308b9d9b4456ddde155") # todo: change this to env. var. for protection.
|
43 |
+
if success:
|
44 |
+
scholarly.use_proxy(pg)
|
45 |
+
# input the title of a paper, return its abstract
|
46 |
+
search_query = scholarly.search_pubs(title)
|
47 |
+
found_paper = next(search_query)
|
48 |
+
else:
|
49 |
+
raise RuntimeError("ScraperAPI fails.")
|
50 |
+
return remove_newlines(found_paper['bib']['abstract'])
|
51 |
|
52 |
|
53 |
def load_papers_from_bibtex(bib_file_path):
|
|
|
58 |
else:
|
59 |
bib_papers = []
|
60 |
for bibitem in bib_database.entries:
|
61 |
+
# Add each paper to `bib_papers`
|
62 |
paper_id = bibitem.get("ID")
|
63 |
title = bibitem.get("title")
|
64 |
if title is None:
|
|
|
81 |
bib_papers.append(result)
|
82 |
return bib_papers
|
83 |
|
|
|
84 |
######################################################################################################################
|
85 |
# Semantic Scholar (SS) API
|
86 |
######################################################################################################################
|
|
|
143 |
return authors_str, last_name
|
144 |
|
145 |
def parse_search_results(search_results_ss):
|
146 |
+
if len(search_results_ss) == 0:
|
147 |
+
return []
|
148 |
+
|
149 |
# turn the search result to a list of paper dictionary.
|
150 |
papers_ss = []
|
151 |
for raw_paper in search_results_ss:
|
|
|
155 |
authors_str, last_name = extract_author_info(raw_paper['authors'])
|
156 |
year_str = str(raw_paper['year'])
|
157 |
title = raw_paper['title']
|
158 |
+
|
159 |
# some journal may contain &; replace it. e.g. journal={IEEE Power & Energy Society General Meeting}
|
160 |
journal = raw_paper['venue'].replace("&", "\\&")
|
161 |
if not journal:
|
162 |
journal = "arXiv preprint"
|
163 |
+
|
164 |
paper_id = extract_paper_id(last_name, year_str, title).lower()
|
165 |
link = externalIds2link(raw_paper['externalIds'])
|
166 |
+
|
167 |
if tldr and raw_paper['tldr'] is not None:
|
168 |
abstract = raw_paper['tldr']['text']
|
169 |
else:
|
170 |
abstract = remove_newlines(raw_paper['abstract'])
|
171 |
+
embeddings = raw_paper['embedding']['vector']
|
172 |
result = {
|
173 |
"paper_id": paper_id,
|
174 |
"title": title,
|
|
|
176 |
"link": link,
|
177 |
"authors": authors_str,
|
178 |
"year": year_str,
|
179 |
+
"journal": journal,
|
180 |
+
"embeddings": embeddings
|
181 |
}
|
182 |
papers_ss.append(result)
|
183 |
return papers_ss
|
184 |
|
185 |
raw_results = ss_search(keyword, limit=counts)
|
186 |
if raw_results is not None:
|
187 |
+
search_results = raw_results.get("data")
|
188 |
+
if search_results is None:
|
189 |
+
search_results = []
|
190 |
else:
|
191 |
search_results = []
|
192 |
results = parse_search_results(search_results)
|
193 |
return results
|
194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
######################################################################################################################
|
196 |
# References Class
|
197 |
######################################################################################################################
|
198 |
|
199 |
class References:
|
200 |
+
def __init__(self):
|
201 |
+
# if load_papers:
|
202 |
+
# # todo: (1) too large bibtex may make have issues on token limitations; may truncate to 5 or 10
|
203 |
+
# # (2) google scholar didn't give a full abstract for some papers ...
|
204 |
+
# # (3) may use langchain to support long input
|
205 |
+
# self.papers = load_papers_from_bibtex(load_papers)
|
206 |
+
# else:
|
207 |
+
self.papers = {}
|
208 |
+
|
209 |
+
def load_papers(self, bibtex, keyword):
|
210 |
+
self.papers[keyword] = load_papers_from_bibtex(bibtex)
|
211 |
+
|
212 |
+
def generate_keywords_dict(self):
|
213 |
+
keywords_dict = {}
|
214 |
+
for k in self.papers:
|
215 |
+
keywords_dict[k] = len(self.papers[k])
|
216 |
+
return keywords_dict
|
217 |
+
|
218 |
+
def collect_papers(self, keywords_dict, tldr=False):
|
219 |
"""
|
220 |
keywords_dict:
|
221 |
{"machine learning": 5, "language model": 2};
|
222 |
the first is the keyword, the second is how many references are needed.
|
223 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
for key, counts in keywords_dict.items():
|
225 |
+
self.papers[key] = _collect_papers_ss(key, counts, tldr)
|
226 |
+
|
227 |
+
# Remove duplicated references # todo: remove duplicated references in tex_processing procedure.
|
228 |
|
229 |
+
def find_relevant(self, max_refs=30):
|
230 |
+
# todo: use embeddings to evaluate
|
231 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
232 |
|
233 |
def to_bibtex(self, path_to_bibtex="ref.bib"):
|
234 |
"""
|
235 |
Turn the saved paper list into bibtex file "ref.bib". Return a list of all `paper_id`.
|
236 |
"""
|
237 |
+
papers = self._get_papers(keyword = "_all")
|
238 |
|
239 |
# clear the bibtex file
|
240 |
with open(path_to_bibtex, "w", encoding="utf-8") as file:
|
|
|
258 |
file.write("\n\n")
|
259 |
return paper_ids
|
260 |
|
261 |
+
def _get_papers(self, keyword = "_all"):
|
262 |
+
if keyword == "_all":
|
263 |
+
papers = []
|
264 |
+
for k, v in self.papers.items():
|
265 |
+
papers = papers + v
|
266 |
+
else:
|
267 |
+
papers = self.papers["keyword"]
|
268 |
+
return papers
|
269 |
+
|
270 |
+
def to_prompts(self, keyword = "_all"):
|
271 |
# `prompts`:
|
272 |
# {"paper1_bibtex_id": "paper_1_abstract", "paper2_bibtex_id": "paper2_abstract"}
|
273 |
# this will be used to instruct GPT model to cite the correct bibtex entry.
|
274 |
+
papers = self._get_papers(keyword)
|
275 |
prompts = {}
|
276 |
+
for paper in papers:
|
277 |
prompts[paper["paper_id"]] = paper["abstract"]
|
278 |
return prompts
|
279 |
|
280 |
+
def to_json(self, keyword = "_all"):
|
281 |
+
papers = self._get_papers(keyword)
|
282 |
+
papers_json = {}
|
283 |
+
for paper in papers:
|
284 |
+
papers_json[paper["paper_id"]] = paper
|
285 |
+
return papers_json
|
286 |
+
|
287 |
+
|
288 |
|
289 |
if __name__ == "__main__":
|
290 |
+
# r = ss_search("Deep Q-Networks")['data']
|
291 |
+
# print(r)
|
292 |
+
# papers_json = {}
|
293 |
+
# # for i in range(len(r)):
|
294 |
+
# # r[i]
|
295 |
+
# #
|
296 |
+
# # with open("Output.txt", "w") as text_file:
|
297 |
+
# # text_file.write("Purchase Amount: %s" % TotalAmount)
|
298 |
+
# embeddings = r[0]['embedding']['vector']
|
299 |
+
# print(embeddings)
|
300 |
+
|
301 |
+
refs = References()
|
302 |
+
keywords_dict = {
|
303 |
+
"Deep Q-Networks": 5,
|
304 |
+
"Actor-Critic Algorithms": 4,
|
305 |
+
"Exploration-Exploitation Trade-off": 3
|
306 |
+
}
|
307 |
+
refs.collect_papers(keywords_dict, method="ss", tldr=True)
|
308 |
+
for k in refs.papers:
|
309 |
+
papers = refs.papers[k]
|
310 |
+
print("keyword: ", k)
|
311 |
+
for paper in papers:
|
312 |
+
print(paper["paper_id"])
|
313 |
+
|
314 |
+
refs.to_json()
|
315 |
+
refs.to_bibtex()
|
316 |
+
refs.to_prompts()
|
317 |
+
# print(refs.papers)
|
318 |
+
|
319 |
+
# todo: test load_papers
|
320 |
+
# write test covering `references.py`. / fix this as a stable version
|
321 |
+
|
322 |
# for p in refs.papers:
|
323 |
# print(p["paper_id"])
|
324 |
# print(len(refs.papers))
|
325 |
+
#
|
326 |
+
# papers_json = refs.to_json()
|
327 |
+
# # print(papers_json)
|
328 |
+
# with open("papers.json", "w", encoding='utf-8') as text_file:
|
329 |
+
# text_file.write(f"{papers_json}")
|
330 |
+
|
331 |
|
332 |
+
# bib = "D:\\Projects\\auto-draft\\latex_templates\\pre_refs.bib"
|
333 |
+
# papers = load_papers_from_bibtex(bib)
|
334 |
+
# for paper in papers:
|
335 |
+
# print(paper)
|