Spaces:
Runtime error
Runtime error
gamingflexer
commited on
Commit
•
ec6a480
1
Parent(s):
d91a4d0
Add arXiv scrapper module
Browse files- src/scrapper/arxiv.py +66 -0
src/scrapper/arxiv.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from requests.adapters import HTTPAdapter, Retry
|
3 |
+
import logging
|
4 |
+
from typing import Union, Any, Optional
|
5 |
+
import re
|
6 |
+
|
7 |
+
"""
|
8 |
+
Usage : get_paper_id("8-bit matrix multiplication for transformers at scale") -> 2106.09680
|
9 |
+
"""
|
10 |
+
|
11 |
+
paper_id_re = re.compile(r'https://arxiv.org/abs/(\d+\.\d+)')
|
12 |
+
|
13 |
+
def retry_request_session(retries: Optional[int] = 5):
|
14 |
+
# we setup retry strategy to retry on common errors
|
15 |
+
retries = Retry(
|
16 |
+
total=retries,
|
17 |
+
backoff_factor=0.1,
|
18 |
+
status_forcelist=[
|
19 |
+
408, # request timeout
|
20 |
+
500, # internal server error
|
21 |
+
502, # bad gateway
|
22 |
+
503, # service unavailable
|
23 |
+
504 # gateway timeout
|
24 |
+
]
|
25 |
+
)
|
26 |
+
# we setup a session with the retry strategy
|
27 |
+
session = requests.Session()
|
28 |
+
session.mount('https://', HTTPAdapter(max_retries=retries))
|
29 |
+
return session
|
30 |
+
|
31 |
+
def get_paper_id(query: str, handle_not_found: bool = True):
|
32 |
+
"""Get the paper ID from a query.
|
33 |
+
|
34 |
+
:param query: The query to search with
|
35 |
+
:type query: str
|
36 |
+
:param handle_not_found: Whether to return None if no paper is found,
|
37 |
+
defaults to True
|
38 |
+
:type handle_not_found: bool, optional
|
39 |
+
:return: The paper ID
|
40 |
+
:rtype: str
|
41 |
+
"""
|
42 |
+
special_chars = {
|
43 |
+
":": "%3A",
|
44 |
+
"|": "%7C",
|
45 |
+
",": "%2C",
|
46 |
+
" ": "+"
|
47 |
+
}
|
48 |
+
# create a translation table from the special_chars dictionary
|
49 |
+
translation_table = query.maketrans(special_chars)
|
50 |
+
# use the translate method to replace the special characters
|
51 |
+
search_term = query.translate(translation_table)
|
52 |
+
# init requests search session
|
53 |
+
session = retry_request_session()
|
54 |
+
# get the search results
|
55 |
+
res = session.get(f"https://www.google.com/search?q={search_term}&sclient=gws-wiz-serp")
|
56 |
+
try:
|
57 |
+
# extract the paper id
|
58 |
+
paper_id = paper_id_re.findall(res.text)[0]
|
59 |
+
except IndexError:
|
60 |
+
if handle_not_found:
|
61 |
+
# if no paper is found, return None
|
62 |
+
return None
|
63 |
+
else:
|
64 |
+
# if no paper is found, raise an error
|
65 |
+
raise Exception(f'No paper found for query: {query}')
|
66 |
+
return paper_id
|