File size: 15,196 Bytes
254a3c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 |
# coding=utf-8
# Copyright 2023-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains `WebhooksServer` and `webhook_endpoint` to create a webhook server easily."""
import atexit
import inspect
import os
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from .utils import experimental, is_gradio_available
from .utils._deprecation import _deprecate_method
if TYPE_CHECKING:
import gradio as gr
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
_global_app: Optional["WebhooksServer"] = None
_is_local = os.getenv("SYSTEM") != "spaces"
@experimental
class WebhooksServer:
"""
The [`WebhooksServer`] class lets you create an instance of a Gradio app that can receive Huggingface webhooks.
These webhooks can be registered using the [`~WebhooksServer.add_webhook`] decorator. Webhook endpoints are added to
the app as a POST endpoint to the FastAPI router. Once all the webhooks are registered, the `run` method has to be
called to start the app.
It is recommended to accept [`WebhookPayload`] as the first argument of the webhook function. It is a Pydantic
model that contains all the information about the webhook event. The data will be parsed automatically for you.
Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your
WebhooksServer and deploy it on a Space.
<Tip warning={true}>
`WebhooksServer` is experimental. Its API is subject to change in the future.
</Tip>
<Tip warning={true}>
You must have `gradio` installed to use `WebhooksServer` (`pip install --upgrade gradio`).
</Tip>
Args:
ui (`gradio.Blocks`, optional):
A Gradio UI instance to be used as the Space landing page. If `None`, a UI displaying instructions
about the configured webhooks is created.
webhook_secret (`str`, optional):
A secret key to verify incoming webhook requests. You can set this value to any secret you want as long as
you also configure it in your [webhooks settings panel](https://huggingface.co/settings/webhooks). You
can also set this value as the `WEBHOOK_SECRET` environment variable. If no secret is provided, the
webhook endpoints are opened without any security.
Example:
```python
import gradio as gr
from huggingface_hub import WebhooksServer, WebhookPayload
with gr.Blocks() as ui:
...
app = WebhooksServer(ui=ui, webhook_secret="my_secret_key")
@app.add_webhook("/say_hello")
async def hello(payload: WebhookPayload):
return {"message": "hello"}
app.run()
```
"""
def __new__(cls, *args, **kwargs) -> "WebhooksServer":
if not is_gradio_available():
raise ImportError(
"You must have `gradio` installed to use `WebhooksServer`. Please run `pip install --upgrade gradio`"
" first."
)
return super().__new__(cls)
def __init__(
self,
ui: Optional["gr.Blocks"] = None,
webhook_secret: Optional[str] = None,
) -> None:
self._ui = ui
self.webhook_secret = webhook_secret or os.getenv("WEBHOOK_SECRET")
self.registered_webhooks: Dict[str, Callable] = {}
_warn_on_empty_secret(self.webhook_secret)
def add_webhook(self, path: Optional[str] = None) -> Callable:
"""
Decorator to add a webhook to the [`WebhooksServer`] server.
Args:
path (`str`, optional):
The URL path to register the webhook function. If not provided, the function name will be used as the
path. In any case, all webhooks are registered under `/webhooks`.
Raises:
ValueError: If the provided path is already registered as a webhook.
Example:
```python
from huggingface_hub import WebhooksServer, WebhookPayload
app = WebhooksServer()
@app.add_webhook
async def trigger_training(payload: WebhookPayload):
if payload.repo.type == "dataset" and payload.event.action == "update":
# Trigger a training job if a dataset is updated
...
app.run()
```
"""
# Usage: directly as decorator. Example: `@app.add_webhook`
if callable(path):
# If path is a function, it means it was used as a decorator without arguments
return self.add_webhook()(path)
# Usage: provide a path. Example: `@app.add_webhook(...)`
@wraps(FastAPI.post)
def _inner_post(*args, **kwargs):
func = args[0]
abs_path = f"/webhooks/{(path or func.__name__).strip('/')}"
if abs_path in self.registered_webhooks:
raise ValueError(f"Webhook {abs_path} already exists.")
self.registered_webhooks[abs_path] = func
return _inner_post
def launch(self, prevent_thread_lock: bool = False, **launch_kwargs: Any) -> None:
"""Launch the Gradio app and register webhooks to the underlying FastAPI server.
Input parameters are forwarded to Gradio when launching the app.
"""
ui = self._ui or self._get_default_ui()
# Start Gradio App
# - as non-blocking so that webhooks can be added afterwards
# - as shared if launch locally (to debug webhooks)
launch_kwargs.setdefault("share", _is_local)
self.fastapi_app, _, _ = ui.launch(prevent_thread_lock=True, **launch_kwargs)
# Register webhooks to FastAPI app
for path, func in self.registered_webhooks.items():
# Add secret check if required
if self.webhook_secret is not None:
func = _wrap_webhook_to_check_secret(func, webhook_secret=self.webhook_secret)
# Add route to FastAPI app
self.fastapi_app.post(path)(func)
# Print instructions and block main thread
url = (ui.share_url or ui.local_url).strip("/")
message = "\nWebhooks are correctly setup and ready to use:"
message += "\n" + "\n".join(f" - POST {url}{webhook}" for webhook in self.registered_webhooks)
message += "\nGo to https://huggingface.co/settings/webhooks to setup your webhooks."
print(message)
if not prevent_thread_lock:
ui.block_thread()
@_deprecate_method(version="0.23", message="Use `WebhooksServer.launch` instead.")
def run(self) -> None:
return self.launch()
def _get_default_ui(self) -> "gr.Blocks":
"""Default UI if not provided (lists webhooks and provides basic instructions)."""
import gradio as gr
with gr.Blocks() as ui:
gr.Markdown("# This is an app to process 🤗 Webhooks")
gr.Markdown(
"Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on"
" specific repos or to all repos belonging to particular set of users/organizations (not just your"
" repos, but any repo). Check out this [guide](https://huggingface.co/docs/hub/webhooks) to get to"
" know more about webhooks on the Huggingface Hub."
)
gr.Markdown(
f"{len(self.registered_webhooks)} webhook(s) are registered:"
+ "\n\n"
+ "\n ".join(
f"- [{webhook_path}]({_get_webhook_doc_url(webhook.__name__, webhook_path)})"
for webhook_path, webhook in self.registered_webhooks.items()
)
)
gr.Markdown(
"Go to https://huggingface.co/settings/webhooks to setup your webhooks."
+ "\nYou app is running locally. Please look at the logs to check the full URL you need to set."
if _is_local
else (
"\nThis app is running on a Space. You can find the corresponding URL in the options menu"
" (top-right) > 'Embed the Space'. The URL looks like 'https://{username}-{repo_name}.hf.space'."
)
)
return ui
@experimental
def webhook_endpoint(path: Optional[str] = None) -> Callable:
"""Decorator to start a [`WebhooksServer`] and register the decorated function as a webhook endpoint.
This is a helper to get started quickly. If you need more flexibility (custom landing page or webhook secret),
you can use [`WebhooksServer`] directly. You can register multiple webhook endpoints (to the same server) by using
this decorator multiple times.
Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your
server and deploy it on a Space.
<Tip warning={true}>
`webhook_endpoint` is experimental. Its API is subject to change in the future.
</Tip>
<Tip warning={true}>
You must have `gradio` installed to use `webhook_endpoint` (`pip install --upgrade gradio`).
</Tip>
Args:
path (`str`, optional):
The URL path to register the webhook function. If not provided, the function name will be used as the path.
In any case, all webhooks are registered under `/webhooks`.
Examples:
The default usage is to register a function as a webhook endpoint. The function name will be used as the path.
The server will be started automatically at exit (i.e. at the end of the script).
```python
from huggingface_hub import webhook_endpoint, WebhookPayload
@webhook_endpoint
async def trigger_training(payload: WebhookPayload):
if payload.repo.type == "dataset" and payload.event.action == "update":
# Trigger a training job if a dataset is updated
...
# Server is automatically started at the end of the script.
```
Advanced usage: register a function as a webhook endpoint and start the server manually. This is useful if you
are running it in a notebook.
```python
from huggingface_hub import webhook_endpoint, WebhookPayload
@webhook_endpoint
async def trigger_training(payload: WebhookPayload):
if payload.repo.type == "dataset" and payload.event.action == "update":
# Trigger a training job if a dataset is updated
...
# Start the server manually
trigger_training.run()
```
"""
if callable(path):
# If path is a function, it means it was used as a decorator without arguments
return webhook_endpoint()(path)
@wraps(WebhooksServer.add_webhook)
def _inner(func: Callable) -> Callable:
app = _get_global_app()
app.add_webhook(path)(func)
if len(app.registered_webhooks) == 1:
# Register `app.run` to run at exit (only once)
atexit.register(app.run)
@wraps(app.run)
def _run_now():
# Run the app directly (without waiting atexit)
atexit.unregister(app.run)
app.run()
func.run = _run_now # type: ignore
return func
return _inner
def _get_global_app() -> WebhooksServer:
global _global_app
if _global_app is None:
_global_app = WebhooksServer()
return _global_app
def _warn_on_empty_secret(webhook_secret: Optional[str]) -> None:
if webhook_secret is None:
print("Webhook secret is not defined. This means your webhook endpoints will be open to everyone.")
print(
"To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: "
"\n\t`app = WebhooksServer(webhook_secret='my_secret', ...)`"
)
print(
"For more details about webhook secrets, please refer to"
" https://huggingface.co/docs/hub/webhooks#webhook-secret."
)
else:
print("Webhook secret is correctly defined.")
def _get_webhook_doc_url(webhook_name: str, webhook_path: str) -> str:
"""Returns the anchor to a given webhook in the docs (experimental)"""
return "/docs#/default/" + webhook_name + webhook_path.replace("/", "_") + "_post"
def _wrap_webhook_to_check_secret(func: Callable, webhook_secret: str) -> Callable:
"""Wraps a webhook function to check the webhook secret before calling the function.
This is a hacky way to add the `request` parameter to the function signature. Since FastAPI based itself on route
parameters to inject the values to the function, we need to hack the function signature to retrieve the `Request`
object (and hence the headers). A far cleaner solution would be to use a middleware. However, since
`fastapi==0.90.1`, a middleware cannot be added once the app has started. And since the FastAPI app is started by
Gradio internals (and not by us), we cannot add a middleware.
This method is called only when a secret has been defined by the user. If a request is sent without the
"x-webhook-secret", the function will return a 401 error (unauthorized). If the header is sent but is incorrect,
the function will return a 403 error (forbidden).
Inspired by https://stackoverflow.com/a/33112180.
"""
initial_sig = inspect.signature(func)
@wraps(func)
async def _protected_func(request: Request, **kwargs):
request_secret = request.headers.get("x-webhook-secret")
if request_secret is None:
return JSONResponse({"error": "x-webhook-secret header not set."}, status_code=401)
if request_secret != webhook_secret:
return JSONResponse({"error": "Invalid webhook secret."}, status_code=403)
# Inject `request` in kwargs if required
if "request" in initial_sig.parameters:
kwargs["request"] = request
# Handle both sync and async routes
if inspect.iscoroutinefunction(func):
return await func(**kwargs)
else:
return func(**kwargs)
# Update signature to include request
if "request" not in initial_sig.parameters:
_protected_func.__signature__ = initial_sig.replace( # type: ignore
parameters=(
inspect.Parameter(name="request", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request),
)
+ tuple(initial_sig.parameters.values())
)
# Return protected route
return _protected_func
|