shaocongma commited on
Commit
70e35a5
1 Parent(s): 677c576

Re-format references. Remove ArXiv API Search.

Browse files
Files changed (1) hide show
  1. utils/references.py +121 -124
utils/references.py CHANGED
@@ -1,14 +1,23 @@
1
  # Each `paper` is a dictionary containing:
2
- # (1) paper_id (2) title (3) authors (4) year (5) link (6) abstract (7) journal
3
  #
4
  # Generate references:
5
  # `Reference` class:
6
  # 1. Read a given .bib file to collect papers; use `search_paper_abstract` method to fill missing abstract.
7
- # 2. Given some keywords; use ArXiv or Semantic Scholar API to find papers.
8
  # 3. Generate bibtex from the selected papers. --> to_bibtex()
9
  # 4. Generate prompts from the selected papers: --> to_prompts()
10
  # A sample prompt: {"paper_id": "paper summary"}
11
 
 
 
 
 
 
 
 
 
 
12
  import requests
13
  import re
14
  import bibtexparser
@@ -30,12 +39,15 @@ def remove_newlines(serie):
30
 
31
  def search_paper_abstract(title):
32
  pg = ProxyGenerator()
33
- success = pg.ScraperAPI("921b16f94d701308b9d9b4456ddde155")
34
- scholarly.use_proxy(pg)
35
- # input the title of a paper, return its abstract
36
- search_query = scholarly.search_pubs(title)
37
- paper = next(search_query)
38
- return remove_newlines(paper['bib']['abstract'])
 
 
 
39
 
40
 
41
  def load_papers_from_bibtex(bib_file_path):
@@ -46,6 +58,7 @@ def load_papers_from_bibtex(bib_file_path):
46
  else:
47
  bib_papers = []
48
  for bibitem in bib_database.entries:
 
49
  paper_id = bibitem.get("ID")
50
  title = bibitem.get("title")
51
  if title is None:
@@ -68,7 +81,6 @@ def load_papers_from_bibtex(bib_file_path):
68
  bib_papers.append(result)
69
  return bib_papers
70
 
71
-
72
  ######################################################################################################################
73
  # Semantic Scholar (SS) API
74
  ######################################################################################################################
@@ -131,6 +143,9 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
131
  return authors_str, last_name
132
 
133
  def parse_search_results(search_results_ss):
 
 
 
134
  # turn the search result to a list of paper dictionary.
135
  papers_ss = []
136
  for raw_paper in search_results_ss:
@@ -140,16 +155,20 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
140
  authors_str, last_name = extract_author_info(raw_paper['authors'])
141
  year_str = str(raw_paper['year'])
142
  title = raw_paper['title']
 
143
  # some journal may contain &; replace it. e.g. journal={IEEE Power & Energy Society General Meeting}
144
  journal = raw_paper['venue'].replace("&", "\\&")
145
  if not journal:
146
  journal = "arXiv preprint"
 
147
  paper_id = extract_paper_id(last_name, year_str, title).lower()
148
  link = externalIds2link(raw_paper['externalIds'])
 
149
  if tldr and raw_paper['tldr'] is not None:
150
  abstract = raw_paper['tldr']['text']
151
  else:
152
  abstract = remove_newlines(raw_paper['abstract'])
 
153
  result = {
154
  "paper_id": paper_id,
155
  "title": title,
@@ -157,134 +176,65 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
157
  "link": link,
158
  "authors": authors_str,
159
  "year": year_str,
160
- "journal": journal
 
161
  }
162
  papers_ss.append(result)
163
  return papers_ss
164
 
165
  raw_results = ss_search(keyword, limit=counts)
166
  if raw_results is not None:
167
- search_results = raw_results['data']
 
 
168
  else:
169
  search_results = []
170
  results = parse_search_results(search_results)
171
  return results
172
 
