File size: 3,076 Bytes
7e4123a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import copy
import datasets
import pandas as pd
from collections import defaultdict

from datetime import datetime, timedelta
from background import process_arxiv_ids
from apscheduler.schedulers.background import BackgroundScheduler

def _count_nans(row):
    count = 0

    for _, (k, v) in enumerate(row.items()):
        if v is None:
            count = count + 1

    return count

def _initialize_requested_arxiv_ids(request_ds):
    requested_arxiv_ids = []

    for request_d in request_ds['train']:
        arxiv_ids = request_d['Requested arXiv IDs']
        requested_arxiv_ids = requested_arxiv_ids + arxiv_ids
    
    requested_arxiv_ids_df = pd.DataFrame({'Requested arXiv IDs': requested_arxiv_ids})
    return requested_arxiv_ids_df

def _initialize_paper_info(source_ds):
    title2qna, date2qna = {}, {}
    date_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    arxivid2data = {}
    count = 0    

    for data in source_ds["train"]:
        date = data["target_date"].strftime("%Y-%m-%d")
        arxiv_id = data["arxiv_id"]

        if date in date2qna:
            papers = copy.deepcopy(date2qna[date])
            for paper in papers:
                if paper["title"] == data["title"]:
                    if _count_nans(paper) > _count_nans(data):
                        date2qna[date].remove(paper)
            
            date2qna[date].append(data)
            del papers
        else:
            date2qna[date] = [data]

    for date in date2qna:
        year, month, day = date.split("-")
        papers = date2qna[date]
        for paper in papers:
            title2qna[paper["title"]] = paper
            arxivid2data[paper['arxiv_id']] = {"idx": count, "paper": paper}
            date_dict[year][month][day].append(paper)

    titles = title2qna.keys()

    return titles, date_dict, arxivid2data

def initialize_data(source_data_repo_id, request_data_repo_id):
    global date_dict, arxivid2data
    global requested_arxiv_ids_df

    source_ds = datasets.load_dataset(source_data_repo_id)
    request_ds = datasets.load_dataset(request_data_repo_id)
    
    titles, date_dict, arxivid2data = _initialize_paper_info(source_ds)
    requested_arxiv_ids_df = _initialize_requested_arxiv_ids(request_ds)

    return (
        titles, date_dict, requested_arxiv_ids_df, arxivid2data
    )

def update_dataframe(request_data_repo_id):
    request_ds = datasets.load_dataset(request_data_repo_id)
    return _initialize_requested_arxiv_ids(request_ds)

def get_secrets():
    global gemini_api_key
    global hf_token
    global request_arxiv_repo_id
    global dataset_repo_id

    gemini_api_key = os.getenv("GEMINI_API_KEY")
    hf_token = os.getenv("HF_TOKEN")
    dataset_repo_id = os.getenv("SOURCE_DATA_REPO_ID") 
    request_arxiv_repo_id = os.getenv("REQUEST_DATA_REPO_ID")
    restart_repo_id = os.getenv("RESTART_TARGET_SPACE_REPO_ID", "chansung/paper_qa")

    return (
        gemini_api_key,
        hf_token,
        dataset_repo_id,
        request_arxiv_repo_id,
        restart_repo_id
    )