File size: 3,689 Bytes
7e4123a
 
 
 
046ea77
7e4123a
 
 
 
046ea77
7e4123a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e08d6ea
7e4123a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
046ea77
 
 
 
 
 
 
 
 
 
 
 
 
7e4123a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import copy
import datasets
import pandas as pd
from datasets import Dataset
from collections import defaultdict

from datetime import datetime, timedelta
from background import process_arxiv_ids
from utils import create_hf_hub
from apscheduler.schedulers.background import BackgroundScheduler

def _count_nans(row):
    count = 0

    for _, (k, v) in enumerate(row.items()):
        if v is None:
            count = count + 1

    return count

def _initialize_requested_arxiv_ids(request_ds):
    requested_arxiv_ids = []

    for request_d in request_ds['train']:
        arxiv_ids = request_d['Requested arXiv IDs']
        requested_arxiv_ids = requested_arxiv_ids + arxiv_ids
    
    requested_arxiv_ids_df = pd.DataFrame({'Requested arXiv IDs': requested_arxiv_ids})
    return requested_arxiv_ids_df

def _initialize_paper_info(source_ds):
    title2qna, date2qna = {}, {}
    date_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    arxivid2data = {}
    count = 0    

    for data in source_ds["train"]:
        date = data["target_date"].strftime("%Y-%m-%d")
        arxiv_id = data["arxiv_id"]

        if date in date2qna:
            papers = copy.deepcopy(date2qna[date])
            for paper in papers:
                if paper["title"] == data["title"]:
                    if _count_nans(paper) > _count_nans(data):
                        date2qna[date].remove(paper)
            
            date2qna[date].append(data)
            del papers
        else:
            date2qna[date] = [data]

    for date in date2qna:
        year, month, day = date.split("-")
        papers = date2qna[date]
        for paper in papers:
            title2qna[paper["title"]] = paper
            arxivid2data[paper['arxiv_id']] = {"idx": count, "paper": paper}
            date_dict[year][month][day].append(paper)

    titles = [f"[{v['arxiv_id']}] {k}" for k, v in title2qna.items()]

    return titles, date_dict, arxivid2data

def initialize_data(source_data_repo_id, request_data_repo_id):
    global date_dict, arxivid2data
    global requested_arxiv_ids_df

    source_ds = datasets.load_dataset(source_data_repo_id)
    request_ds = datasets.load_dataset(request_data_repo_id)
    
    titles, date_dict, arxivid2data = _initialize_paper_info(source_ds)
    requested_arxiv_ids_df = _initialize_requested_arxiv_ids(request_ds)

    return (
        titles, date_dict, requested_arxiv_ids_df, arxivid2data
    )

def update_dataframe(request_data_repo_id):
    request_ds = datasets.load_dataset(request_data_repo_id)
    return _initialize_requested_arxiv_ids(request_ds)

def initialize_repos(
    source_data_repo_id, request_data_repo_id, hf_token
):
    if create_hf_hub(source_data_repo_id, hf_token) is False:
        print(f"{source_data_repo_id} repository already exists")

    if create_hf_hub(request_data_repo_id, hf_token) is False:
        print(f"{request_data_repo_id} repository already exists")
    else:
        df = pd.DataFrame(data={"Requested arXiv IDs": [["top"]]})
        ds = Dataset.from_df(df)
        ds.push_to_hub(request_data_repo_id, token=hf_token)

def get_secrets():
    global gemini_api_key
    global hf_token
    global request_arxiv_repo_id
    global dataset_repo_id

    gemini_api_key = os.getenv("GEMINI_API_KEY")
    hf_token = os.getenv("HF_TOKEN")
    dataset_repo_id = os.getenv("SOURCE_DATA_REPO_ID") 
    request_arxiv_repo_id = os.getenv("REQUEST_DATA_REPO_ID")
    restart_repo_id = os.getenv("RESTART_TARGET_SPACE_REPO_ID", "chansung/paper_qa")

    return (
        gemini_api_key,
        hf_token,
        dataset_repo_id,
        request_arxiv_repo_id,
        restart_repo_id
    )