173
-
174
- ######################################################################################################################
175
- # ArXiv API
176
- ######################################################################################################################
177
- def _collect_papers_arxiv(keyword, counts=3, tldr=False):
178
- # Build the arXiv API query URL with the given keyword and other parameters
179
- def build_query_url(keyword, results_limit=3, sort_by="relevance", sort_order="descending"):
180
- base_url = "http://export.arxiv.org/api/query?"
181
- query = f"search_query=all:{keyword}&start=0&max_results={results_limit}"
182
- query += f"&sortBy={sort_by}&sortOrder={sort_order}"
183
- return base_url + query
184
-
185
- # Fetch search results from the arXiv API using the constructed URL
186
- def fetch_search_results(query_url):
187
- response = requests.get(query_url)
188
- return response.text
189
-
190
- # Parse the XML content of the API response to extract paper information
191
- def parse_results(content):
192
- from xml.etree import ElementTree as ET
193
-
194
- root = ET.fromstring(content)
195
- namespace = "{http://www.w3.org/2005/Atom}"
196
- entries = root.findall(f"{namespace}entry")
197
-
198
- results = []
199
- for entry in entries:
200
- title = entry.find(f"{namespace}title").text
201
- link = entry.find(f"{namespace}id").text
202
- summary = entry.find(f"{namespace}summary").text
203
- summary = remove_newlines(summary)
204
-
205
- # Extract the authors
206
- authors = entry.findall(f"{namespace}author")
207
- author_list = []
208
- for author in authors:
209
- name = author.find(f"{namespace}name").text
210
- author_list.append(name)
211
- authors_str = " and ".join(author_list)
212
-
213
- # Extract the year
214
- published = entry.find(f"{namespace}published").text
215
- year = published.split("-")[0]
216
-
217
- founds = re.search(r'\d+\.\d+', link)
218
- if founds is None:
219
- # some links are not standard; such as "https://arxiv.org/abs/cs/0603127v1".
220
- # will be solved in the future.
221
- continue
222
- else:
223
- arxiv_id = founds.group(0)
224
- journal = f"arXiv preprint arXiv:{arxiv_id}"
225
- result = {
226
- "paper_id": arxiv_id,
227
- "title": title,
228
- "link": link,
229
- "abstract": summary,
230
- "authors": authors_str,
231
- "year": year,
232
- "journal": journal
233
- }
234
- results.append(result)
235
-
236
- return results
237
-
238
- query_url = build_query_url(keyword, counts)
239
- content = fetch_search_results(query_url)
240
- results = parse_results(content)
241
- return results
242
-
243
-
244
  ######################################################################################################################
245
  # References Class
246
  ######################################################################################################################
247
 
248
  class References:
249
- def __init__(self, load_papers=""):
250
- if load_papers:
251
- # todo: (1) too large bibtex may make have issues on token limitations; may truncate to 5 or 10
252
- # (2) google scholar didn't give a full abstract for some papers ...
253
- # (3) may use langchain to support long input
254
- self.papers = load_papers_from_bibtex(load_papers)
255
- else:
256
- self.papers = []
257
-
258
- def collect_papers(self, keywords_dict, method="arxiv", tldr=False):
 
 
 
 
 
 
 
 
 
259
  """
260
  keywords_dict:
261
  {"machine learning": 5, "language model": 2};
262
  the first is the keyword, the second is how many references are needed.
263
  """
264
- match method:
265
- case "arxiv":
266
- process = _collect_papers_arxiv
267
- case "ss":
268
- process = _collect_papers_ss
269
- case _:
270
- raise NotImplementedError("Other sources have not been not supported yet.")
271
  for key, counts in keywords_dict.items():
272
- self.papers = self.papers + process(key, counts, tldr)
 
 
273
 
274
- seen = set()
275
- papers = []
276
- for paper in self.papers:
277
- paper_id = paper["paper_id"]
278
- if paper_id not in seen:
279
- seen.add(paper_id)
280
- papers.append(paper)
281
- self.papers = papers
282
 
283
  def to_bibtex(self, path_to_bibtex="ref.bib"):
284
  """
285
  Turn the saved paper list into bibtex file "ref.bib". Return a list of all `paper_id`.
286
  """
287
- papers = self.papers
288
 
289
  # clear the bibtex file
290
  with open(path_to_bibtex, "w", encoding="utf-8") as file:
@@ -308,31 +258,78 @@ class References:
308
  file.write("\n\n")
309
  return paper_ids
310
 
311
- def to_prompts(self):
 
 
 
 
 
 
 
 
 
312
  # `prompts`:
313
  # {"paper1_bibtex_id": "paper_1_abstract", "paper2_bibtex_id": "paper2_abstract"}
314
  # this will be used to instruct GPT model to cite the correct bibtex entry.
 
315
  prompts = {}
316
- for paper in self.papers:
317
  prompts[paper["paper_id"]] = paper["abstract"]
318
  return prompts
319
 
 
 
 
 
 
 
 
 
320
 
321
  if __name__ == "__main__":
322
- # refs = References()
323
- # keywords_dict = {
324
- # "Deep Q-Networks": 15,
325
- # "Policy Gradient Methods": 24,
326
- # "Actor-Critic Algorithms": 4,
327
- # "Model-Based Reinforcement Learning": 13,
328
- # "Exploration-Exploitation Trade-off": 7
329
- # }
330
- # refs.collect_papers(keywords_dict, method="ss", tldr=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  # for p in refs.papers:
332
  # print(p["paper_id"])
