Spaces:
Running
Running
# coding=utf-8 | |
# Copyright 2023-present, the HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Contains `WebhooksServer` and `webhook_endpoint` to create a webhook server easily.""" | |
import atexit | |
import inspect | |
import os | |
from functools import wraps | |
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional | |
from .utils import experimental, is_fastapi_available, is_gradio_available | |
if TYPE_CHECKING: | |
import gradio as gr | |
from fastapi import Request | |
if is_fastapi_available(): | |
from fastapi import FastAPI, Request | |
from fastapi.responses import JSONResponse | |
else: | |
# Will fail at runtime if FastAPI is not available | |
FastAPI = Request = JSONResponse = None # type: ignore [misc, assignment] | |
_global_app: Optional["WebhooksServer"] = None | |
_is_local = os.environ.get("SPACE_ID") is None | |
class WebhooksServer: | |
""" | |
The [`WebhooksServer`] class lets you create an instance of a Gradio app that can receive Huggingface webhooks. | |
These webhooks can be registered using the [`~WebhooksServer.add_webhook`] decorator. Webhook endpoints are added to | |
the app as a POST endpoint to the FastAPI router. Once all the webhooks are registered, the `launch` method has to be | |
called to start the app. | |
It is recommended to accept [`WebhookPayload`] as the first argument of the webhook function. It is a Pydantic | |
model that contains all the information about the webhook event. The data will be parsed automatically for you. | |
Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your | |
WebhooksServer and deploy it on a Space. | |
<Tip warning={true}> | |
`WebhooksServer` is experimental. Its API is subject to change in the future. | |
</Tip> | |
<Tip warning={true}> | |
You must have `gradio` installed to use `WebhooksServer` (`pip install --upgrade gradio`). | |
</Tip> | |
Args: | |
ui (`gradio.Blocks`, optional): | |
A Gradio UI instance to be used as the Space landing page. If `None`, a UI displaying instructions | |
about the configured webhooks is created. | |
webhook_secret (`str`, optional): | |
A secret key to verify incoming webhook requests. You can set this value to any secret you want as long as | |
you also configure it in your [webhooks settings panel](https://huggingface.co/settings/webhooks). You | |
can also set this value as the `WEBHOOK_SECRET` environment variable. If no secret is provided, the | |
webhook endpoints are opened without any security. | |
Example: | |
```python | |
import gradio as gr | |
from huggingface_hub import WebhooksServer, WebhookPayload | |
with gr.Blocks() as ui: | |
... | |
app = WebhooksServer(ui=ui, webhook_secret="my_secret_key") | |
@app.add_webhook("/say_hello") | |
async def hello(payload: WebhookPayload): | |
return {"message": "hello"} | |
app.launch() | |
``` | |
""" | |
def __new__(cls, *args, **kwargs) -> "WebhooksServer": | |
if not is_gradio_available(): | |
raise ImportError( | |
"You must have `gradio` installed to use `WebhooksServer`. Please run `pip install --upgrade gradio`" | |
" first." | |
) | |
if not is_fastapi_available(): | |
raise ImportError( | |
"You must have `fastapi` installed to use `WebhooksServer`. Please run `pip install --upgrade fastapi`" | |
" first." | |
) | |
return super().__new__(cls) | |
def __init__( | |
self, | |
ui: Optional["gr.Blocks"] = None, | |
webhook_secret: Optional[str] = None, | |
) -> None: | |
self._ui = ui | |
self.webhook_secret = webhook_secret or os.getenv("WEBHOOK_SECRET") | |
self.registered_webhooks: Dict[str, Callable] = {} | |
_warn_on_empty_secret(self.webhook_secret) | |
def add_webhook(self, path: Optional[str] = None) -> Callable: | |
""" | |
Decorator to add a webhook to the [`WebhooksServer`] server. | |
Args: | |
path (`str`, optional): | |
The URL path to register the webhook function. If not provided, the function name will be used as the | |
path. In any case, all webhooks are registered under `/webhooks`. | |
Raises: | |
ValueError: If the provided path is already registered as a webhook. | |
Example: | |
```python | |
from huggingface_hub import WebhooksServer, WebhookPayload | |
app = WebhooksServer() | |
@app.add_webhook | |
async def trigger_training(payload: WebhookPayload): | |
if payload.repo.type == "dataset" and payload.event.action == "update": | |
# Trigger a training job if a dataset is updated | |
... | |
app.launch() | |
``` | |
""" | |
# Usage: directly as decorator. Example: `@app.add_webhook` | |
if callable(path): | |
# If path is a function, it means it was used as a decorator without arguments | |
return self.add_webhook()(path) | |
# Usage: provide a path. Example: `@app.add_webhook(...)` | |
def _inner_post(*args, **kwargs): | |
func = args[0] | |
abs_path = f"/webhooks/{(path or func.__name__).strip('/')}" | |
if abs_path in self.registered_webhooks: | |
raise ValueError(f"Webhook {abs_path} already exists.") | |
self.registered_webhooks[abs_path] = func | |
return _inner_post | |
def launch(self, prevent_thread_lock: bool = False, **launch_kwargs: Any) -> None: | |
"""Launch the Gradio app and register webhooks to the underlying FastAPI server. | |
Input parameters are forwarded to Gradio when launching the app. | |
""" | |
ui = self._ui or self._get_default_ui() | |
# Start Gradio App | |
# - as non-blocking so that webhooks can be added afterwards | |
# - as shared if launch locally (to debug webhooks) | |
launch_kwargs.setdefault("share", _is_local) | |
self.fastapi_app, _, _ = ui.launch(prevent_thread_lock=True, **launch_kwargs) | |
# Register webhooks to FastAPI app | |
for path, func in self.registered_webhooks.items(): | |
# Add secret check if required | |
if self.webhook_secret is not None: | |
func = _wrap_webhook_to_check_secret(func, webhook_secret=self.webhook_secret) | |
# Add route to FastAPI app | |
self.fastapi_app.post(path)(func) | |
# Print instructions and block main thread | |
space_host = os.environ.get("SPACE_HOST") | |
url = "https://" + space_host if space_host is not None else (ui.share_url or ui.local_url) | |
url = url.strip("/") | |
message = "\nWebhooks are correctly setup and ready to use:" | |
message += "\n" + "\n".join(f" - POST {url}{webhook}" for webhook in self.registered_webhooks) | |
message += "\nGo to https://huggingface.co/settings/webhooks to setup your webhooks." | |
print(message) | |
if not prevent_thread_lock: | |
ui.block_thread() | |
def _get_default_ui(self) -> "gr.Blocks": | |
"""Default UI if not provided (lists webhooks and provides basic instructions).""" | |
import gradio as gr | |
with gr.Blocks() as ui: | |
gr.Markdown("# This is an app to process 🤗 Webhooks") | |
gr.Markdown( | |
"Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on" | |
" specific repos or to all repos belonging to particular set of users/organizations (not just your" | |
" repos, but any repo). Check out this [guide](https://huggingface.co/docs/hub/webhooks) to get to" | |
" know more about webhooks on the Huggingface Hub." | |
) | |
gr.Markdown( | |
f"{len(self.registered_webhooks)} webhook(s) are registered:" | |
+ "\n\n" | |
+ "\n ".join( | |
f"- [{webhook_path}]({_get_webhook_doc_url(webhook.__name__, webhook_path)})" | |
for webhook_path, webhook in self.registered_webhooks.items() | |
) | |
) | |
gr.Markdown( | |
"Go to https://huggingface.co/settings/webhooks to setup your webhooks." | |
+ "\nYou app is running locally. Please look at the logs to check the full URL you need to set." | |
if _is_local | |
else ( | |
"\nThis app is running on a Space. You can find the corresponding URL in the options menu" | |
" (top-right) > 'Embed the Space'. The URL looks like 'https://{username}-{repo_name}.hf.space'." | |
) | |
) | |
return ui | |
def webhook_endpoint(path: Optional[str] = None) -> Callable: | |
"""Decorator to start a [`WebhooksServer`] and register the decorated function as a webhook endpoint. | |
This is a helper to get started quickly. If you need more flexibility (custom landing page or webhook secret), | |
you can use [`WebhooksServer`] directly. You can register multiple webhook endpoints (to the same server) by using | |
this decorator multiple times. | |
Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your | |
server and deploy it on a Space. | |
<Tip warning={true}> | |
`webhook_endpoint` is experimental. Its API is subject to change in the future. | |
</Tip> | |
<Tip warning={true}> | |
You must have `gradio` installed to use `webhook_endpoint` (`pip install --upgrade gradio`). | |
</Tip> | |
Args: | |
path (`str`, optional): | |
The URL path to register the webhook function. If not provided, the function name will be used as the path. | |
In any case, all webhooks are registered under `/webhooks`. | |
Examples: | |
The default usage is to register a function as a webhook endpoint. The function name will be used as the path. | |
The server will be started automatically at exit (i.e. at the end of the script). | |
```python | |
from huggingface_hub import webhook_endpoint, WebhookPayload | |
@webhook_endpoint | |
async def trigger_training(payload: WebhookPayload): | |
if payload.repo.type == "dataset" and payload.event.action == "update": | |
# Trigger a training job if a dataset is updated | |
... | |
# Server is automatically started at the end of the script. | |
``` | |
Advanced usage: register a function as a webhook endpoint and start the server manually. This is useful if you | |
are running it in a notebook. | |
```python | |
from huggingface_hub import webhook_endpoint, WebhookPayload | |
@webhook_endpoint | |
async def trigger_training(payload: WebhookPayload): | |
if payload.repo.type == "dataset" and payload.event.action == "update": | |
# Trigger a training job if a dataset is updated | |
... | |
# Start the server manually | |
trigger_training.launch() | |
``` | |
""" | |
if callable(path): | |
# If path is a function, it means it was used as a decorator without arguments | |
return webhook_endpoint()(path) | |
def _inner(func: Callable) -> Callable: | |
app = _get_global_app() | |
app.add_webhook(path)(func) | |
if len(app.registered_webhooks) == 1: | |
# Register `app.launch` to run at exit (only once) | |
atexit.register(app.launch) | |
def _launch_now(): | |
# Run the app directly (without waiting atexit) | |
atexit.unregister(app.launch) | |
app.launch() | |
func.launch = _launch_now # type: ignore | |
return func | |
return _inner | |
def _get_global_app() -> WebhooksServer: | |
global _global_app | |
if _global_app is None: | |
_global_app = WebhooksServer() | |
return _global_app | |
def _warn_on_empty_secret(webhook_secret: Optional[str]) -> None: | |
if webhook_secret is None: | |
print("Webhook secret is not defined. This means your webhook endpoints will be open to everyone.") | |
print( | |
"To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: " | |
"\n\t`app = WebhooksServer(webhook_secret='my_secret', ...)`" | |
) | |
print( | |
"For more details about webhook secrets, please refer to" | |
" https://huggingface.co/docs/hub/webhooks#webhook-secret." | |
) | |
else: | |
print("Webhook secret is correctly defined.") | |
def _get_webhook_doc_url(webhook_name: str, webhook_path: str) -> str: | |
"""Returns the anchor to a given webhook in the docs (experimental)""" | |
return "/docs#/default/" + webhook_name + webhook_path.replace("/", "_") + "_post" | |
def _wrap_webhook_to_check_secret(func: Callable, webhook_secret: str) -> Callable: | |
"""Wraps a webhook function to check the webhook secret before calling the function. | |
This is a hacky way to add the `request` parameter to the function signature. Since FastAPI based itself on route | |
parameters to inject the values to the function, we need to hack the function signature to retrieve the `Request` | |
object (and hence the headers). A far cleaner solution would be to use a middleware. However, since | |
`fastapi==0.90.1`, a middleware cannot be added once the app has started. And since the FastAPI app is started by | |
Gradio internals (and not by us), we cannot add a middleware. | |
This method is called only when a secret has been defined by the user. If a request is sent without the | |
"x-webhook-secret", the function will return a 401 error (unauthorized). If the header is sent but is incorrect, | |
the function will return a 403 error (forbidden). | |
Inspired by https://stackoverflow.com/a/33112180. | |
""" | |
initial_sig = inspect.signature(func) | |
async def _protected_func(request: Request, **kwargs): | |
request_secret = request.headers.get("x-webhook-secret") | |
if request_secret is None: | |
return JSONResponse({"error": "x-webhook-secret header not set."}, status_code=401) | |
if request_secret != webhook_secret: | |
return JSONResponse({"error": "Invalid webhook secret."}, status_code=403) | |
# Inject `request` in kwargs if required | |
if "request" in initial_sig.parameters: | |
kwargs["request"] = request | |
# Handle both sync and async routes | |
if inspect.iscoroutinefunction(func): | |
return await func(**kwargs) | |
else: | |
return func(**kwargs) | |
# Update signature to include request | |
if "request" not in initial_sig.parameters: | |
_protected_func.__signature__ = initial_sig.replace( # type: ignore | |
parameters=( | |
inspect.Parameter(name="request", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request), | |
) | |
+ tuple(initial_sig.parameters.values()) | |
) | |
# Return protected route | |
return _protected_func | |