neondaniel commited on
Commit
d9f8b28
1 Parent(s): ee0a441

Implement Google oauth

Browse files

Refactor `app.py` into separate functions
Add docstrings
Add `allowed_domains` Inference config with default value

Files changed (3) hide show
  1. app.py +262 -67
  2. requirements.txt +5 -1
  3. shared.py +5 -2
app.py CHANGED
@@ -1,49 +1,180 @@
1
  import os
2
  import json
3
- from typing import List, Tuple
4
- from collections import OrderedDict
5
-
6
  import gradio as gr
7
 
8
- from shared import Client
9
-
10
-
11
- config = json.loads(os.environ['CONFIG'])
12
-
 
 
 
13
 
 
14
 
 
 
15
  clients = {}
16
- for name in config:
17
- model_personas = config[name].get("personas", {})
18
- client = Client(
19
- api_url=os.environ[config[name]['api_url']],
20
- api_key=os.environ[config[name]['api_key']],
21
- personas=model_personas
 
 
 
 
 
 
 
 
 
 
 
22
  )
23
- clients[name] = client
24
-
25
-
26
- model_names = list(config.keys())
27
- radio_infos = [f"{name} ({clients[name].vllm_model_name})" for name in model_names]
28
- accordion_info = "Persona and LLM Options - Choose one:"
29
-
30
-
31
-
32
- def parse_radio_select(radio_select):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  value_index = next(i for i in range(len(radio_select)) if radio_select[i] is not None)
34
- model = model_names[value_index]
35
  persona = radio_select[value_index]
36
  return model, persona
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def respond(
41
- message,
42
  history: List[Tuple[str, str]],
43
- conversational,
44
- max_tokens,
45
  *radio_select,
46
  ):
 
 
 
 
 
 
 
 
 
47
  model, persona = parse_radio_select(radio_select)
48
 
49
  client = clients[model]