333
  # print(len(refs.papers))
 
 
 
 
 
 
334
 
335
- bib = "D:\\Projects\\auto-draft\\latex_templates\\pre_refs.bib"
336
- papers = load_papers_from_bibtex(bib)
337
- for paper in papers:
338
- print(paper)
 
1
  # Each `paper` is a dictionary containing:
2
+ # (1) paper_id (2) title (3) authors (4) year (5) link (6) abstract (7) journal (8) embeddings
3
  #
4
  # Generate references:
5
  # `Reference` class:
6
  # 1. Read a given .bib file to collect papers; use `search_paper_abstract` method to fill missing abstract.
7
+ # 2. Given some keywords; use Semantic Scholar API to find papers.
8
  # 3. Generate bibtex from the selected papers. --> to_bibtex()
9
  # 4. Generate prompts from the selected papers: --> to_prompts()
10
  # A sample prompt: {"paper_id": "paper summary"}
11
 
12
+ # todo: (1) citations & citedby of provided papers:
13
+ # load the pre-defined papers; use S2 to find all related works
14
+ # add all citations to `bib_papers`
15
+ # add all citedby to `bib_papers`
16
+ # use Semantic Scholar to find their embeddings
17
+ # (2) separate references:
18
+ # divide references into different groups to reduce the tokens count
19
+ # for generating different paragraph of related works, use different set of references
20
+
21
  import requests
22
  import re
23
  import bibtexparser
 
39
 
40
  def search_paper_abstract(title):
41
  pg = ProxyGenerator()
42
+ success = pg.ScraperAPI("921b16f94d701308b9d9b4456ddde155") # todo: change this to env. var. for protection.
43
+ if success:
44
+ scholarly.use_proxy(pg)
45
+ # input the title of a paper, return its abstract
46
+ search_query = scholarly.search_pubs(title)
47
+ found_paper = next(search_query)
48
+ else:
49
+ raise RuntimeError("ScraperAPI fails.")
50
+ return remove_newlines(found_paper['bib']['abstract'])
51
 
52
 
53
  def load_papers_from_bibtex(bib_file_path):
 
58
  else:
59
  bib_papers = []
60
  for bibitem in bib_database.entries:
61
+ # Add each paper to `bib_papers`
62
  paper_id = bibitem.get("ID")
63
  title = bibitem.get("title")
64
  if title is None:
 
81
  bib_papers.append(result)
82
  return bib_papers
83
 
 
84
  ######################################################################################################################
85
  # Semantic Scholar (SS) API
86
  ######################################################################################################################
 
143
  return authors_str, last_name
144
 
145
  def parse_search_results(search_results_ss):
146
+ if len(search_results_ss) == 0:
147
+ return []
148
+
149
  # turn the search result to a list of paper dictionary.
150
  papers_ss = []
151
  for raw_paper in search_results_ss:
 
155
  authors_str, last_name = extract_author_info(raw_paper['authors'])
156
  year_str = str(raw_paper['year'])
157
  title = raw_paper['title']
158
+
159
  # some journal may contain &; replace it. e.g. journal={IEEE Power & Energy Society General Meeting}
160
  journal = raw_paper['venue'].replace("&", "\\&")
161
  if not journal:
162
  journal = "arXiv preprint"
163
+
164
  paper_id = extract_paper_id(last_name, year_str, title).lower()
165
  link = externalIds2link(raw_paper['externalIds'])
166
+
167
  if tldr and raw_paper['tldr'] is not None:
168
  abstract = raw_paper['tldr']['text']
169
  else:
170
  abstract = remove_newlines(raw_paper['abstract'])
171
+ embeddings = raw_paper['embedding']['vector']
172
  result = {
173
  "paper_id": paper_id,
174
  "title": title,
 
176
  "link": link,
177
  "authors": authors_str,
178
  "year": year_str,
179
+ "journal": journal,
180
+ "embeddings": embeddings
181
  }
182
  papers_ss.append(result)
183
  return papers_ss
184
 
185
  raw_results = ss_search(keyword, limit=counts)
186
  if raw_results is not None:
187
+ search_results = raw_results.get("data")
188
+ if search_results is None:
189
+ search_results = []
190
  else:
191
  search_results = []
192
  results = parse_search_results(search_results)
193
  return results
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  ######################################################################################################################
196
  # References Class
197
  ######################################################################################################################
198
 
199
  class References:
