neondaniel commited on
Commit
d098bff
1 Parent(s): ee0a441

Add OAuth Support (#1)

Browse files

- Implement Google oauth (d9f8b28560a57546599305dca4baf52b29d2c7ec)
- Add `huggingface_text` and `allowed_domains_override` configuration (f1a3e74f534b544673c1dd3489ddf7a059bdb770)
- Update to print disallowed endpoints in-place in the model list (640efa738283fbf227e32c024345e4be34c3ec0b)

Files changed (3) hide show
  1. app.py +309 -66
  2. requirements.txt +5 -1
  3. shared.py +19 -1
app.py CHANGED
@@ -1,49 +1,217 @@
1
  import os
2
  import json
3
- from typing import List, Tuple
4
- from collections import OrderedDict
5
 
6
  import gradio as gr
7
 
8
- from shared import Client
9
-
10
-
11
- config = json.loads(os.environ['CONFIG'])
12
-
 
 
 
13
 
 
14
 
 
 
15
  clients = {}
16
- for name in config:
17
- model_personas = config[name].get("personas", {})
18
- client = Client(
19
- api_url=os.environ[config[name]['api_url']],
20
- api_key=os.environ[config[name]['api_key']],
21
- personas=model_personas
 
 
 
 
 
 
 
 
 
 
 
22
  )
23
- clients[name] = client
24
-
25
-
26
- model_names = list(config.keys())
27
- radio_infos = [f"{name} ({clients[name].vllm_model_name})" for name in model_names]
28
- accordion_info = "Persona and LLM Options - Choose one:"
29
-
30
-
31
-
32
- def parse_radio_select(radio_select):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  value_index = next(i for i in range(len(radio_select)) if radio_select[i] is not None)
34
- model = model_names[value_index]
35
  persona = radio_select[value_index]
36
  return model, persona
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def respond(
41
- message,
42
  history: List[Tuple[str, str]],
43
- conversational,
44
- max_tokens,
45
  *radio_select,
46
  ):
 
 
 
 
 
 
 
 
 
47
  model, persona = parse_radio_select(radio_select)
48
 
49
  client = clients[model]
@@ -83,45 +251,120 @@ def respond(
83
  return response
84
 
85
 
86
- # Components
87
- radios = [gr.Radio(choices=clients[name].personas.keys(), value=None, label=info) for name, info in zip(model_names, radio_infos)]
88
- radios[0].value = list(clients[model_names[0]].personas.keys())[0]
89
-
90
- conversational_checkbox = gr.Checkbox(value=True, label="conversational")
91
- max_tokens_slider = gr.Slider(minimum=64, maximum=2048, value=512, step=64, label="Max new tokens")
92
-
93
-
94
-
95
- with gr.Blocks() as blocks:
96
- # Events
97
- radio_state = gr.State([radio.value for radio in radios])
98
- @gr.on(triggers=[radio.input for radio in radios], inputs=[radio_state, *radios], outputs=[radio_state, *radios])
99
- def radio_click(state, *new_state):
100
- changed_index = next(i for i in range(len(state)) if state[i] != new_state[i])
101
- changed_value = new_state[changed_index]
102
- clean_state = [None if i != changed_index else changed_value for i in range(len(state))]
103
-
104
- return clean_state, *clean_state
105
-
106
- # Compile
107
- with gr.Accordion(label=accordion_info, open=True, render=False) as accordion:
108
- [radio.render() for radio in radios]
109
- conversational_checkbox.render()
110
- max_tokens_slider.render()
111
-
112
- demo = gr.ChatInterface(
113
- respond,
114
- additional_inputs=[
115
- conversational_checkbox,
116
- max_tokens_slider,
117
- *radios,
118
- ],
119
- additional_inputs_accordion=accordion,
120
- title="Neon AI BrainForge Personas and Large Language Models (v2024-07-24)",
121
- concurrency_limit=5,
122
- )
123
- accordion.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
 
126
  if __name__ == "__main__":
127
- blocks.launch()
 
 
 
 
 
1
  import os
2
  import json
3
+ from time import sleep
 
4
 
5
  import gradio as gr
6
 
7
+ import uvicorn
8
+ from datetime import datetime
9
+ from typing import List, Tuple
10
+ from starlette.config import Config
11
+ from starlette.middleware.sessions import SessionMiddleware
12
+ from starlette.responses import RedirectResponse
13
+ from authlib.integrations.starlette_client import OAuth, OAuthError
14
+ from fastapi import FastAPI, Request
15
 
16
+ from shared import Client, User, OAuthProvider
17
 
18
+ app = FastAPI()
19
+ config = {}
20
  clients = {}
21
+ llm_host_names = []
22
+ oauth = None
23
+
24
+
25
+ def init_oauth():
26
+ global oauth
27
+ google_client_id = os.environ.get("GOOGLE_CLIENT_ID")
28
+ google_client_secret = os.environ.get("GOOGLE_CLIENT_SECRET")
29
+ secret_key = os.environ.get('SECRET_KEY') or "a_very_secret_key"
30
+
31
+ starlette_config = Config(environ={"GOOGLE_CLIENT_ID": google_client_id,
32
+ "GOOGLE_CLIENT_SECRET": google_client_secret})
33
+ oauth = OAuth(starlette_config)
34
+ oauth.register(
35
+ name='google',
36
+ server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
37
+ client_kwargs={'scope': 'openid email profile'}
38
  )
39
+ app.add_middleware(SessionMiddleware, secret_key=secret_key)
40
+
41
+
42
+ def init_config():
43
+ """
44
+ Initialize configuration. A configured `api_url` or `api_key` may be an
45
+ envvar reference OR a literal value. Configuration should follow the
46
+ format:
47
+ {"<llm_host_name>": {"api_key": "<api_key>",
48
+ "api_url": "<api_url>"
49
+ }
50
+ }
51
+ """
52
+ global config
53
+ global clients
54
+ global llm_host_names
55
+ config = json.loads(os.environ['CONFIG'])
56
+ client_config = config.get("clients") or config
57
+ for name in client_config:
58
+ model_personas = client_config[name].get("personas", {})
59
+ client = Client(
60
+ api_url=os.environ.get(client_config[name]['api_url'],
61
+ client_config[name]['api_url']),
62
+ api_key=os.environ.get(client_config[name]['api_key'],
63
+ client_config[name]['api_key']),
64
+ personas=model_personas
65
+ )
66
+ clients[name] = client
67
+ llm_host_names = list(client_config.keys())
68
+
69
+
70
+ def get_allowed_models(user: User) -> List[str]:
71
+ """
72
+ Get a list of allowed endpoints for a specified user domain. Allowed domains
73
+ are configured in each model's configuration and may optionally be overridden
74
+ in the Gradio demo configuration.
75
+ :param user: User to get permissions for
76
+ :return: List of allowed endpoints from configuration (including empty
77
+ strings for disallowed endpoints)
78
+ """
79
+ overrides = config.get("permissions_override", {})
80
+ allowed_endpoints = []
81
+ for client in clients:
82
+ permission = overrides.get(client,
83
+ clients[client].config.inference.permissions)
84
+ if not permission:
85
+ # Permissions not specified (None or empty dict); model is public
86
+ allowed_endpoints.append(client)
87
+ elif user.oauth == OAuthProvider.GOOGLE and user.permissions_id in \
88
+ permission.get("google_domains", []):
89
+ # Google oauth domain is in the allowed domain list
90
+ allowed_endpoints.append(client)
91
+ else:
92
+ allowed_endpoints.append("")
93
+ print(f"No permission to access {client}")
94
+ return allowed_endpoints
95
+
96
+
97
+ def parse_radio_select(radio_select: tuple) -> (str, str):
98
+ """
99
+ Parse radio selection to determine the requested model and persona
100
+ :param radio_select: List of radio selection states
101
+ :return: Selected model, persona
102
+ """
103
  value_index = next(i for i in range(len(radio_select)) if radio_select[i] is not None)
104
+ model = llm_host_names[value_index]
105
  persona = radio_select[value_index]
106
  return model, persona
107
 
108
 
109
+ def get_login_button(request: gr.Request) -> gr.Button:
110
+ """
111
+ Get a login/logout button based on current login status
112
+ :param request: Gradio request to evaluate
113
+ :return: Button for either login or logout action
114
+ """
115
+ user = get_user(request).username
116
+ print(f"Getting login button for {user}")
117
+
118
+ if user == "guest":
119
+ return gr.Button("Login", link="/login")
120
+ else:
121
+ return gr.Button(f"Logout {user}", link="/logout")
122
+
123
+
124
+ def get_user(request: Request) -> User:
125
+ """
126
+ Get a unique user email address for the specified request
127
+ :param request: FastAPI Request object with user session data
128
+ :return: String user email address or "guest"
129
+ """
130
+ # {'iss': 'https://accounts.google.com',
131
+ # 'azp': '***.apps.googleusercontent.com',
132
+ # 'aud': '***.apps.googleusercontent.com',
133
+ # 'sub': '###',
134
+ # 'hd': 'neon.ai',
135
+ # 'email': 'daniel@neon.ai',
136
+ # 'email_verified': True,
137
+ # 'at_hash': '***',
138
+ # 'nonce': '***',
139
+ # 'name': 'Daniel McKnight',
140
+ # 'picture': 'https://lh3.googleusercontent.com/a/***',
141
+ # 'given_name': '***',
142
+ # 'family_name': '***',
143
+ # 'iat': ###,
144
+ # 'exp': ###}
145
+ if not request:
146
+ return User(OAuthProvider.NONE, "guest", "")
147
+
148
+ user_dict = request.session.get("user", {})
149
+ if user_dict.get("iss") == "https://accounts.google.com":
150
+ user = User(OAuthProvider.GOOGLE, user_dict["email"], user_dict["hd"])
151
+ elif user_dict:
152
+ print(f"Unknown user session data: {user_dict}")
153
+ user = User(OAuthProvider.NONE, "guest", "")
154
+ else:
155
+ user = User(OAuthProvider.NONE, "guest", "")
156
+ print(user)
157
+ return user
158
+
159
+
160
+ @app.route('/logout')
161
+ async def logout(request: Request):
162
+ """
163
+ Remove the user session context and reload an un-authenticated session
164
+ :param request: FastAPI Request object with user session data
165
+ :return: Redirect to `/`
166
+ """
167
+ request.session.pop('user', None)
168
+ return RedirectResponse(url='/')
169
+
170
+
171
+ @app.route('/login')
172
+ async def login(request: Request):
173
+ """
174
+ Start oauth flow for login with Google
175
+ :param request: FastAPI Request object
176
+ """
177
+ redirect_uri = request.url_for('auth')
178
+ # Ensure that the `redirect_uri` is https
179
+ from urllib.parse import urlparse, urlunparse
180
+ redirect_uri = urlunparse(urlparse(str(redirect_uri))._replace(scheme='https'))
181
+
182
+ return await oauth.google.authorize_redirect(request, redirect_uri)
183
+
184
+
185
+ @app.route('/auth')
186
+ async def auth(request: Request):
187
+ """
188
+ Callback endpoint for Google oauth
189
+ :param request: FastAPI Request object
190
+ """
191
+ try:
192
+ access_token = await oauth.google.authorize_access_token(request)
193
+ except OAuthError:
194
+ return RedirectResponse(url='/')
195
+ request.session['user'] = dict(access_token)["userinfo"]
196
+ return RedirectResponse(url='/')
197
+
198
 
199
  def respond(
200
+ message: str,
201
  history: List[Tuple[str, str]],
202
+ conversational: bool,
203
+ max_tokens: int,
204
  *radio_select,
205
  ):
206
+ """
207
+ Send user input to a vLLM backend and return the generated response
208
+ :param message: String input from the user
209
+ :param history: Optional list of chat history (<user message>,<llm message>)
210
+ :param conversational: If true, include chat history
211
+ :param max_tokens: Maximum tokens for the LLM to generate
212
+ :param radio_select: List of radio selection args to parse
213
+ :return: String LLM response
214
+ """
215
  model, persona = parse_radio_select(radio_select)
216
 
217
  client = clients[model]
 
251
  return response
252
 
253
 
254
+ def get_model_options(request: gr.Request) -> List[gr.Radio]:
255
+ """
256
+ Get allowed models for the specified session.
257
+ :param request: Gradio request object to get user from
258
+ :return: List of Radio objects for available models
259
+ """
260
+ if request:
261
+ # `user` is a valid Google email address or 'guest'
262
+ user = get_user(request.request)
263
+ else:
264
+ user = User(OAuthProvider.NONE, "guest", "")
265
+ print(f"Getting models for {user.username}")
266
+
267
+ allowed_llm_host_names = get_allowed_models(user)
268
+
269
+ radio_infos = [f"{name} ({clients[name].vllm_model_name})"
270
+ if name in clients else "Not Authorized"
271
+ for name in allowed_llm_host_names]
272
+ # Components
273
+ radios = [gr.Radio(choices=clients[name].personas.keys() if name in clients else [],
274
+ value=None, label=info) for name, info
275
+ in zip(allowed_llm_host_names, radio_infos)]
276
+
277
+ # Select the first available option by default
278
+ radios[0].value = list(clients[allowed_llm_host_names[0]].personas.keys())[0]
279
+ print(f"Set default persona to {radios[0].value} for {allowed_llm_host_names[0]}")
280
+ # # Ensure we always have the same number of rows
281
+ # while len(radios) < len(llm_host_names):
282
+ # radios.append(gr.Radio(choices=[], value=None, label="Not Authorized"))
283
+ return radios
284
+
285
+
286
+ def init_gradio() -> gr.Blocks:
287
+ """
288
+ Initialize a Gradio demo
289
+ :return:
290
+ """
291
+ conversational_checkbox = gr.Checkbox(value=True, label="conversational")
292
+ max_tokens_slider = gr.Slider(minimum=64, maximum=2048, value=512, step=64,
293
+ label="Max new tokens")
294
+ radios = get_model_options(None)
295
+
296
+ with gr.Blocks() as blocks:
297
+ # Events
298
+ radio_state = gr.State([radio.value for radio in radios])
299
+
300
+ @gr.on(triggers=[blocks.load, *[radio.input for radio in radios]],
301
+ inputs=[radio_state, *radios], outputs=[radio_state, *radios])
302
+ def radio_click(state, *new_state):
303
+ """
304
+ Handle any state changes that require re-rendering radio buttons
305
+ :param state: Previous radio state representation (before selection)
306
+ :param new_state: Current radio state (including selection)
307
+ :return: Desired new state (current option selected, previous option
308
+ deselected)
309
+ """
310
+ # Login and model options are triggered on load. This sleep is just
311
+ # a hack to make sure those events run before this logic to select
312
+ # the default model
313
+ sleep(0.1)
314
+ try:
315
+ changed_index = next(i for i in range(len(state))
316
+ if state[i] != new_state[i])
317
+ changed_value = new_state[changed_index]
318
+ except StopIteration:
319
+ # TODO: This is the result of some error in rendering a selected
320
+ # option.
321
+ # Changed to current selection
322
+ changed_value = [i for i in new_state if i is not None][0]
323
+ changed_index = new_state.index(changed_value)
324
+ clean_state = [None if i != changed_index else changed_value
325
+ for i in range(len(state))]
326
+ return clean_state, *clean_state
327
+
328
+ # Compile
329
+ hf_config = config.get("huggingface_text") or dict()
330
+ accordion_info = hf_config.get("accordian_info") or \
331
+ "Persona and LLM Options - Choose one:"
332
+ version = hf_config.get("version") or \
333
+ f"v{datetime.now().strftime('%Y-%m-%d')}"
334
+ title = hf_config.get("title") or \
335
+ f"Neon AI BrainForge Personas and Large Language Models ({version})"
336
+
337
+ with gr.Accordion(label=accordion_info, open=True,
338
+ render=False) as accordion:
339
+ [radio.render() for radio in radios]
340
+ conversational_checkbox.render()
341
+ max_tokens_slider.render()
342
+
343
+ _ = gr.ChatInterface(
344
+ respond,
345
+ additional_inputs=[
346
+ conversational_checkbox,
347
+ max_tokens_slider,
348
+ *radios,
349
+ ],
350
+ additional_inputs_accordion=accordion,
351
+ title=title,
352
+ concurrency_limit=5,
353
+ )
354
+
355
+ # Render login/logout button
356
+ login_button = gr.Button("Log In")
357
+ blocks.load(get_login_button, None, login_button)
358
+
359
+ accordion.render()
360
+ blocks.load(get_model_options, None, radios)
361
+
362
+ return blocks
363
 
364
 
365
  if __name__ == "__main__":
366
+ init_config()
367
+ init_oauth()
368
+ blocks = init_gradio()
369
+ app = gr.mount_gradio_app(app, blocks, '/')
370
+ uvicorn.run(app, host='0.0.0.0', port=7860)
requirements.txt CHANGED
@@ -1,2 +1,6 @@
1
  huggingface_hub==0.22.2
2
- openai~=1.0
 
 
 
 
 
1
  huggingface_hub==0.22.2
2
+ openai~=1.0
3
+ fastapi
4
+ authlib
5
+ uvicorn
6
+ starlette
shared.py CHANGED
@@ -1,6 +1,9 @@
 
 
 
1
  import yaml
2
 
3
- from typing import Dict
4
  from pydantic import BaseModel, ValidationError
5
  from huggingface_hub import hf_hub_download
6
  from huggingface_hub.utils import EntryNotFoundError
@@ -8,6 +11,17 @@ from huggingface_hub.utils import EntryNotFoundError
8
  from openai import OpenAI
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class PileConfig(BaseModel):
13
  file2persona: Dict[str, str]
@@ -15,13 +29,17 @@ class PileConfig(BaseModel):
15
  persona2system: Dict[str, str]
16
  prompt: str
17
 
 
18
  class InferenceConfig(BaseModel):
19
  chat_template: str
 
 
20
 
21
  class RepoConfig(BaseModel):
22
  name: str
23
  tag: str
24
 
 
25
  class ModelConfig(BaseModel):
26
  pile: PileConfig
27
  inference: InferenceConfig
 
1
+ from dataclasses import dataclass
2
+ from enum import IntEnum
3
+
4
  import yaml
5
 
6
+ from typing import Dict, Optional, List
7
  from pydantic import BaseModel, ValidationError
8
  from huggingface_hub import hf_hub_download
9
  from huggingface_hub.utils import EntryNotFoundError
 
11
  from openai import OpenAI
12
 
13
 
14
+ class OAuthProvider(IntEnum):
15
+ NONE = 0
16
+ GOOGLE = 1
17
+
18
+
19
+ @dataclass
20
+ class User:
21
+ oauth: OAuthProvider
22
+ username: str
23
+ permissions_id: str
24
+
25
 
26
  class PileConfig(BaseModel):
27
  file2persona: Dict[str, str]
 
29
  persona2system: Dict[str, str]
30
  prompt: str
31
 
32
+
33
  class InferenceConfig(BaseModel):
34
  chat_template: str
35
+ permissions: Dict[str, list] = {}
36
+
37
 
38
  class RepoConfig(BaseModel):
39
  name: str
40
  tag: str
41
 
42
+
43
  class ModelConfig(BaseModel):
44
  pile: PileConfig
45
  inference: InferenceConfig