gamingflexer commited on
Commit
ec6a480
1 Parent(s): d91a4d0

Add arXiv scrapper module

Browse files
Files changed (1) hide show
  1. src/scrapper/arxiv.py +66 -0
src/scrapper/arxiv.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from requests.adapters import HTTPAdapter, Retry
3
+ import logging
4
+ from typing import Union, Any, Optional
5
+ import re
6
+
7
+ """
8
+ Usage : get_paper_id("8-bit matrix multiplication for transformers at scale") -> 2106.09680
9
+ """
10
+
11
+ paper_id_re = re.compile(r'https://arxiv.org/abs/(\d+\.\d+)')
12
+
13
+ def retry_request_session(retries: Optional[int] = 5):
14
+ # we setup retry strategy to retry on common errors
15
+ retries = Retry(
16
+ total=retries,
17
+ backoff_factor=0.1,
18
+ status_forcelist=[
19
+ 408, # request timeout
20
+ 500, # internal server error
21
+ 502, # bad gateway
22
+ 503, # service unavailable
23
+ 504 # gateway timeout
24
+ ]
25
+ )
26
+ # we setup a session with the retry strategy
27
+ session = requests.Session()
28
+ session.mount('https://', HTTPAdapter(max_retries=retries))
29
+ return session
30
+
31
+ def get_paper_id(query: str, handle_not_found: bool = True):
32
+ """Get the paper ID from a query.
33
+
34
+ :param query: The query to search with
35
+ :type query: str
36
+ :param handle_not_found: Whether to return None if no paper is found,
37
+ defaults to True
38
+ :type handle_not_found: bool, optional
39
+ :return: The paper ID
40
+ :rtype: str
41
+ """
42
+ special_chars = {
43
+ ":": "%3A",
44
+ "|": "%7C",
45
+ ",": "%2C",
46
+ " ": "+"
47
+ }
48
+ # create a translation table from the special_chars dictionary
49
+ translation_table = query.maketrans(special_chars)
50
+ # use the translate method to replace the special characters
51
+ search_term = query.translate(translation_table)
52
+ # init requests search session
53
+ session = retry_request_session()
54
+ # get the search results
55
+ res = session.get(f"https://www.google.com/search?q={search_term}&sclient=gws-wiz-serp")
56
+ try:
57
+ # extract the paper id
58
+ paper_id = paper_id_re.findall(res.text)[0]
59
+ except IndexError:
60
+ if handle_not_found:
61
+ # if no paper is found, return None
62
+ return None
63
+ else:
64
+ # if no paper is found, raise an error
65
+ raise Exception(f'No paper found for query: {query}')
66
+ return paper_id