Spaces:
Sleeping
Sleeping
import os | |
from authlib.integrations.starlette_client import OAuth, OAuthError | |
from fastapi import FastAPI, Depends, Request | |
from starlette.config import Config | |
from starlette.responses import RedirectResponse | |
from starlette.middleware.sessions import SessionMiddleware | |
import uvicorn | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
messages = [{"role": "system", "content": system_message}] | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": message}) | |
response = "" | |
for message in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = message.choices[0].delta.content | |
response += token | |
yield response | |
app = FastAPI() | |
# Replace these with your own OAuth settings | |
GOOGLE_CLIENT_ID = "..." | |
GOOGLE_CLIENT_SECRET = "..." | |
SECRET_KEY = "..." | |
config_data = {'GOOGLE_CLIENT_ID': GOOGLE_CLIENT_ID, 'GOOGLE_CLIENT_SECRET': GOOGLE_CLIENT_SECRET} | |
starlette_config = Config(environ=config_data) | |
oauth = OAuth(starlette_config) | |
oauth.register( | |
name='google', | |
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', | |
client_kwargs={'scope': 'openid email profile'}, | |
) | |
SECRET_KEY = os.environ.get('SECRET_KEY') or "a_very_secret_key" | |
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY) | |
# Dependency to get the current user | |
def get_user(request: Request): | |
user = request.session.get('user') | |
if user: | |
return user['name'] | |
return None | |
def public(user: dict = Depends(get_user)): | |
if user: | |
return RedirectResponse(url='/gradio') | |
else: | |
return RedirectResponse(url='/login-demo') | |
async def logout(request: Request): | |
request.session.pop('user', None) | |
return RedirectResponse(url='/') | |
async def login(request: Request): | |
redirect_uri = request.url_for('auth') | |
# If your app is running on https, you should ensure that the | |
# `redirect_uri` is https, e.g. uncomment the following lines: | |
# | |
# from urllib.parse import urlparse, urlunparse | |
# redirect_uri = urlunparse(urlparse(str(redirect_uri))._replace(scheme='https')) | |
return await oauth.google.authorize_redirect(request, redirect_uri) | |
async def auth(request: Request): | |
try: | |
access_token = await oauth.google.authorize_access_token(request) | |
except OAuthError: | |
return RedirectResponse(url='/') | |
request.session['user'] = dict(access_token)["userinfo"] | |
return RedirectResponse(url='/') | |
with gr.Blocks() as login_demo: | |
gr.Button("Login", link="/login") | |
app = gr.mount_gradio_app(app, login_demo, path="/login-demo") | |
def greet(request: gr.Request): | |
return f"Welcome to Gradio, {request.username}" | |
with gr.Blocks() as main_demo: | |
m = gr.Markdown("Welcome to Gradio!") | |
gr.Button("Logout", link="/logout") | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
], | |
) | |
main_demo.load(greet, None, m, demo) | |
app = gr.mount_gradio_app(app, main_demo, path="/gradio", auth_dependency=get_user) | |
if __name__ == '__main__': | |
uvicorn.run(app) |