@@ -83,45 +214,109 @@ def respond(
83
  return response
84
 
85
 
86
- # Components
87
- radios = [gr.Radio(choices=clients[name].personas.keys(), value=None, label=info) for name, info in zip(model_names, radio_infos)]
88
- radios[0].value = list(clients[model_names[0]].personas.keys())[0]
89
-
90
- conversational_checkbox = gr.Checkbox(value=True, label="conversational")
91
- max_tokens_slider = gr.Slider(minimum=64, maximum=2048, value=512, step=64, label="Max new tokens")
92
-
93
-
94
-
95
- with gr.Blocks() as blocks:
96
- # Events
97
- radio_state = gr.State([radio.value for radio in radios])
98
- @gr.on(triggers=[radio.input for radio in radios], inputs=[radio_state, *radios], outputs=[radio_state, *radios])
99
- def radio_click(state, *new_state):
100
- changed_index = next(i for i in range(len(state)) if state[i] != new_state[i])
101
- changed_value = new_state[changed_index]
102
- clean_state = [None if i != changed_index else changed_value for i in range(len(state))]
103
-
104
- return clean_state, *clean_state
105
-
106
- # Compile
107
- with gr.Accordion(label=accordion_info, open=True, render=False) as accordion:
108
- [radio.render() for radio in radios]
109
- conversational_checkbox.render()
110
- max_tokens_slider.render()
111
-
112
- demo = gr.ChatInterface(
113
- respond,
114
- additional_inputs=[
115
- conversational_checkbox,
116
- max_tokens_slider,
117
- *radios,
118
- ],
119
- additional_inputs_accordion=accordion,
120
- title="Neon AI BrainForge Personas and Large Language Models (v2024-07-24)",
121
- concurrency_limit=5,
122
- )
123
- accordion.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
 
126
  if __name__ == "__main__":
127
- blocks.launch()
 
 
 
 
 
1
  import os
2
  import json
 
 
 
3
  import gradio as gr
4
 
5
+ import uvicorn
6
+ from datetime import datetime
7
+ from typing import List, Tuple
8
+ from starlette.config import Config
9
+ from starlette.middleware.sessions import SessionMiddleware
10
+ from starlette.responses import RedirectResponse
11
+ from authlib.integrations.starlette_client import OAuth, OAuthError
12
+ from fastapi import FastAPI, Request
13
 
14
+ from shared import Client
15
 
16
+ app = FastAPI()
17
+ config = {}
18
  clients = {}
19
+ llm_host_names = []
20
+ oauth = None
21
+
22
+
23
+ def init_oauth():
24
+ global oauth
25
+ google_client_id = os.environ.get("GOOGLE_CLIENT_ID")
26
+ google_client_secret = os.environ.get("GOOGLE_CLIENT_SECRET")
27
+ secret_key = os.environ.get('SECRET_KEY') or "a_very_secret_key"
28
+
29
+ starlette_config = Config(environ={"GOOGLE_CLIENT_ID": google_client_id,
30
+ "GOOGLE_CLIENT_SECRET": google_client_secret})
31
+ oauth = OAuth(starlette_config)
32
+ oauth.register(
33
+ name='google',
34
+ server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
35
+ client_kwargs={'scope': 'openid email profile'}
36
  )
37
+ app.add_middleware(SessionMiddleware, secret_key=secret_key)
38
+
39
+
40
+ def init_config():
41
+ """
42
+ Initialize configuration. A configured `api_url` or `api_key` may be an
43
+ envvar reference OR a literal value. Configuration should follow the
44
+ format:
45
+ {"<llm_host_name>": {"api_key": "<api_key>",
46
+ "api_url": "<api_url>"
47
+ }
48
+ }
49
+ """
50
+ global config
51
+ global clients
52
+ global llm_host_names
53
+ config = json.loads(os.environ['CONFIG'])
54
+ for name in config:
55
+ model_personas = config[name].get("personas", {})
56
+ client = Client(
57
+ api_url=os.environ.get(config[name]['api_url'],
58
+ config[name]['api_url']),
59
+ api_key=os.environ.get(config[name]['api_key'],
60
+ config[name]['api_key']),
61
+ personas=model_personas
62
+ )
63
+ clients[name] = client
64
+ llm_host_names = list(config.keys())
65
+
66
+
67
+ def get_allowed_models(user_domain: str) -> List[str]:
68
+ """
69
+ Get a list of allowed endpoints for a specified user domain
70
+ :param user_domain: User domain (i.e. neon.ai, google.com, guest)
71
+ :return: List of allowed endpoints from configuration
72
+ """
73
+ allowed_endpoints = []
74
+ for client in clients:
75
+ if clients[client].config.inference.allowed_domains is None:
76
+ # Allowed domains not specified; model is public
77
+ allowed_endpoints.append(client)
78
+ elif user_domain in clients[client].config.inference.allowed_domains:
79
+ # User domain is in the allowed domain list
80
+ allowed_endpoints.append(client)
81
+ return allowed_endpoints
82
+
83
+
84
+ def parse_radio_select(radio_select: tuple) -> (str, str):
85
+ """
86
+ Parse radio selection to determine the requested model and persona
87
+ :param radio_select: List of radio selection states
88
+ :return: Selected model, persona
89
+ """
90
  value_index = next(i for i in range(len(radio_select)) if radio_select[i] is not None)
91
+ model = llm_host_names[value_index]
92
  persona = radio_select[value_index]
93
  return model, persona
94
 
95
 
96
+ def get_login_button(request: gr.Request) -> gr.Button:
97
+ """
98
+ Get a login/logout button based on current login status
99
+ :param request: Gradio request to evaluate
100
+ :return: Button for either login or logout action
101
+ """
102
+ user = get_user(request)
103
+ print(f"Getting login button for {user}")
104
+
105
+ if user == "guest":
106
+ return gr.Button("Login", link="/login")
107
+ else:
108
+ return gr.Button(f"Logout {user}", link="/logout")
109
+
110
+
111
+ def get_user(request: Request) -> str:
112
+ """
113
+ Get a unique user email address for the specified request
114
+ :param request: FastAPI Request object with user session data
115
+ :return: String user email address or "guest"
116
+ """
117
+ if not request:
118
+ return "guest"
119
+ user = request.session.get('user', {}).get('email') or "guest"
120
+ return user
121
+
122
+
123
+ @app.route('/logout')
124
+ async def logout(request: Request):
125
+ """
126
+ Remove the user session context and reload an un-authenticated session
127
+ :param request: FastAPI Request object with user session data
128
+ :return: Redirect to `/`
129
+ """
130
+ request.session.pop('user', None)
131
+ return RedirectResponse(url='/')
132
+
133
+
134
+ @app.route('/login')
135
+ async def login(request: Request):
136
+ """
137
+ Start oauth flow for login with Google
138
+ :param request: FastAPI Request object
139
+ """
140
+ redirect_uri = request.url_for('auth')
141
+ # Ensure that the `redirect_uri` is https
142
+ from urllib.parse import urlparse, urlunparse
143
+ redirect_uri = urlunparse(urlparse(str(redirect_uri))._replace(scheme='https'))
144
+
145
+ return await oauth.google.authorize_redirect(request, redirect_uri)
146
+
147
+
148
+ @app.route('/auth')
149
+ async def auth(request: Request):
150
+ """
151
+ Callback endpoint for Google oauth
152
+ :param request: FastAPI Request object
153
+ """
154
+ try:
155
+ access_token = await oauth.google.authorize_access_token(request)
156
+ except OAuthError:
157
+ return RedirectResponse(url='/')
158
+ request.session['user'] = dict(access_token)["userinfo"]
159
+ return RedirectResponse(url='/')
160
+
161
 
162
  def respond(
163
+ message: str,
164
  history: List[Tuple[str, str]],
165
+ conversational: bool,
166
+ max_tokens: int,
167
  *radio_select,
168
  ):
169
+ """
170
+ Send user input to a vLLM backend and return the generated response
171
+ :param message: String input from the user
172
+ :param history: Optional list of chat history (<user message>,<llm message>)
173
+ :param conversational: If true, include chat history
174
+ :param max_tokens: Maximum tokens for the LLM to generate
175
+ :param radio_select: List of radio selection args to parse
176
+ :return: String LLM response
177
+ """
178
  model, persona = parse_radio_select(radio_select)
179
 
180
  client = clients[model]
 
214
  return response
215
 
216
 
217
+ def get_model_options(request: gr.Request) -> List[gr.Radio]:
218
+ """
219
+ Get allowed models for the specified session.
220
+ :param request: Gradio request object to get user from
221
+ :return: List of Radio objects for available models
222
+ """
223
+ if request:
224
+ # `user` is a valid Google email address or 'guest'
225
+ user = get_user(request.request)
226
+ else:
227
+ user = "guest"
228
+ print(f"Getting models for {user}")
229
+
230
+ domain = "guest" if user == "guest" else user.split('@')[1]
231
+ allowed_llm_host_names = get_allowed_models(domain)
232
+
233
+ radio_infos = [f"{name} ({clients[name].vllm_model_name})"
234
+ for name in allowed_llm_host_names]
235
+ # Components
236
+ radios = [gr.Radio(choices=clients[name].personas.keys(),
237
+ value=None, label=info) for name, info
238
+ in zip(allowed_llm_host_names, radio_infos)]
239
+
240
+ # Select the first available option by default
241
+ radios[0].value = list(clients[allowed_llm_host_names[0]].personas.keys())[0]
242
+ print(f"Set default persona to {radios[0].value} for {allowed_llm_host_names[0]}")
243
+ # Ensure we always have the same number of rows
244
+ while len(radios) < len(llm_host_names):
245
+ radios.append(gr.Radio(choices=[], value=None, label="Not Authorized"))
246
+ return radios
247
+
248
+
249
+ def init_gradio() -> gr.Blocks:
250
+ """
251
+ Initialize a Gradio demo
252
+ :return:
253
+ """
254
+ conversational_checkbox = gr.Checkbox(value=True, label="conversational")
255
+ max_tokens_slider = gr.Slider(minimum=64, maximum=2048, value=512, step=64,
256
+ label="Max new tokens")
257
+ radios = get_model_options(None)
258
+
259
+ with gr.Blocks() as blocks:
260
+ # Events
261
+ radio_state = gr.State([radio.value for radio in radios])
262
+
263
+ @gr.on(triggers=[blocks.load, *[radio.input for radio in radios]],
264
+ inputs=[radio_state, *radios], outputs=[radio_state, *radios])
265
+ def radio_click(state, *new_state):
266
+ try:
267
+ changed_index = next(i for i in range(len(state))
268
+ if state[i] != new_state[i])
269
+ changed_value = new_state[changed_index]
270
+ except StopIteration:
271
+ # TODO: This is the result of some error in rendering a selected
272
+ # option.
273
+ # Changed to current selection
274
+ changed_value = [i for i in new_state if i is not None][0]
275
+ changed_index = new_state.index(changed_value)
276
+ clean_state = [None if i != changed_index else changed_value
277
+ for i in range(len(state))]
278
+ return clean_state, *clean_state
279
+
280
+ # Compile
281
+ # TODO: Define a configuration structure for this information
282
+ accordion_info = config.get("accordian_info") or \
283
+ "Persona and LLM Options - Choose one:"
284
+ version = config.get("version") or \
285
+ f"v{datetime.now().strftime('%Y-%m-%d')}"
286
+ title = config.get("title") or \
287
+ f"Neon AI BrainForge Personas and Large Language Models ({version})"
288
+
289
+ with gr.Accordion(label=accordion_info, open=True,
290
+ render=False) as accordion:
291
+ [radio.render() for radio in radios]
292
+ conversational_checkbox.render()
293
+ max_tokens_slider.render()
294
+
295
+ _ = gr.ChatInterface(
296
+ respond,
297
+ additional_inputs=[
298
+ conversational_checkbox,
299
+ max_tokens_slider,
300
+ *radios,
301
+ ],
302
+ additional_inputs_accordion=accordion,
303
+ title=title,
304
+ concurrency_limit=5,
305
+ )
306
+
307
+ # Render login/logout button
308
+ login_button = gr.Button("Log In")
309
+ blocks.load(get_login_button, None, login_button)
310
+
311
+ accordion.render()
312
+ blocks.load(get_model_options, None, radios)
313
+
314
+ return blocks
315
 
316
 
317
  if __name__ == "__main__":
318
+ init_config()
319
+ init_oauth()
320
+ blocks = init_gradio()
321
+ app = gr.mount_gradio_app(app, blocks, '/', auth_dependency=get_user)
322
+ uvicorn.run(app, host='0.0.0.0', port=7860)
requirements.txt CHANGED
@@ -1,2 +1,6 @@
1
  huggingface_hub==0.22.2
2
- openai~=1.0
 
 
 
 
 
1
  huggingface_hub==0.22.2
2
+ openai~=1.0
3
+ fastapi
4
+ authlib
5
+ uvicorn
6
+ starlette
shared.py CHANGED
@@ -1,6 +1,6 @@
1
  import yaml
2
 
3
- from typing import Dict
4
  from pydantic import BaseModel, ValidationError
5
  from huggingface_hub import hf_hub_download
6
  from huggingface_hub.utils import EntryNotFoundError
@@ -8,20 +8,23 @@ from huggingface_hub.utils import EntryNotFoundError
8
  from openai import OpenAI
9
 
10
 
11
-
12
  class PileConfig(BaseModel):
13
  file2persona: Dict[str, str]
14
  file2prefix: Dict[str, str]
15
  persona2system: Dict[str, str]
16
  prompt: str
17
 
 
18
  class InferenceConfig(BaseModel):
19
  chat_template: str
 
 
20
 
21
  class RepoConfig(BaseModel):
22
  name: str
23
  tag: str
24
 
 
25
  class ModelConfig(BaseModel):
26
  pile: PileConfig
27
  inference: InferenceConfig
 
1
  import yaml
2
 
3
+ from typing import Dict, Optional, List
4
  from pydantic import BaseModel, ValidationError
5
  from huggingface_hub import hf_hub_download
6
  from huggingface_hub.utils import EntryNotFoundError
 
8
  from openai import OpenAI
9
 
10
 
 
11
  class PileConfig(BaseModel):
12
  file2persona: Dict[str, str]
13
  file2prefix: Dict[str, str]
14
  persona2system: Dict[str, str]
15
  prompt: str
16
 
17
+
18
  class InferenceConfig(BaseModel):
19
  chat_template: str
20
+ allowed_domains: Optional[List[str]] = None
21
+
22
 
23
  class RepoConfig(BaseModel):
24
  name: str
25
  tag: str
26
 
27
+
28
  class ModelConfig(BaseModel):
29
  pile: PileConfig
30
  inference: InferenceConfig