Spaces:
Runtime error
Runtime error
import os | |
import json | |
import gradio as gr | |
import uvicorn | |
from datetime import datetime | |
from typing import List, Tuple | |
from starlette.config import Config | |
from starlette.middleware.sessions import SessionMiddleware | |
from starlette.responses import RedirectResponse | |
from authlib.integrations.starlette_client import OAuth, OAuthError | |
from fastapi import FastAPI, Request | |
from shared import Client | |
app = FastAPI() | |
config = {} | |
clients = {} | |
llm_host_names = [] | |
oauth = None | |
def init_oauth(): | |
global oauth | |
google_client_id = os.environ.get("GOOGLE_CLIENT_ID") | |
google_client_secret = os.environ.get("GOOGLE_CLIENT_SECRET") | |
secret_key = os.environ.get('SECRET_KEY') or "a_very_secret_key" | |
starlette_config = Config(environ={"GOOGLE_CLIENT_ID": google_client_id, | |
"GOOGLE_CLIENT_SECRET": google_client_secret}) | |
oauth = OAuth(starlette_config) | |
oauth.register( | |
name='google', | |
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', | |
client_kwargs={'scope': 'openid email profile'} | |
) | |
app.add_middleware(SessionMiddleware, secret_key=secret_key) | |
def init_config(): | |
""" | |
Initialize configuration. A configured `api_url` or `api_key` may be an | |
envvar reference OR a literal value. Configuration should follow the | |
format: | |
{"<llm_host_name>": {"api_key": "<api_key>", | |
"api_url": "<api_url>" | |
} | |
} | |
""" | |
global config | |
global clients | |
global llm_host_names | |
config = json.loads(os.environ['CONFIG']) | |
for name in config: | |
model_personas = config[name].get("personas", {}) | |
client = Client( | |
api_url=os.environ.get(config[name]['api_url'], | |
config[name]['api_url']), | |
api_key=os.environ.get(config[name]['api_key'], | |
config[name]['api_key']), | |
personas=model_personas | |
) | |
clients[name] = client | |
llm_host_names = list(config.keys()) | |
def get_allowed_models(user_domain: str) -> List[str]: | |
""" | |
Get a list of allowed endpoints for a specified user domain | |
:param user_domain: User domain (i.e. neon.ai, google.com, guest) | |
:return: List of allowed endpoints from configuration | |
""" | |
allowed_endpoints = [] | |
for client in clients: | |
if clients[client].config.inference.allowed_domains is None: | |
# Allowed domains not specified; model is public | |
allowed_endpoints.append(client) | |
elif user_domain in clients[client].config.inference.allowed_domains: | |
# User domain is in the allowed domain list | |
allowed_endpoints.append(client) | |
return allowed_endpoints | |
def parse_radio_select(radio_select: tuple) -> (str, str): | |
""" | |
Parse radio selection to determine the requested model and persona | |
:param radio_select: List of radio selection states | |
:return: Selected model, persona | |
""" | |
value_index = next(i for i in range(len(radio_select)) if radio_select[i] is not None) | |
model = llm_host_names[value_index] | |
persona = radio_select[value_index] | |
return model, persona | |
def get_login_button(request: gr.Request) -> gr.Button: | |
""" | |
Get a login/logout button based on current login status | |
:param request: Gradio request to evaluate | |
:return: Button for either login or logout action | |
""" | |
user = get_user(request) | |
print(f"Getting login button for {user}") | |
if user == "guest": | |
return gr.Button("Login", link="/login") | |
else: | |
return gr.Button(f"Logout {user}", link="/logout") | |
def get_user(request: Request) -> str: | |
""" | |
Get a unique user email address for the specified request | |
:param request: FastAPI Request object with user session data | |
:return: String user email address or "guest" | |
""" | |
if not request: | |
return "guest" | |
user = request.session.get('user', {}).get('email') or "guest" | |
return user | |
async def logout(request: Request): | |
""" | |
Remove the user session context and reload an un-authenticated session | |
:param request: FastAPI Request object with user session data | |
:return: Redirect to `/` | |
""" | |
request.session.pop('user', None) | |
return RedirectResponse(url='/') | |
async def login(request: Request): | |
""" | |
Start oauth flow for login with Google | |
:param request: FastAPI Request object | |
""" | |
redirect_uri = request.url_for('auth') | |
# Ensure that the `redirect_uri` is https | |
from urllib.parse import urlparse, urlunparse | |
redirect_uri = urlunparse(urlparse(str(redirect_uri))._replace(scheme='https')) | |
return await oauth.google.authorize_redirect(request, redirect_uri) | |
async def auth(request: Request): | |
""" | |
Callback endpoint for Google oauth | |
:param request: FastAPI Request object | |
""" | |
try: | |
access_token = await oauth.google.authorize_access_token(request) | |
except OAuthError: | |
return RedirectResponse(url='/') | |
request.session['user'] = dict(access_token)["userinfo"] | |
return RedirectResponse(url='/') | |
def respond( | |
message: str, | |
history: List[Tuple[str, str]], | |
conversational: bool, | |
max_tokens: int, | |
*radio_select, | |
): | |
""" | |
Send user input to a vLLM backend and return the generated response | |
:param message: String input from the user | |
:param history: Optional list of chat history (<user message>,<llm message>) | |
:param conversational: If true, include chat history | |
:param max_tokens: Maximum tokens for the LLM to generate | |
:param radio_select: List of radio selection args to parse | |
:return: String LLM response | |
""" | |
model, persona = parse_radio_select(radio_select) | |
client = clients[model] | |
messages = [] | |
try: | |
system_prompt = client.personas[persona] | |
except KeyError: | |
supported_personas = list(client.personas.keys()) | |
raise gr.Error(f"Model '{model}' does not support persona '{persona}', only {supported_personas}") | |
if system_prompt is not None: | |
messages.append({"role": "system", "content": system_prompt}) | |
if conversational: | |
for val in history[-2:]: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": message}) | |
completion = client.openai.chat.completions.create( | |
model=client.vllm_model_name, | |
messages=messages, | |
max_tokens=max_tokens, | |
temperature=0, | |
extra_body={ | |
"add_special_tokens": True, | |
"repetition_penalty": 1.05, | |
"use_beam_search": True, | |
"best_of": 5, | |
}, | |
) | |
response = completion.choices[0].message.content | |
return response | |
def get_model_options(request: gr.Request) -> List[gr.Radio]: | |
""" | |
Get allowed models for the specified session. | |
:param request: Gradio request object to get user from | |
:return: List of Radio objects for available models | |
""" | |
if request: | |
# `user` is a valid Google email address or 'guest' | |
user = get_user(request.request) | |
else: | |
user = "guest" | |
print(f"Getting models for {user}") | |
domain = "guest" if user == "guest" else user.split('@')[1] | |
allowed_llm_host_names = get_allowed_models(domain) | |
radio_infos = [f"{name} ({clients[name].vllm_model_name})" | |
for name in allowed_llm_host_names] | |
# Components | |
radios = [gr.Radio(choices=clients[name].personas.keys(), | |
value=None, label=info) for name, info | |
in zip(allowed_llm_host_names, radio_infos)] | |
# Select the first available option by default | |
radios[0].value = list(clients[allowed_llm_host_names[0]].personas.keys())[0] | |
print(f"Set default persona to {radios[0].value} for {allowed_llm_host_names[0]}") | |
# Ensure we always have the same number of rows | |
while len(radios) < len(llm_host_names): | |
radios.append(gr.Radio(choices=[], value=None, label="Not Authorized")) | |
return radios | |
def init_gradio() -> gr.Blocks: | |
""" | |
Initialize a Gradio demo | |
:return: | |
""" | |
conversational_checkbox = gr.Checkbox(value=True, label="conversational") | |
max_tokens_slider = gr.Slider(minimum=64, maximum=2048, value=512, step=64, | |
label="Max new tokens") | |
radios = get_model_options(None) | |
with gr.Blocks() as blocks: | |
# Events | |
radio_state = gr.State([radio.value for radio in radios]) | |
def radio_click(state, *new_state): | |
try: | |
changed_index = next(i for i in range(len(state)) | |
if state[i] != new_state[i]) | |
changed_value = new_state[changed_index] | |
except StopIteration: | |
# TODO: This is the result of some error in rendering a selected | |
# option. | |
# Changed to current selection | |
changed_value = [i for i in new_state if i is not None][0] | |
changed_index = new_state.index(changed_value) | |
clean_state = [None if i != changed_index else changed_value | |
for i in range(len(state))] | |
return clean_state, *clean_state | |
# Compile | |
# TODO: Define a configuration structure for this information | |
accordion_info = config.get("accordian_info") or \ | |
"Persona and LLM Options - Choose one:" | |
version = config.get("version") or \ | |
f"v{datetime.now().strftime('%Y-%m-%d')}" | |
title = config.get("title") or \ | |
f"Neon AI BrainForge Personas and Large Language Models ({version})" | |
with gr.Accordion(label=accordion_info, open=True, | |
render=False) as accordion: | |
[radio.render() for radio in radios] | |
conversational_checkbox.render() | |
max_tokens_slider.render() | |
_ = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
conversational_checkbox, | |
max_tokens_slider, | |
*radios, | |
], | |
additional_inputs_accordion=accordion, | |
title=title, | |
concurrency_limit=5, | |
) | |
# Render login/logout button | |
login_button = gr.Button("Log In") | |
blocks.load(get_login_button, None, login_button) | |
accordion.render() | |
blocks.load(get_model_options, None, radios) | |
return blocks | |
if __name__ == "__main__": | |
init_config() | |
init_oauth() | |
blocks = init_gradio() | |
app = gr.mount_gradio_app(app, blocks, '/', auth_dependency=get_user) | |
uvicorn.run(app, host='0.0.0.0', port=7860) | |