Spaces:
Running
Running
import os | |
import uuid | |
import warnings | |
from concurrent.futures import ThreadPoolExecutor | |
from functools import wraps | |
from pathlib import Path | |
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union | |
import gradio as gr | |
from fastapi import BackgroundTasks, HTTPException, Response, status | |
from huggingface_hub import ( | |
SpaceHardware, | |
SpaceStorage, | |
WebhookPayload, | |
WebhooksServer, | |
add_space_secret, | |
add_space_variable, | |
comment_discussion, | |
create_repo, | |
delete_repo, | |
get_discussion_details, | |
get_repo_discussions, | |
get_space_runtime, | |
get_space_variables, | |
repo_exists, | |
request_space_hardware, | |
request_space_storage, | |
snapshot_download, | |
space_info, | |
upload_folder, | |
) | |
from huggingface_hub.repocard import RepoCard | |
from huggingface_hub.utils import ( | |
RepositoryNotFoundError, | |
build_hf_headers, | |
get_session, | |
hf_raise_for_status, | |
) | |
from requests import HTTPError | |
SPACE_ID = os.environ.get("SPACE_ID") | |
IS_EPHEMERAL_SPACE = SPACE_ID is not None and "-ci-pr-" in SPACE_ID | |
WEBHOOK_SECRET = os.environ.get("SPACE_CI_SECRET") | |
if SPACE_ID is not None: # If running in a Space (i.e. not locally) | |
if WEBHOOK_SECRET is None: # No secret set yet => generate one => restart space | |
WEBHOOK_SECRET = str(uuid.uuid4()) | |
add_space_secret( | |
repo_id=SPACE_ID, | |
key="SPACE_CI_SECRET", | |
value=WEBHOOK_SECRET, | |
description="This value is used by the SpaceCI. It is automatically generated and should not be changed.", | |
) | |
EPHEMERAL_SPACES_CONFIG: Dict[str, Any] = {} | |
# Draft and open PRs are considered as active (in opposition to closed and merged PRs) | |
ACTIVE_PR_STATUS = ("draft", "open") | |
def enable_space_ci() -> None: | |
"""Enable Space CI for the current Space based on config from the README.md file. | |
Example: | |
```py | |
import gradio as gr | |
from gradio_space_ci import enable_space_ci | |
enable_space_ci() | |
with gr.Blocks() as demo: | |
... | |
demo.launch() | |
``` | |
""" | |
if SPACE_ID is None: | |
print("Not in a Space: Space CI disabled.") | |
return | |
if IS_EPHEMERAL_SPACE: | |
print("In an ephemeral Space: Space CI disabled.") | |
return | |
card = RepoCard.load(repo_id_or_path=SPACE_ID, repo_type="space") | |
config = card.data.get("space_ci", {}) | |
print(f"Enabling Space CI with config from README: {config}") | |
old_launch = gr.Blocks.launch | |
def new_launch(self: gr.Blocks, *args, **kwargs) -> None: | |
server = configure_space_ci( | |
blocks=self, | |
trusted_authors=config.get("trusted_authors"), | |
private=config.get("private", "auto"), | |
variables=config.get("variables", "auto"), | |
secrets=config.get("secrets"), | |
hardware=config.get("hardware"), | |
storage=config.get("storage"), | |
) | |
# De-monkey patch gradio (otherwise it will be called recursively) | |
gr.Blocks.launch = old_launch | |
server.launch(*args, *kwargs) | |
# Monkey patch gradio | |
gr.Blocks.launch = new_launch | |
def configure_space_ci( | |
blocks: Optional["gr.Blocks"] = None, | |
trusted_authors: Optional[List[str]] = None, | |
private: Union[bool, Literal["auto"]] = "auto", | |
variables: Union[Dict[str, str], Literal["auto"]] = "auto", | |
secrets: Optional[List[str]] = None, | |
hardware: Union[SpaceHardware, Literal["auto"], None] = None, | |
storage: Union[SpaceStorage, Literal["auto"], None] = None, | |
) -> WebhooksServer: | |
if SPACE_ID is None or IS_EPHEMERAL_SPACE: | |
# Runs locally => don't configure webhook | |
# Runs in an ephemeral Space => don't configure webhook | |
return WebhooksServer(ui=blocks) | |
# Authors | |
trusted_authors = trusted_authors or [] | |
namespace = SPACE_ID.split("/")[0] | |
try: # Check if namespace is an organization => in this case all members are allowed to trigger CI by default | |
response = get_session().get( | |
f"https://huggingface.co/api/organizations/{namespace}/members", headers=build_hf_headers() | |
) | |
response.raise_for_status() | |
trusted_authors += [user["user"] for user in response.json()] | |
except Exception: # Otherwise, it's a single user => only this user is allowed to trigger CI by default | |
trusted_authors += [namespace] | |
trusted_authors = sorted(set(trusted_authors)) | |
EPHEMERAL_SPACES_CONFIG["trusted_authors"] = trusted_authors | |
# Private | |
if private == "auto": | |
private = space_info(SPACE_ID).private | |
EPHEMERAL_SPACES_CONFIG["private"] = private | |
# Variables | |
if variables == "auto": | |
variables = {value.key: value.value for value in get_space_variables(SPACE_ID).values()} | |
EPHEMERAL_SPACES_CONFIG["variables"] = variables | |
# Secrets | |
secrets_with_values: Dict[str, str] = {} | |
if secrets is not None: | |
for secret in secrets: | |
secret_value = os.environ.get(secret) | |
if secret_value is None: | |
warnings.warn(f"Secret {secret} not found in environment variables. Will skip it in ephemeral Space.") | |
continue | |
secrets_with_values[secret] = secret_value | |
EPHEMERAL_SPACES_CONFIG["secrets"] = secrets_with_values | |
# Hardware and storage | |
if hardware == "auto" or storage == "auto": | |
runtime = get_space_runtime(SPACE_ID) | |
if hardware == "auto": | |
hardware = runtime.hardware | |
if storage == "auto": | |
storage = runtime.storage | |
EPHEMERAL_SPACES_CONFIG["hardware"] = hardware | |
EPHEMERAL_SPACES_CONFIG["storage"] = storage | |
# Summary | |
print( | |
"Ephemeral Spaces config:" | |
f"\n - trusted authors: {trusted_authors}" | |
f"\n - private: {private}" | |
f"\n - secrets: {', '.join(sorted(secrets_with_values.keys()))}" | |
f"\n - variables: {variables}" | |
f"\n - storage: {storage}" | |
f"\n - hardware: {hardware}" | |
) | |
# Configure webhook | |
server = WebhooksServer(ui=blocks, webhook_secret=WEBHOOK_SECRET) | |
server.add_webhook()(trigger_ci_on_pr) | |
configure_webhook_on_hub() | |
# Recover missed webhooks (loop through PRs in the background) | |
background_pool.submit(recover_after_restart, space_id=SPACE_ID) | |
return server | |
### | |
# Recovery logic | |
### | |
# Check if there are any PRs that need to be synced. | |
# We might have missed some events while the server was down. | |
# => called once at startup (see configure_space_ci) | |
background_pool = ThreadPoolExecutor(max_workers=1) | |
def recover_after_restart(space_id: str) -> None: | |
print("Looping through PRs to check if any needs to be synced.") | |
for discussion in get_repo_discussions(repo_id=space_id, repo_type="space", discussion_type="pull_request"): | |
if discussion.status in ACTIVE_PR_STATUS: | |
if not is_pr_synced(space_id=space_id, pr_num=discussion.num): | |
# Found a PR that is not yet synced | |
print(f"Recovery. Found an open PR that is not synced: {discussion.url}. Syncing it.") | |
background_pool.submit(sync_ci_space, space_id=space_id, pr_num=discussion.num) | |
if discussion.status == "merged" or discussion.status == "closed": | |
ci_space_id = _get_ci_space_id(space_id=space_id, pr_num=discussion.num) | |
if repo_exists(repo_id=ci_space_id, repo_type="space"): | |
# Found a PR for which the CI space has not been deleted | |
print(f"Recovery. Found a closed PR with an active CI space: {discussion.url}. Deleting it.") | |
background_pool.submit(delete_ci_space, space_id=space_id, pr_num=discussion.num) | |
### | |
# Define webhook on the Hub logic | |
### | |
def configure_webhook_on_hub(): | |
url = "https://" + os.environ.get("SPACE_HOST").strip("/") + "/webhooks/trigger_ci_on_pr" | |
# Check if webhook already exists | |
webhooks = list_webhooks() | |
for webhook in webhooks: | |
if webhook["url"] == url: | |
print("Webhook already configured") | |
return | |
# If not => create it | |
create_webhook( | |
watched=[{"type": "space", "name": SPACE_ID}], url=url, domains=["repo", "discussion"], secret=WEBHOOK_SECRET | |
) | |
print("New webhook already configured!") | |
### | |
# Webhook logic | |
### | |
async def trigger_ci_on_pr(payload: WebhookPayload, task_queue: BackgroundTasks): | |
if payload.repo.type != "space": | |
raise HTTPException(400, f"Must be a Space, not {payload.repo.type}") | |
space_id = payload.repo.name | |
has_task = False | |
if ( | |
# Means "a new PR has been opened" | |
payload.event.scope.startswith("discussion") | |
and payload.event.action == "create" | |
and payload.discussion is not None | |
and payload.discussion.isPullRequest | |
and payload.discussion.status in ACTIVE_PR_STATUS | |
): | |
if not is_pr_synced(space_id=space_id, pr_num=payload.discussion.num): | |
# New PR! Sync task scheduled | |
task_queue.add_task(sync_ci_space, space_id=space_id, pr_num=payload.discussion.num) | |
has_task = True | |
elif ( | |
# Means "a PR has been merged or closed" | |
payload.event.scope.startswith("discussion") | |
and payload.event.action == "update" | |
and payload.discussion is not None | |
and payload.discussion.isPullRequest | |
and (payload.discussion.status == "merged" or payload.discussion.status == "closed") | |
): | |
task_queue.add_task( | |
delete_ci_space, | |
space_id=space_id, | |
pr_num=payload.discussion.num, | |
) | |
has_task = True | |
elif ( | |
# Means "some content has been pushed to the Space" (any branch) | |
payload.event.scope.startswith("repo.content") and payload.event.action == "update" | |
): | |
# New repo change. Is it a commit on a PR? | |
# => loop through all PRs and check if new changes happened | |
for discussion in get_repo_discussions(repo_id=space_id, repo_type="space"): | |
if discussion.is_pull_request and discussion.status in ACTIVE_PR_STATUS: | |
if not is_pr_synced(space_id=space_id, pr_num=discussion.num): | |
# Found a PR that is not yet synced | |
task_queue.add_task(sync_ci_space, space_id=space_id, pr_num=discussion.num) | |
has_task = True | |
if has_task: | |
return Response("Task scheduled to sync/delete Space", status_code=status.HTTP_202_ACCEPTED) | |
else: | |
return Response("No task scheduled", status_code=status.HTTP_200_OK) | |
### | |
# Internal logic | |
### | |
def is_pr_synced(space_id: str, pr_num: int) -> bool: | |
# What is the last synced commit for this PR? | |
ci_space_id = _get_ci_space_id(space_id=space_id, pr_num=pr_num) | |
try: | |
card = RepoCard.load(repo_id_or_path=ci_space_id, repo_type="space") | |
last_synced_sha = getattr(card.data, "synced_sha", None) | |
except HTTPError: | |
return False | |
# What is the last commit id for this PR? | |
info = space_info(repo_id=space_id, revision=f"refs/pr/{pr_num}") | |
last_pr_sha = info.sha | |
# Is it up to date ? | |
return last_synced_sha == last_pr_sha | |
def sync_ci_space(space_id: str, pr_num: int) -> None: | |
print(f"New task: sync ephemeral env for {space_id} (PR {pr_num})") | |
if is_pr_synced(space_id=space_id, pr_num=pr_num): | |
print("Already synced. Nothing to do.") | |
return | |
ci_space_id = _get_ci_space_id(space_id=space_id, pr_num=pr_num) | |
# Create a temporary space for CI if didn't exist | |
is_new = create_ephemeral_space(space_id=space_id, pr_num=pr_num) | |
# Configure ephemeral Space if trusted author | |
is_configured = False | |
if is_new: | |
is_configured = configure_ephemeral_space(space_id=space_id, pr_num=pr_num) | |
# Download space codebase from PR revision | |
snapshot_path = Path(snapshot_download(repo_id=space_id, revision=f"refs/pr/{pr_num}", repo_type="space")) | |
# Overwrite README file in cache (/!\) | |
readme_path = snapshot_path / "README.md" | |
card = RepoCard.load(readme_path) | |
setattr(card.data, "synced_sha", snapshot_path.name) # latest sha | |
card.data.title = f"{card.data.title} (ephemeral #{pr_num})" | |
card.save(readme_path) | |
# Sync space codebase with PR revision | |
upload_folder( | |
repo_id=ci_space_id, | |
repo_type="space", | |
commit_message=f"Sync CI Space with PR {pr_num}.", | |
folder_path=snapshot_path, | |
delete_patterns="*", | |
) | |
# Delete readme file from cache (just in case) | |
readme_path.unlink(missing_ok=True) | |
# Post a comment on the PR | |
if is_new and is_configured: | |
notify_pr(space_id=space_id, pr_num=pr_num, action="created_and_configured") | |
elif is_new: | |
notify_pr(space_id=space_id, pr_num=pr_num, action="created_not_configured") | |
else: | |
notify_pr(space_id=space_id, pr_num=pr_num, action="updated") | |
def create_ephemeral_space(space_id: str, pr_num: int) -> bool: | |
# Config values | |
ci_space_id = _get_ci_space_id(space_id=space_id, pr_num=pr_num) | |
private: bool = EPHEMERAL_SPACES_CONFIG["private"] | |
# Create space | |
try: | |
create_repo( | |
ci_space_id, | |
repo_type="space", | |
space_sdk="docker", # Will be overwritten by sync | |
private=private, | |
exist_ok=False, | |
) | |
return True | |
except HTTPError as err: | |
if err.response is not None and err.response.status_code == 409: # already exists | |
return False | |
else: | |
raise | |
def configure_ephemeral_space(space_id: str, pr_num: int) -> bool: | |
# Config values | |
ci_space_id = _get_ci_space_id(space_id=space_id, pr_num=pr_num) | |
trusted_authors: List[str] = EPHEMERAL_SPACES_CONFIG["trusted_authors"] | |
variables: Dict[str, str] = EPHEMERAL_SPACES_CONFIG["variables"] | |
secrets: Dict[str, str] = EPHEMERAL_SPACES_CONFIG["secrets"] | |
hardware: Optional[SpaceHardware] = EPHEMERAL_SPACES_CONFIG["hardware"] | |
storage: Optional[SpaceHardware] = EPHEMERAL_SPACES_CONFIG["storage"] | |
# Check if trusted author | |
details = get_discussion_details(repo_id=space_id, repo_type="space", discussion_num=pr_num) | |
if details.author not in trusted_authors: | |
return False # not a trusted author => do NOT set secrets, hardware, storage, etc. | |
# Configure space | |
for key, value in variables.items(): | |
add_space_variable(ci_space_id, key, value) | |
for key, value in secrets.items(): | |
add_space_secret(ci_space_id, key, value) | |
# Request hardware/storage for space | |
if hardware is not None and hardware != SpaceHardware.CPU_BASIC: | |
request_space_hardware(ci_space_id, hardware, sleep_time=5 * 60) # sleep after 5min on PR Spaces with GPU | |
if storage is not None: | |
request_space_storage(ci_space_id, storage) | |
return True | |
def delete_ci_space(space_id: str, pr_num: int) -> None: | |
print(f"New task: delete ephemeral env for {space_id} (PR {pr_num})") | |
# Delete | |
ci_space_id = _get_ci_space_id(space_id=space_id, pr_num=pr_num) | |
try: | |
delete_repo(repo_id=ci_space_id, repo_type="space") | |
except RepositoryNotFoundError: | |
# Repo did not exist: no need to notify | |
return | |
# Notify about deletion | |
notify_pr(space_id=space_id, pr_num=pr_num, action="deleted") | |
def notify_pr( | |
space_id: str, | |
pr_num: int, | |
action: Literal["created_not_configured", "created_and_configured", "updated", "deleted"], | |
) -> None: | |
ci_space_id = _get_ci_space_id(space_id=space_id, pr_num=pr_num) | |
if action == "created_not_configured": | |
comment = NOTIFICATION_TEMPLATE_CREATED_NOT_CONFIGURED.format(ci_space_id=ci_space_id) | |
elif action == "created_and_configured": | |
comment = NOTIFICATION_TEMPLATE_CREATED_AND_CONFIGURED.format(ci_space_id=ci_space_id) | |
elif action == "updated": | |
comment = NOTIFICATION_TEMPLATE_UPDATED.format(ci_space_id=ci_space_id) | |
elif action == "deleted": | |
comment = NOTIFICATION_TEMPLATE_DELETED | |
else: | |
raise ValueError(f"Status {action} not handled.") | |
comment_discussion(repo_id=space_id, repo_type="space", discussion_num=pr_num, comment=comment) | |
def _get_ci_space_id(space_id: str, pr_num: int) -> str: | |
return f"{space_id}-ci-pr-{pr_num}" | |
NOTIFICATION_TEMPLATE_CREATED_AND_CONFIGURED = """\ | |
Following the creation of this PR, an ephemeral Space [{ci_space_id}](https://huggingface.co/spaces/{ci_space_id}) has been started. Any changes pushed to this PR will be synced with the test Space. | |
Since this PR has been created by a trusted author, the ephemeral Space has been configured with the correct hardware, storage, and secrets. | |
_(This is an automated message.)_ | |
""" | |
NOTIFICATION_TEMPLATE_CREATED_NOT_CONFIGURED = """\ | |
Following the creation of this PR, an ephemeral Space [{ci_space_id}](https://huggingface.co/spaces/{ci_space_id}) has been started. Any changes pushed to this PR will be synced with the test Space. | |
Since this PR has not been created by a trusted author, the ephemeral Space has not been configured with the correct hardware, storage, and secrets. An admin must configure it manually. | |
_(This is an automated message.)_ | |
""" | |
NOTIFICATION_TEMPLATE_UPDATED = """\ | |
Following new commits that happened in this PR, the ephemeral Space [{ci_space_id}](https://huggingface.co/spaces/{ci_space_id}) has been updated. | |
_(This is an automated message.)_ | |
""" | |
NOTIFICATION_TEMPLATE_DELETED = """\ | |
PR is now merged/closed. The ephemeral Space has been deleted. | |
_(This is an automated message.)_ | |
""" | |
### TO MOVE TO ITS OWN MODULE | |
# Taken from https://github.com/huggingface/huggingface_hub/issues/1808#issuecomment-1802341663 | |
headers = build_hf_headers() | |
class WatchedItem(TypedDict): | |
# Examples: | |
# {"type": "user", "name": "julien-c"} | |
# {"type": "org", "name": "HuggingFaceH4"} | |
# {"type": "model", "name": "HuggingFaceH4/zephyr-7b-beta"} | |
# {"type": "dataset", "name": "HuggingFaceH4/ultrachat_200k"} | |
# {"type": "space", "name": "HuggingFaceH4/zephyr-chat"} | |
type: Literal["model", "dataset", "space", "org", "user"] | |
name: str | |
# Do you want to subscribe to repo updates (code changes), discussion updates (issues, PRs, comments), or both? | |
DOMAIN_T = Literal["repo", "discussion"] | |
def get_webhook(webhook_id: str) -> Dict: | |
"""Get a webhook by its id.""" | |
response = get_session().get(f"https://huggingface.co/api/settings/webhooks/{webhook_id}", headers=headers) | |
hf_raise_for_status(response) | |
return response.json() | |
def list_webhooks() -> List[Dict]: | |
"""List all configured webhooks.""" | |
response = get_session().get("https://huggingface.co/api/settings/webhooks", headers=headers) | |
hf_raise_for_status(response) | |
return response.json() | |
def create_webhook(watched: List[WatchedItem], url: str, domains: List[DOMAIN_T], secret: Optional[str]) -> Dict: | |
"""Create a new webhook. | |
Args: | |
watched (List[WatchedItem]): | |
List of items to watch. It an be users, orgs, models, datasets or spaces. | |
See `WatchedItem` for more details. | |
url (str): | |
URL to send the payload to. | |
domains (List[Literal["repo", "discussion"]]): | |
List of domains to watch. It can be "repo", "discussion" or both. | |
secret (str, optional): | |
Secret to use to sign the payload. | |
Returns: | |
dict: The created webhook. | |
Example: | |
```python | |
>>> payload = create_webhook( | |
... watched=[{"type": "user", "name": "julien-c"}, {"type": "org", "name": "HuggingFaceH4"}], | |
... url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", | |
... domains=["repo", "discussion"], | |
... secret="my-secret", | |
... ) | |
{ | |
"webhook": { | |
"id": "654bbbc16f2ec14d77f109cc", | |
"watched": [{"type": "user", "name": "julien-c"}, {"type": "org", "name": "HuggingFaceH4"}], | |
"url": "https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", | |
"secret": "my-secret", | |
"domains": ["repo", "discussion"], | |
"disabled": False, | |
}, | |
} | |
``` | |
""" | |
print("Creating webhook") | |
print({"watched": watched, "url": url, "domains": domains, "secret": str(type(secret))}) | |
response = get_session().post( | |
"https://huggingface.co/api/settings/webhooks", | |
json={"watched": watched, "url": url, "domains": domains, "secret": secret}, | |
headers=headers, | |
) | |
hf_raise_for_status(response) | |
return response.json() | |
def update_webhook( | |
webhook_id: str, watched: List[WatchedItem], url: str, domains: List[DOMAIN_T], secret: Optional[str] | |
) -> Dict: | |
"""Update an existing webhook. | |
Exact same usage as `create_webhook` but you must know the `webhook_id`. | |
All fields are updated. | |
""" | |
response = get_session().post( | |
f"https://huggingface.co/api/settings/webhooks/{webhook_id}", | |
json={"watched": watched, "url": url, "domains": domains, "secret": secret}, | |
headers=headers, | |
) | |
hf_raise_for_status(response) | |
return response.json() | |
def enable_webhook(webhook_id: str) -> Dict: | |
"""Enable a webhook (makes it "active").""" | |
response = get_session().post( | |
f"https://huggingface.co/api/settings/webhooks/{webhook_id}/enable", | |
headers=headers, | |
) | |
hf_raise_for_status(response) | |
return response.json() | |
def disable_webhook(webhook_id: str) -> Dict: | |
"""Disable a webhook (makes it "disabled").""" | |
response = get_session().post( | |
f"https://huggingface.co/api/settings/webhooks/{webhook_id}/disable", | |
headers=headers, | |
) | |
hf_raise_for_status(response) | |
return response.json() | |
def delete_webhook(webhook_id: str): | |
"""Delete a webhook.""" | |
response = get_session().delete( | |
f"https://huggingface.co/api/settings/webhooks/{webhook_id}", | |
headers=headers, | |
) | |
hf_raise_for_status(response) | |