200
+ def __init__(self):
201
+ # if load_papers:
202
+ # # todo: (1) too large bibtex may make have issues on token limitations; may truncate to 5 or 10
203
+ # # (2) google scholar didn't give a full abstract for some papers ...
204
+ # # (3) may use langchain to support long input
205
+ # self.papers = load_papers_from_bibtex(load_papers)
206
+ # else:
207
+ self.papers = {}
208
+
209
+ def load_papers(self, bibtex, keyword):
210
+ self.papers[keyword] = load_papers_from_bibtex(bibtex)
211
+
212
+ def generate_keywords_dict(self):
213
+ keywords_dict = {}
214
+ for k in self.papers:
215
+ keywords_dict[k] = len(self.papers[k])
216
+ return keywords_dict
217
+
218
+ def collect_papers(self, keywords_dict, tldr=False):
219
  """
220
  keywords_dict:
221
  {"machine learning": 5, "language model": 2};
222
  the first is the keyword, the second is how many references are needed.
223
  """
 
 
 
 
 
 
 
224
  for key, counts in keywords_dict.items():
225
+ self.papers[key] = _collect_papers_ss(key, counts, tldr)
226
+
227
+ # Remove duplicated references # todo: remove duplicated references in tex_processing procedure.
228
 
229
+ def find_relevant(self, max_refs=30):
230
+ # todo: use embeddings to evaluate
231
+ pass
 
 
 
 
 
232
 
233
  def to_bibtex(self, path_to_bibtex="ref.bib"):
234
  """
235
  Turn the saved paper list into bibtex file "ref.bib". Return a list of all `paper_id`.
236
  """
237
+ papers = self._get_papers(keyword = "_all")
238
 
239
  # clear the bibtex file
240
  with open(path_to_bibtex, "w", encoding="utf-8") as file:
 
258
  file.write("\n\n")
259
  return paper_ids
260
 
261
+ def _get_papers(self, keyword = "_all"):
262
+ if keyword == "_all":
263
+ papers = []
264
+ for k, v in self.papers.items():
265
+ papers = papers + v
266
+ else:
267
+ papers = self.papers["keyword"]
268
+ return papers
269
+
270
+ def to_prompts(self, keyword = "_all"):
271
  # `prompts`:
272
  # {"paper1_bibtex_id": "paper_1_abstract", "paper2_bibtex_id": "paper2_abstract"}
273
  # this will be used to instruct GPT model to cite the correct bibtex entry.
274
+ papers = self._get_papers(keyword)
275
  prompts = {}
276
+ for paper in papers:
277
  prompts[paper["paper_id"]] = paper["abstract"]
278
  return prompts
279
 
280
+ def to_json(self, keyword = "_all"):
281
+ papers = self._get_papers(keyword)
282
+ papers_json = {}
283
+ for paper in papers:
284
+ papers_json[paper["paper_id"]] = paper
285
+ return papers_json
286
+
287
+
288
 
289
  if __name__ == "__main__":
290
+ # r = ss_search("Deep Q-Networks")['data']
291
+ # print(r)
292
+ # papers_json = {}
293
+ # # for i in range(len(r)):
294
+ # # r[i]
295
+ # #
296
+ # # with open("Output.txt", "w") as text_file:
297
+ # # text_file.write("Purchase Amount: %s" % TotalAmount)
298
+ # embeddings = r[0]['embedding']['vector']
299
+ # print(embeddings)
300
+
301
+ refs = References()
302
+ keywords_dict = {
303
+ "Deep Q-Networks": 5,
304
+ "Actor-Critic Algorithms": 4,
305
+ "Exploration-Exploitation Trade-off": 3
306
+ }
307
+ refs.collect_papers(keywords_dict, method="ss", tldr=True)
308
+ for k in refs.papers:
309
+ papers = refs.papers[k]
310
+ print("keyword: ", k)
311
+ for paper in papers:
312
+ print(paper["paper_id"])
313
+
314
+ refs.to_json()
315
+ refs.to_bibtex()
316
+ refs.to_prompts()
317
+ # print(refs.papers)
318
+
319
+ # todo: test load_papers
320
+ # write test covering `references.py`. / fix this as a stable version
321
+
322
  # for p in refs.papers:
323
  # print(p["paper_id"])
324
  # print(len(refs.papers))
325
+ #
326
+ # papers_json = refs.to_json()
327
+ # # print(papers_json)
328
+ # with open("papers.json", "w", encoding='utf-8') as text_file:
329
+ # text_file.write(f"{papers_json}")
330
+
331
 
332
+ # bib = "D:\\Projects\\auto-draft\\latex_templates\\pre_refs.bib"
333
+ # papers = load_papers_from_bibtex(bib)
334
+ # for paper in papers:
335
+ # print(paper)