Spaces:
Runtime error
Runtime error
neondaniel
commited on
Commit
•
640efa7
1
Parent(s):
f1a3e74
Update to print disallowed endpoints in-place in the model list
Browse filesUpdate configuration handling to put all clients in `clients` with backwards-compat. parsing
Troubleshoot radio button rendering
Refactor permissions configuration to support other oauth methods
app.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import os
|
2 |
import json
|
|
|
|
|
3 |
import gradio as gr
|
4 |
|
5 |
import uvicorn
|
@@ -11,7 +13,7 @@ from starlette.responses import RedirectResponse
|
|
11 |
from authlib.integrations.starlette_client import OAuth, OAuthError
|
12 |
from fastapi import FastAPI, Request
|
13 |
|
14 |
-
from shared import Client
|
15 |
|
16 |
app = FastAPI()
|
17 |
config = {}
|
@@ -51,41 +53,44 @@ def init_config():
|
|
51 |
global clients
|
52 |
global llm_host_names
|
53 |
config = json.loads(os.environ['CONFIG'])
|
54 |
-
|
55 |
-
for name in
|
56 |
-
|
57 |
-
continue
|
58 |
-
model_personas = config[name].get("personas", {})
|
59 |
client = Client(
|
60 |
-
api_url=os.environ.get(
|
61 |
-
|
62 |
-
api_key=os.environ.get(
|
63 |
-
|
64 |
personas=model_personas
|
65 |
)
|
66 |
clients[name] = client
|
67 |
-
llm_host_names = list(
|
68 |
|
69 |
|
70 |
-
def get_allowed_models(
|
71 |
"""
|
72 |
Get a list of allowed endpoints for a specified user domain. Allowed domains
|
73 |
are configured in each model's configuration and may optionally be overridden
|
74 |
in the Gradio demo configuration.
|
75 |
-
:param
|
76 |
-
:return: List of allowed endpoints from configuration
|
|
|
77 |
"""
|
78 |
-
overrides = config.get("
|
79 |
allowed_endpoints = []
|
80 |
for client in clients:
|
81 |
-
|
82 |
-
|
83 |
-
if
|
84 |
-
#
|
85 |
allowed_endpoints.append(client)
|
86 |
-
elif
|
87 |
-
|
|
|
88 |
allowed_endpoints.append(client)
|
|
|
|
|
|
|
89 |
return allowed_endpoints
|
90 |
|
91 |
|
@@ -107,7 +112,7 @@ def get_login_button(request: gr.Request) -> gr.Button:
|
|
107 |
:param request: Gradio request to evaluate
|
108 |
:return: Button for either login or logout action
|
109 |
"""
|
110 |
-
user = get_user(request)
|
111 |
print(f"Getting login button for {user}")
|
112 |
|
113 |
if user == "guest":
|
@@ -116,15 +121,39 @@ def get_login_button(request: gr.Request) -> gr.Button:
|
|
116 |
return gr.Button(f"Logout {user}", link="/logout")
|
117 |
|
118 |
|
119 |
-
def get_user(request: Request) ->
|
120 |
"""
|
121 |
Get a unique user email address for the specified request
|
122 |
:param request: FastAPI Request object with user session data
|
123 |
:return: String user email address or "guest"
|
124 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
if not request:
|
126 |
-
return "guest"
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
return user
|
129 |
|
130 |
|
@@ -232,25 +261,25 @@ def get_model_options(request: gr.Request) -> List[gr.Radio]:
|
|
232 |
# `user` is a valid Google email address or 'guest'
|
233 |
user = get_user(request.request)
|
234 |
else:
|
235 |
-
user = "guest"
|
236 |
-
print(f"Getting models for {user}")
|
237 |
|
238 |
-
|
239 |
-
allowed_llm_host_names = get_allowed_models(domain)
|
240 |
|
241 |
radio_infos = [f"{name} ({clients[name].vllm_model_name})"
|
|
|
242 |
for name in allowed_llm_host_names]
|
243 |
# Components
|
244 |
-
radios = [gr.Radio(choices=clients[name].personas.keys(),
|
245 |
value=None, label=info) for name, info
|
246 |
in zip(allowed_llm_host_names, radio_infos)]
|
247 |
|
248 |
# Select the first available option by default
|
249 |
radios[0].value = list(clients[allowed_llm_host_names[0]].personas.keys())[0]
|
250 |
print(f"Set default persona to {radios[0].value} for {allowed_llm_host_names[0]}")
|
251 |
-
# Ensure we always have the same number of rows
|
252 |
-
while len(radios) < len(llm_host_names):
|
253 |
-
|
254 |
return radios
|
255 |
|
256 |
|
@@ -271,6 +300,17 @@ def init_gradio() -> gr.Blocks:
|
|
271 |
@gr.on(triggers=[blocks.load, *[radio.input for radio in radios]],
|
272 |
inputs=[radio_state, *radios], outputs=[radio_state, *radios])
|
273 |
def radio_click(state, *new_state):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
try:
|
275 |
changed_index = next(i for i in range(len(state))
|
276 |
if state[i] != new_state[i])
|
@@ -326,5 +366,5 @@ if __name__ == "__main__":
|
|
326 |
init_config()
|
327 |
init_oauth()
|
328 |
blocks = init_gradio()
|
329 |
-
app = gr.mount_gradio_app(app, blocks, '/'
|
330 |
uvicorn.run(app, host='0.0.0.0', port=7860)
|
|
|
1 |
import os
|
2 |
import json
|
3 |
+
from time import sleep
|
4 |
+
|
5 |
import gradio as gr
|
6 |
|
7 |
import uvicorn
|
|
|
13 |
from authlib.integrations.starlette_client import OAuth, OAuthError
|
14 |
from fastapi import FastAPI, Request
|
15 |
|
16 |
+
from shared import Client, User, OAuthProvider
|
17 |
|
18 |
app = FastAPI()
|
19 |
config = {}
|
|
|
53 |
global clients
|
54 |
global llm_host_names
|
55 |
config = json.loads(os.environ['CONFIG'])
|
56 |
+
client_config = config.get("clients") or config
|
57 |
+
for name in client_config:
|
58 |
+
model_personas = client_config[name].get("personas", {})
|
|
|
|
|
59 |
client = Client(
|
60 |
+
api_url=os.environ.get(client_config[name]['api_url'],
|
61 |
+
client_config[name]['api_url']),
|
62 |
+
api_key=os.environ.get(client_config[name]['api_key'],
|
63 |
+
client_config[name]['api_key']),
|
64 |
personas=model_personas
|
65 |
)
|
66 |
clients[name] = client
|
67 |
+
llm_host_names = list(client_config.keys())
|
68 |
|
69 |
|
70 |
+
def get_allowed_models(user: User) -> List[str]:
|
71 |
"""
|
72 |
Get a list of allowed endpoints for a specified user domain. Allowed domains
|
73 |
are configured in each model's configuration and may optionally be overridden
|
74 |
in the Gradio demo configuration.
|
75 |
+
:param user: User to get permissions for
|
76 |
+
:return: List of allowed endpoints from configuration (including empty
|
77 |
+
strings for disallowed endpoints)
|
78 |
"""
|
79 |
+
overrides = config.get("permissions_override", {})
|
80 |
allowed_endpoints = []
|
81 |
for client in clients:
|
82 |
+
permission = overrides.get(client,
|
83 |
+
clients[client].config.inference.permissions)
|
84 |
+
if not permission:
|
85 |
+
# Permissions not specified (None or empty dict); model is public
|
86 |
allowed_endpoints.append(client)
|
87 |
+
elif user.oauth == OAuthProvider.GOOGLE and user.permissions_id in \
|
88 |
+
permission.get("google_domains", []):
|
89 |
+
# Google oauth domain is in the allowed domain list
|
90 |
allowed_endpoints.append(client)
|
91 |
+
else:
|
92 |
+
allowed_endpoints.append("")
|
93 |
+
print(f"No permission to access {client}")
|
94 |
return allowed_endpoints
|
95 |
|
96 |
|
|
|
112 |
:param request: Gradio request to evaluate
|
113 |
:return: Button for either login or logout action
|
114 |
"""
|
115 |
+
user = get_user(request).username
|
116 |
print(f"Getting login button for {user}")
|
117 |
|
118 |
if user == "guest":
|
|
|
121 |
return gr.Button(f"Logout {user}", link="/logout")
|
122 |
|
123 |
|
124 |
+
def get_user(request: Request) -> User:
|
125 |
"""
|
126 |
Get a unique user email address for the specified request
|
127 |
:param request: FastAPI Request object with user session data
|
128 |
:return: String user email address or "guest"
|
129 |
"""
|
130 |
+
# {'iss': 'https://accounts.google.com',
|
131 |
+
# 'azp': '***.apps.googleusercontent.com',
|
132 |
+
# 'aud': '***.apps.googleusercontent.com',
|
133 |
+
# 'sub': '###',
|
134 |
+
# 'hd': 'neon.ai',
|
135 |
+
# 'email': 'daniel@neon.ai',
|
136 |
+
# 'email_verified': True,
|
137 |
+
# 'at_hash': '***',
|
138 |
+
# 'nonce': '***',
|
139 |
+
# 'name': 'Daniel McKnight',
|
140 |
+
# 'picture': 'https://lh3.googleusercontent.com/a/***',
|
141 |
+
# 'given_name': '***',
|
142 |
+
# 'family_name': '***',
|
143 |
+
# 'iat': ###,
|
144 |
+
# 'exp': ###}
|
145 |
if not request:
|
146 |
+
return User(OAuthProvider.NONE, "guest", "")
|
147 |
+
|
148 |
+
user_dict = request.session.get("user", {})
|
149 |
+
if user_dict.get("iss") == "https://accounts.google.com":
|
150 |
+
user = User(OAuthProvider.GOOGLE, user_dict["email"], user_dict["hd"])
|
151 |
+
elif user_dict:
|
152 |
+
print(f"Unknown user session data: {user_dict}")
|
153 |
+
user = User(OAuthProvider.NONE, "guest", "")
|
154 |
+
else:
|
155 |
+
user = User(OAuthProvider.NONE, "guest", "")
|
156 |
+
print(user)
|
157 |
return user
|
158 |
|
159 |
|
|
|
261 |
# `user` is a valid Google email address or 'guest'
|
262 |
user = get_user(request.request)
|
263 |
else:
|
264 |
+
user = User(OAuthProvider.NONE, "guest", "")
|
265 |
+
print(f"Getting models for {user.username}")
|
266 |
|
267 |
+
allowed_llm_host_names = get_allowed_models(user)
|
|
|
268 |
|
269 |
radio_infos = [f"{name} ({clients[name].vllm_model_name})"
|
270 |
+
if name in clients else "Not Authorized"
|
271 |
for name in allowed_llm_host_names]
|
272 |
# Components
|
273 |
+
radios = [gr.Radio(choices=clients[name].personas.keys() if name in clients else [],
|
274 |
value=None, label=info) for name, info
|
275 |
in zip(allowed_llm_host_names, radio_infos)]
|
276 |
|
277 |
# Select the first available option by default
|
278 |
radios[0].value = list(clients[allowed_llm_host_names[0]].personas.keys())[0]
|
279 |
print(f"Set default persona to {radios[0].value} for {allowed_llm_host_names[0]}")
|
280 |
+
# # Ensure we always have the same number of rows
|
281 |
+
# while len(radios) < len(llm_host_names):
|
282 |
+
# radios.append(gr.Radio(choices=[], value=None, label="Not Authorized"))
|
283 |
return radios
|
284 |
|
285 |
|
|
|
300 |
@gr.on(triggers=[blocks.load, *[radio.input for radio in radios]],
|
301 |
inputs=[radio_state, *radios], outputs=[radio_state, *radios])
|
302 |
def radio_click(state, *new_state):
|
303 |
+
"""
|
304 |
+
Handle any state changes that require re-rendering radio buttons
|
305 |
+
:param state: Previous radio state representation (before selection)
|
306 |
+
:param new_state: Current radio state (including selection)
|
307 |
+
:return: Desired new state (current option selected, previous option
|
308 |
+
deselected)
|
309 |
+
"""
|
310 |
+
# Login and model options are triggered on load. This sleep is just
|
311 |
+
# a hack to make sure those events run before this logic to select
|
312 |
+
# the default model
|
313 |
+
sleep(0.1)
|
314 |
try:
|
315 |
changed_index = next(i for i in range(len(state))
|
316 |
if state[i] != new_state[i])
|
|
|
366 |
init_config()
|
367 |
init_oauth()
|
368 |
blocks = init_gradio()
|
369 |
+
app = gr.mount_gradio_app(app, blocks, '/')
|
370 |
uvicorn.run(app, host='0.0.0.0', port=7860)
|
shared.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
import yaml
|
2 |
|
3 |
from typing import Dict, Optional, List
|
@@ -8,6 +11,18 @@ from huggingface_hub.utils import EntryNotFoundError
|
|
8 |
from openai import OpenAI
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
class PileConfig(BaseModel):
|
12 |
file2persona: Dict[str, str]
|
13 |
file2prefix: Dict[str, str]
|
@@ -17,7 +32,7 @@ class PileConfig(BaseModel):
|
|
17 |
|
18 |
class InferenceConfig(BaseModel):
|
19 |
chat_template: str
|
20 |
-
|
21 |
|
22 |
|
23 |
class RepoConfig(BaseModel):
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from enum import IntEnum
|
3 |
+
|
4 |
import yaml
|
5 |
|
6 |
from typing import Dict, Optional, List
|
|
|
11 |
from openai import OpenAI
|
12 |
|
13 |
|
14 |
+
class OAuthProvider(IntEnum):
|
15 |
+
NONE = 0
|
16 |
+
GOOGLE = 1
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class User:
|
21 |
+
oauth: OAuthProvider
|
22 |
+
username: str
|
23 |
+
permissions_id: str
|
24 |
+
|
25 |
+
|
26 |
class PileConfig(BaseModel):
|
27 |
file2persona: Dict[str, str]
|
28 |
file2prefix: Dict[str, str]
|
|
|
32 |
|
33 |
class InferenceConfig(BaseModel):
|
34 |
chat_template: str
|
35 |
+
permissions: Dict[str, list] = {}
|
36 |
|
37 |
|
38 |
class RepoConfig(BaseModel):
|