Spaces:
Runtime error
Runtime error
First oauth test
Browse files- README.md +5 -2
- app.py +27 -0
- auth.py +78 -0
- requirements.txt +2 -0
- start.py +3 -0
README.md
CHANGED
@@ -3,8 +3,11 @@ title: Gradio Oauth Test
|
|
3 |
emoji: 🏆
|
4 |
colorFrom: pink
|
5 |
colorTo: pink
|
6 |
-
sdk:
|
7 |
-
|
|
|
|
|
|
|
8 |
---
|
9 |
|
10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
3 |
emoji: 🏆
|
4 |
colorFrom: pink
|
5 |
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.36.1
|
8 |
+
python_version: 3.10.6
|
9 |
+
app_file: start.py
|
10 |
+
hf_oauth: true
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from auth import get_app
|
3 |
+
|
4 |
+
|
5 |
+
TEMPLATE = """
|
6 |
+
### Name: {name}
|
7 |
+
### Username: {preferred_username}
|
8 |
+
### Profile: {profile}
|
9 |
+
### Website: {website}
|
10 |
+
|
11 |
+
![Profile Picture]({picture})
|
12 |
+
|
13 |
+
You can manage your connected applications in your [settings](https://huggingface.co/settings/connected-applications).
|
14 |
+
"""
|
15 |
+
|
16 |
+
|
17 |
+
def show_profile(request: gr.Request) -> str:
|
18 |
+
return TEMPLATE.format(**request.request.session["user"])
|
19 |
+
|
20 |
+
|
21 |
+
with gr.Blocks() as demo:
|
22 |
+
greet_btn = gr.Button("Show profile")
|
23 |
+
output = gr.Markdown()
|
24 |
+
greet_btn.click(fn=show_profile, outputs=output)
|
25 |
+
|
26 |
+
fastapi_app = get_app()
|
27 |
+
app = gr.mount_gradio_app(fastapi_app, demo, path="/gradio")
|
auth.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import httpx
|
3 |
+
|
4 |
+
from authlib.integrations.starlette_client import OAuth
|
5 |
+
from fastapi import FastAPI
|
6 |
+
from fastapi.requests import Request
|
7 |
+
from fastapi.responses import RedirectResponse
|
8 |
+
from starlette.middleware.sessions import SessionMiddleware
|
9 |
+
|
10 |
+
OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID")
|
11 |
+
OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET")
|
12 |
+
OAUTH_SCOPES = os.environ.get("OAUTH_SCOPES")
|
13 |
+
OAUTH_SCOPES = "profile" # TODO: remove when openid is fixed (honor nonce)
|
14 |
+
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL")
|
15 |
+
|
16 |
+
for value in (OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OAUTH_SCOPES, OPENID_PROVIDER_URL):
|
17 |
+
if value is None:
|
18 |
+
raise ValueError("Missing environment variable")
|
19 |
+
|
20 |
+
AUTHORIZE_URL = OPENID_PROVIDER_URL + "/oauth/authorize"
|
21 |
+
ACCESS_TOKEN_URL = OPENID_PROVIDER_URL + "/oauth/token"
|
22 |
+
USER_INFO_URL = OPENID_PROVIDER_URL + "/oauth/userinfo"
|
23 |
+
|
24 |
+
oauth = OAuth()
|
25 |
+
oauth.register(
|
26 |
+
name="huggingface",
|
27 |
+
client_id=OAUTH_CLIENT_ID,
|
28 |
+
client_secret=OAUTH_CLIENT_SECRET,
|
29 |
+
access_token_url=ACCESS_TOKEN_URL,
|
30 |
+
authorize_url=AUTHORIZE_URL,
|
31 |
+
api_base_url=OPENID_PROVIDER_URL,
|
32 |
+
client_kwargs={"scope": OAUTH_SCOPES},
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
async def landing(request: Request):
|
37 |
+
if request.session.get("user"):
|
38 |
+
return RedirectResponse("/gradio")
|
39 |
+
else:
|
40 |
+
return RedirectResponse(request.url_for("oauth_login"))
|
41 |
+
|
42 |
+
|
43 |
+
async def oauth_login(request: Request):
|
44 |
+
redirect_uri = request.url_for("oauth_redirect_callback")
|
45 |
+
return await oauth.huggingface.authorize_redirect(request, redirect_uri)
|
46 |
+
|
47 |
+
|
48 |
+
async def oauth_redirect_callback(request: Request):
|
49 |
+
token = await oauth.huggingface.authorize_access_token(request)
|
50 |
+
|
51 |
+
async with httpx.AsyncClient() as client:
|
52 |
+
resp = await client.get(USER_INFO_URL, headers={"Authorization": f"Bearer {token['access_token']}"})
|
53 |
+
user_info = resp.json()
|
54 |
+
|
55 |
+
request.session["user"] = user_info # TODO: we should store token instead
|
56 |
+
return RedirectResponse(request.url_for("landing"))
|
57 |
+
|
58 |
+
|
59 |
+
async def check_oauth(request: Request, call_next):
|
60 |
+
if request.url.path in (
|
61 |
+
"/",
|
62 |
+
"/auth/huggingface",
|
63 |
+
"/auth/callback",
|
64 |
+
): # not protected
|
65 |
+
return await call_next(request)
|
66 |
+
if request.session.get("user"): # authenticated
|
67 |
+
return await call_next(request)
|
68 |
+
return RedirectResponse("/")
|
69 |
+
|
70 |
+
|
71 |
+
def get_app() -> FastAPI:
|
72 |
+
app = FastAPI()
|
73 |
+
app.middleware("http")(check_oauth)
|
74 |
+
app.add_middleware(SessionMiddleware, secret_key="session-secret-key") # TODO: make this is secret key
|
75 |
+
app.get("/")(landing)
|
76 |
+
app.get("/auth/huggingface")(oauth_login)
|
77 |
+
app.get("/auth/callback")(oauth_redirect_callback)
|
78 |
+
return app
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
authlib==1.2.1
|
2 |
+
itsdangerous==2.1.2
|
start.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
|
3 |
+
subprocess.run("uvicorn app:app --host 0.0.0.0 --port 7860", shell=True)
|