papers / update_scheduler.py
hysts's picture
hysts HF staff
Migrate from yapf to black
9fb4b90
raw history blame
No virus
4.26 kB
import datetime
import pathlib
import re
import tempfile
import pandas as pd
import requests
from apscheduler.schedulers.background import BackgroundScheduler
from huggingface_hub import HfApi, Repository
from huggingface_hub.utils import RepositoryNotFoundError
class SpaceRestarter:
def __init__(self, space_id: str):
self.api = HfApi()
if self.api.get_token_permission() != "write":
raise ValueError("The HF token must have write permission.")
try:
self.api.space_info(repo_id=space_id)
except RepositoryNotFoundError:
raise ValueError("The Space ID does not exist.")
self.space_id = space_id
def restart(self) -> None:
self.api.restart_space(self.space_id)
def find_github_links(summary: str) -> str:
links = re.findall(r"https://github.com/[^/]+/[^/)}, ]+(?:/(?:tree|blob)/[^/]+/[^/)}, ]+)?", summary)
if len(links) == 0:
return ""
if len(links) != 1:
raise RuntimeError(f"Found multiple GitHub links: {links}")
link = links[0]
if link.endswith("."):
link = link[:-1]
link = link.strip()
return link
class RepoUpdater:
def __init__(self, repo_id: str, repo_type: str):
api = HfApi()
if api.get_token_permission() != "write":
raise ValueError("The HF token must have write permission.")
name = api.whoami()["name"]
repo_dir = pathlib.Path(tempfile.tempdir) / repo_id.split("/")[-1] # type: ignore
self.csv_path = repo_dir / "papers.csv"
self.repo = Repository(
local_dir=repo_dir,
clone_from=repo_id,
repo_type=repo_type,
git_user=name,
git_email=f"{name}@users.noreply.huggingface.co",
)
self.repo.git_pull()
def update(self) -> None:
yesterday = (datetime.datetime.now() - datetime.timedelta(days=1)).strftime("%Y-%m-%d")
today = datetime.datetime.now().strftime("%Y-%m-%d")
daily_papers = [
{
"date": yesterday,
"papers": requests.get(f"https://huggingface.co/api/daily_papers?date={yesterday}").json(),
},
{
"date": today,
"papers": requests.get(f"https://huggingface.co/api/daily_papers?date={today}").json(),
},
]
self.repo.git_pull()
df = pd.read_csv(self.csv_path, dtype=str).fillna("")
rows = [row for _, row in df.iterrows()]
arxiv_ids = {row.arxiv_id for row in rows}
for d in daily_papers:
date = d["date"]
papers = d["papers"]
for paper in papers:
arxiv_id = paper["paper"]["id"]
if arxiv_id in arxiv_ids:
continue
try:
github = find_github_links(paper["paper"]["summary"])
except RuntimeError as e:
print(e)
continue
rows.append(
pd.Series(
{
"date": date,
"arxiv_id": arxiv_id,
"github": github,
}
)
)
df = pd.DataFrame(rows).reset_index(drop=True)
df.to_csv(self.csv_path, index=False)
def push(self) -> None:
self.repo.push_to_hub()
class UpdateScheduler:
def __init__(self, space_id: str, cron_hour: str, cron_minute: str, cron_second: str = "0"):
self.space_restarter = SpaceRestarter(space_id=space_id)
self.repo_updater = RepoUpdater(repo_id=space_id, repo_type="space")
self.scheduler = BackgroundScheduler()
self.scheduler.add_job(
func=self._update,
trigger="cron",
hour=cron_hour,
minute=cron_minute,
second=cron_second,
timezone="UTC",
)
def _update(self) -> None:
self.repo_updater.update()
if self.repo_updater.repo.is_repo_clean():
self.space_restarter.restart()
else:
self.repo_updater.push()
def start(self) -> None:
self.scheduler.start()