File size: 4,259 Bytes
b4eb3ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fb4b90
 
b4eb3ca
 
 
9fb4b90
b4eb3ca
 
 
 
 
 
 
9fb4b90
b4eb3ca
9fb4b90
b4eb3ca
9fb4b90
b4eb3ca
9fb4b90
b4eb3ca
 
 
 
 
 
 
 
9fb4b90
 
d1e8757
9fb4b90
b4eb3ca
9fb4b90
 
b4eb3ca
8cbdd6f
b4eb3ca
 
 
9fb4b90
 
b4eb3ca
 
 
9fb4b90
 
e2797b8
 
9fb4b90
 
e2797b8
 
9fb4b90
 
e2797b8
 
b4eb3ca
 
9fb4b90
b4eb3ca
 
 
e2797b8
9fb4b90
 
e2797b8
9fb4b90
e2797b8
 
 
9fb4b90
e2797b8
 
 
 
9fb4b90
 
 
 
 
 
 
 
b4eb3ca
8cbdd6f
b4eb3ca
 
 
 
 
 
9fb4b90
b4eb3ca
9fb4b90
b4eb3ca
 
9fb4b90
 
 
 
 
 
 
 
b4eb3ca
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import datetime
import pathlib
import re
import tempfile

import pandas as pd
import requests
from apscheduler.schedulers.background import BackgroundScheduler
from huggingface_hub import HfApi, Repository
from huggingface_hub.utils import RepositoryNotFoundError


class SpaceRestarter:
    def __init__(self, space_id: str):
        self.api = HfApi()
        if self.api.get_token_permission() != "write":
            raise ValueError("The HF token must have write permission.")
        try:
            self.api.space_info(repo_id=space_id)
        except RepositoryNotFoundError:
            raise ValueError("The Space ID does not exist.")
        self.space_id = space_id

    def restart(self) -> None:
        self.api.restart_space(self.space_id)


def find_github_links(summary: str) -> str:
    links = re.findall(r"https://github.com/[^/]+/[^/)}, ]+(?:/(?:tree|blob)/[^/]+/[^/)}, ]+)?", summary)
    if len(links) == 0:
        return ""
    if len(links) != 1:
        raise RuntimeError(f"Found multiple GitHub links: {links}")
    link = links[0]
    if link.endswith("."):
        link = link[:-1]
    link = link.strip()
    return link


class RepoUpdater:
    def __init__(self, repo_id: str, repo_type: str):
        api = HfApi()
        if api.get_token_permission() != "write":
            raise ValueError("The HF token must have write permission.")

        name = api.whoami()["name"]

        repo_dir = pathlib.Path(tempfile.tempdir) / repo_id.split("/")[-1]  # type: ignore
        self.csv_path = repo_dir / "papers.csv"
        self.repo = Repository(
            local_dir=repo_dir,
            clone_from=repo_id,
            repo_type=repo_type,
            git_user=name,
            git_email=f"{name}@users.noreply.huggingface.co",
        )
        self.repo.git_pull()

    def update(self) -> None:
        yesterday = (datetime.datetime.now() - datetime.timedelta(days=1)).strftime("%Y-%m-%d")
        today = datetime.datetime.now().strftime("%Y-%m-%d")
        daily_papers = [
            {
                "date": yesterday,
                "papers": requests.get(f"https://huggingface.co/api/daily_papers?date={yesterday}").json(),
            },
            {
                "date": today,
                "papers": requests.get(f"https://huggingface.co/api/daily_papers?date={today}").json(),
            },
        ]

        self.repo.git_pull()
        df = pd.read_csv(self.csv_path, dtype=str).fillna("")
        rows = [row for _, row in df.iterrows()]
        arxiv_ids = {row.arxiv_id for row in rows}

        for d in daily_papers:
            date = d["date"]
            papers = d["papers"]
            for paper in papers:
                arxiv_id = paper["paper"]["id"]
                if arxiv_id in arxiv_ids:
                    continue
                try:
                    github = find_github_links(paper["paper"]["summary"])
                except RuntimeError as e:
                    print(e)
                    continue
                rows.append(
                    pd.Series(
                        {
                            "date": date,
                            "arxiv_id": arxiv_id,
                            "github": github,
                        }
                    )
                )
        df = pd.DataFrame(rows).reset_index(drop=True)
        df.to_csv(self.csv_path, index=False)

    def push(self) -> None:
        self.repo.push_to_hub()


class UpdateScheduler:
    def __init__(self, space_id: str, cron_hour: str, cron_minute: str, cron_second: str = "0"):
        self.space_restarter = SpaceRestarter(space_id=space_id)
        self.repo_updater = RepoUpdater(repo_id=space_id, repo_type="space")

        self.scheduler = BackgroundScheduler()
        self.scheduler.add_job(
            func=self._update,
            trigger="cron",
            hour=cron_hour,
            minute=cron_minute,
            second=cron_second,
            timezone="UTC",
        )

    def _update(self) -> None:
        self.repo_updater.update()
        if self.repo_updater.repo.is_repo_clean():
            self.space_restarter.restart()
        else:
            self.repo_updater.push()

    def start(self) -> None:
        self.scheduler.start()