File size: 3,935 Bytes
7e4123a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e08d6ea
 
 
 
 
 
 
 
 
 
 
7a143c1
7e4123a
 
 
 
 
 
e08d6ea
7e4123a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e08d6ea
7e4123a
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import datasets
import pandas as pd
from huggingface_hub import HfApi

from utils import push_to_hf_hub
from paper.download import download_pdf_from_arxiv
from paper.download import get_papers_from_arxiv_ids
from paper.parser import extract_text_and_figures
from gen.gemini import get_basic_qa, get_deep_qa

def _filter_function(example, ids):
    ids_e = example['Requested arXiv IDs']
    for iid in ids:
        if iid in ids_e:
            ids_e.remove(iid)
            example['Requested arXiv IDs'] = ids_e

    print(example)
    return example

def _extract_arxiv_id(text):
    print(text)
    start = text.find("[") + 1
    end = text.find("]")

    # Extract the text between brackets
    if start != -1 and end != -1:
        return text[start:end]
    else:
        return text

def process_arxiv_ids(gemini_api, hf_repo_id, req_hf_repo_id, hf_token, restart_repo_id, how_many=1):
    arxiv_ids = []

    ds1 = datasets.load_dataset(req_hf_repo_id)
    for d in ds1['train']:
        req_arxiv_ids = d['Requested arXiv IDs']
        if len(req_arxiv_ids) > 0 and req_arxiv_ids[0] != "top":
            arxiv_ids = arxiv_ids + [_extract_arxiv_id(req_arxiv_ids[0])]

    arxiv_ids = arxiv_ids[:how_many]

    if arxiv_ids is not None and len(arxiv_ids) > 0:
        print(f"1. Get metadata for the papers [{arxiv_ids}]")
        papers = get_papers_from_arxiv_ids(arxiv_ids)
        print("...DONE")
        
        print("2. Generating QAs for the paper")
        for paper in papers:
            try:
                title = paper['title']
                target_date = paper['target_date']
                abstract = paper['paper']['summary']
                arxiv_id = paper['paper']['id']
                authors = paper['paper']['authors']

                print(f"...PROCESSING ON[{arxiv_id}, {title}]")
                print(f"......Downloading the paper PDF")
                filename = download_pdf_from_arxiv(arxiv_id)
                print(f"......DONE")

                print(f"......Extracting text and figures")
                texts, figures = extract_text_and_figures(filename)
                text =' '.join(texts)
                print(f"......DONE")

                print(f"......Generating the seed(basic) QAs")
                qnas = get_basic_qa(text, gemini_api_key=gemini_api, trucate=30000)
                qnas['title'] = title
                qnas['abstract'] = abstract
                qnas['authors'] = ','.join(authors)
                qnas['arxiv_id'] = arxiv_id
                qnas['target_date'] = target_date
                qnas['full_text'] = text
                print(f"......DONE")

                print(f"......Generating the follow-up QAs")
                qnas = get_deep_qa(text, qnas, gemini_api_key=gemini_api, trucate=30000)
                del qnas["qna"]
                print(f"......DONE")

                print(f"......Exporting to HF Dataset repo at [{hf_repo_id}]")
                df = pd.DataFrame([qnas])
                ds = datasets.Dataset.from_pandas(df)
                ds = ds.cast_column("target_date", datasets.features.Value("timestamp[s]"))                
                push_to_hf_hub(ds, hf_repo_id, hf_token)
                print(f"......DONE")

                print(f"......Updating request arXiv HF Dataset repo at [{req_hf_repo_id}]")
                ds1 = ds1['train'].map(
                    lambda example: _filter_function(example, [f"[{arxiv_id}] {title}"])
                ).filter(
                    lambda example: len(example['Requested arXiv IDs']) > 0
                )
                ds1.push_to_hub(req_hf_repo_id, token=hf_token)
                            
                print(f"......DONE")
            except Exception as e:
                print(f".......failed due to exception {e}")
                continue

        HfApi(token=hf_token).restart_space(
            repo_id=restart_repo_id, token=hf_token
        )