webhook / app.py
Wauplin's picture
Wauplin HF staff
update comment
9a9f2de
raw
history blame
8.68 kB
# Taken from https://huggingface.co/spaces/huggingface-projects/auto-retrain
import logging
import os
from pathlib import Path
from typing import Literal, Optional
from fastapi import BackgroundTasks, FastAPI, Header, HTTPException
from fastapi.responses import FileResponse
from huggingface_hub import (
CommitOperationAdd,
CommitOperationDelete,
comment_discussion,
create_commit,
create_repo,
delete_repo,
get_repo_discussions,
snapshot_download,
space_info,
)
from huggingface_hub.repocard import RepoCard
from pydantic import BaseModel
from requests import HTTPError
logger = logging.getLogger(__file__)
WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET")
HF_TOKEN = os.getenv("HF_TOKEN")
class WebhookPayloadEvent(BaseModel):
action: Literal["create", "update", "delete"]
scope: str
class WebhookPayloadRepo(BaseModel):
type: Literal["dataset", "model", "space"]
name: str
private: bool
class WebhookPayloadDiscussion(BaseModel):
num: int
isPullRequest: bool
status: Literal["open", "closed", "merged"]
class WebhookPayload(BaseModel):
event: WebhookPayloadEvent
repo: WebhookPayloadRepo
discussion: Optional[WebhookPayloadDiscussion]
app = FastAPI()
@app.get("/")
async def home():
return FileResponse("home.html")
@app.post("/webhook")
async def post_webhook(
payload: WebhookPayload,
task_queue: BackgroundTasks,
x_webhook_secret: Optional[str] = Header(default=None),
):
logger.info("Received new hook!")
if x_webhook_secret is None:
logger.warning("HTTP 401: No webhook secret")
raise HTTPException(401)
if x_webhook_secret != WEBHOOK_SECRET:
logger.warning("HTTP 403: wrong webhook secret")
raise HTTPException(403)
if payload.repo.type != "space":
logger.warning("HTTP 400: not a space")
raise HTTPException(400, f"Must be a Space, not {payload.repo.type}")
space_id = payload.repo.name
if (
payload.event.scope.startswith("discussion")
and payload.event.action == "create"
and payload.discussion is not None
and payload.discussion.isPullRequest
and payload.discussion.status == "open"
):
# New PR!
if not is_pr_synced(space_id=space_id, pr_num=payload.discussion.num):
task_queue.add_task(
sync_ci_space,
space_id=space_id,
pr_num=payload.discussion.num,
private=payload.repo.private,
)
logger.info("New PR! Sync task scheduled")
else:
logger.info("New comment on PR but CI space already synced")
elif (
payload.event.scope.startswith("discussion")
and payload.event.action == "update"
and payload.discussion is not None
and payload.discussion.isPullRequest
and (
payload.discussion.status == "merged"
or payload.discussion.status == "closed"
)
):
# PR merged or closed!
task_queue.add_task(
delete_ci_space,
space_id=space_id,
pr_num=payload.discussion.num,
)
logger.info("PR is merged (or closed)! Delete task scheduled")
elif (
payload.event.scope.startswith("repo.content")
and payload.event.action == "update"
):
# New repo change. Is it a commit on a PR?
# => loop through all PRs and check if new changes happened
logger.info("New repo content update. Checking PRs state.")
for discussion in get_repo_discussions(
repo_id=space_id, repo_type="space", token=HF_TOKEN
):
if discussion.is_pull_request and discussion.status == "open":
if not is_pr_synced(space_id=space_id, pr_num=discussion.num):
task_queue.add_task(
sync_ci_space,
space_id=space_id,
pr_num=discussion.num,
private=payload.repo.private,
)
logger.info(f"Scheduled update for PR {discussion.num}.")
logger.info(f"Done looping over PRs.")
else:
logger.info(f"Webhook ignored.")
logger.info(f"Done.")
return {"processed": True}
def is_pr_synced(space_id: str, pr_num: int) -> bool:
# What is the last synced commit for this PR?
ci_space_id = _get_ci_space_id(space_id=space_id, pr_num=pr_num)
try:
card = RepoCard.load(
repo_id_or_path=ci_space_id, repo_type="space", token=HF_TOKEN
)
last_synced_sha = getattr(card.data, "synced_sha", None)
except HTTPError:
last_synced_sha = None
# What is the last commit id for this PR?
info = space_info(repo_id=space_id, revision=f"refs/pr/{pr_num}")
last_pr_sha = info.sha
# Is it up to date ?
return last_synced_sha == last_pr_sha
def sync_ci_space(space_id: str, pr_num: int, private: bool) -> None:
# Create a temporary space for CI if didn't exist
ci_space_id = _get_ci_space_id(space_id=space_id, pr_num=pr_num)
try:
create_repo(
ci_space_id,
repo_type="space",
space_sdk="docker",
private=private,
token=HF_TOKEN,
)
is_new = True
except HTTPError as err:
if err.response.status_code == 409: # already exists
is_new = False
else:
raise
# Download space codebase from PR revision
snapshot_path = Path(
snapshot_download(
repo_id=space_id,
revision=f"refs/pr/{pr_num}",
repo_type="space",
token=HF_TOKEN,
)
)
# Sync space codebase with PR revision
operations = [ # little aggressive but works
CommitOperationDelete(".", is_folder=True)
]
for filepath in snapshot_path.glob("**/*"):
if filepath.is_file():
path_in_repo = str(filepath.relative_to(snapshot_path))
# Upload all files without changes except for the README file
if path_in_repo == "README.md":
card = RepoCard.load(filepath)
setattr(card.data, "synced_sha", snapshot_path.name) # latest sha
path_or_fileobj = str(card).encode()
else:
path_or_fileobj = filepath
operations.append(
CommitOperationAdd(
path_in_repo=path_in_repo, path_or_fileobj=path_or_fileobj
)
)
create_commit(
repo_id=ci_space_id,
repo_type="space",
operations=operations,
commit_message=f"Sync CI Space with PR {pr_num}.",
token=HF_TOKEN,
)
# Post a comment on the PR
notify_pr(space_id=space_id, pr_num=pr_num, action="create" if is_new else "update")
def delete_ci_space(space_id: str, pr_num: int) -> None:
# Delete
ci_space_id = _get_ci_space_id(space_id=space_id, pr_num=pr_num)
delete_repo(repo_id=ci_space_id, repo_type="space", token=HF_TOKEN)
# Notify about deletion
notify_pr(space_id=space_id, pr_num=pr_num, action="delete")
def notify_pr(
space_id: str, pr_num: int, action: Literal["create", "update", "delete"]
) -> None:
ci_space_id = _get_ci_space_id(space_id=space_id, pr_num=pr_num)
if action == "create":
comment = NOTIFICATION_TEMPLATE_CREATE.format(ci_space_id=ci_space_id)
elif action == "update":
comment = NOTIFICATION_TEMPLATE_UPDATE.format(ci_space_id=ci_space_id)
elif action == "delete":
comment = NOTIFICATION_TEMPLATE_DELETE
else:
raise ValueError(f"Status {action} not handled.")
comment_discussion(
repo_id=space_id,
repo_type="space",
discussion_num=pr_num,
comment=comment,
token=HF_TOKEN,
)
def _get_ci_space_id(space_id: str, pr_num: int) -> str:
return f"{space_id}-ci-pr-{pr_num}"
NOTIFICATION_TEMPLATE_CREATE = """\
Hey there!
Following the creation of this PR, a temporary test Space [{ci_space_id}](https://huggingface.co/spaces/{ci_space_id}) has been launched.
Any changes pushed to this PR will be synced with the test Space.
(This is an automated message)
"""
NOTIFICATION_TEMPLATE_UPDATE = """\
Hey there!
Following new commits that happened in this PR, the temporary test Space [{ci_space_id}](https://huggingface.co/spaces/{ci_space_id}) has been updated.
(This is an automated message)
"""
NOTIFICATION_TEMPLATE_DELETE = """\
Hey there!
PR is now merged/closed. The temporary test Space has been deleted.
(This is an automated message)
"""