playground / main.py
Francesco's picture
updated main with auth
9c97588
raw
history blame
8.76 kB
import logging
from pathlib import Path
from typing import List, Optional, Tuple, Dict
import json
from dotenv import load_dotenv
load_dotenv()
from queue import Empty, Queue
from threading import Thread
import os
import gradio as gr
from langchain.chat_models import ChatOpenAI
from langchain.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain.schema import AIMessage, BaseMessage, HumanMessage
from js import get_window_url_params
from callback import QueueCallback
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from db import (
User,
Chat,
create_user,
get_client,
get_user_by_username,
add_chat_by_uid,
)
MODELS_NAMES = ["gpt-3.5-turbo", "gpt-4"]
DEFAULT_TEMPERATURE = 0.7
ChatHistory = List[str]
logging.basicConfig(
format="[%(asctime)s %(levelname)s]: %(message)s", level=logging.INFO
)
# load redis client
client = get_client()
# load up our system prompt
system_message_prompt = SystemMessagePromptTemplate.from_template(
Path("prompts/system.prompt").read_text()
)
# for the human, we will just inject the text
human_message_prompt_template = HumanMessagePromptTemplate.from_template("{text}")
with open("data/patients.json") as f:
patiens = json.load(f)
patients_names = [el["name"] for el in patiens]
def message_handler(
chat: Optional[ChatOpenAI],
message: str,
chatbot_messages: ChatHistory,
messages: List[BaseMessage],
) -> Tuple[ChatOpenAI, str, ChatHistory, List[BaseMessage]]:
if chat is None:
# in the queue we will store our streamed tokens
queue = Queue()
# let's create our default chat
chat = ChatOpenAI(
model_name=MODELS_NAMES[0],
temperature=DEFAULT_TEMPERATURE,
streaming=True,
callbacks=([QueueCallback(queue)]),
)
else:
# hacky way to get the queue back
queue = chat.callbacks[0].queue
job_done = object()
logging.info("asking question to GPT")
# let's add the messages to our stuff
messages.append(HumanMessage(content=f"Doctor:{message}"))
chatbot_messages.append((message, ""))
# this is a little wrapper we need cuz we have to add the job_done
def task():
chat(messages)
queue.put(job_done)
# now let's start a thread and run the generation inside it
t = Thread(target=task)
t.start()
# this will hold the content as we generate
content = ""
# now, we read the next_token from queue and do what it has to be done
while True:
try:
next_token = queue.get(True, timeout=1)
if next_token is job_done:
break
content += next_token
chatbot_messages[-1] = (message, content)
yield chat, "", chatbot_messages, messages
except Empty:
continue
# finally we can add our reply to messsages
messages.append(AIMessage(content=content))
logging.debug(f"reply = {content}")
logging.info(f"Done!")
return chat, "", chatbot_messages, messages
def on_clear_click() -> Tuple[str, List, List]:
return "", [], []
def on_done_click(
chatbot_messages: ChatHistory, patient: str, user: User
) -> Tuple[str, List, List]:
logging.info(f"Saving chat for user={user}")
add_chat_by_uid(
client, Chat(patient=patient, messages=chatbot_messages), user["uid"]
)
return on_clear_click()
def on_apply_settings_click(model_name: str, temperature: float):
logging.info(
f"Applying settings: model_name={model_name}, temperature={temperature}"
)
chat = ChatOpenAI(
model_name=model_name,
temperature=temperature,
streaming=True,
callbacks=[QueueCallback(Queue())],
)
# don't forget to nuke our queue
chat.callbacks[0].queue.empty()
return chat, *on_clear_click()
def on_drop_down_change(selected_item, messages):
index = patients_names.index(selected_item)
patient = patiens[index]
messages = [system_message_prompt.format(patient=patient)]
print(f"You selected: {selected_item}", index)
return patient, patient, [], messages
def on_demo_load(url_params, request: gr.Request):
username = request.username or url_params.get("username", "test")
logging.info(f"Getting user for username={username}")
create_user(client, User(username=username, uid=None))
user = get_user_by_username(client, username)
logging.info(f"User {user}")
print(f"got url_params: {url_params}")
return user, f"Nice to see you {user['username']} πŸ‘‹"
url_params = gr.JSON({}, visible=False, label="URL Params")
# some css why not, "borrowed" from https://huggingface.co/spaces/ysharma/Gradio-demo-streaming/blob/main/app.py
with gr.Blocks(
css="""#col_container {width: 700px; margin-left: auto; margin-right: auto;}
#chatbot {height: 400px; overflow: auto;}"""
) as demo:
# here we keep our state so multiple user can use the app at the same time!
messages = gr.State([system_message_prompt.format(patient=patiens[0])])
# same thing for the chat, we want one chat per use so callbacks are unique I guess
chat = gr.State(None)
user = gr.State(None)
patient = gr.State(patiens[0])
# see here https://github.com/gradio-app/gradio/discussions/2949#discussioncomment-5278991
url_params.render()
with gr.Column(elem_id="col_container"):
gr.Markdown("# Welcome to OscePal! πŸ‘¨β€βš•οΈπŸ§‘β€βš•οΈ")
welcome_markdown = gr.Markdown("")
demo.load(
fn=on_demo_load,
inputs=[url_params],
outputs=[user, welcome_markdown],
_js=get_window_url_params,
)
chatbot = gr.Chatbot()
with gr.Column():
message = gr.Textbox(label="chat input")
message.submit(
message_handler,
[chat, message, chatbot, messages],
[chat, message, chatbot, messages],
queue=True,
)
submit = gr.Button("Send Message", variant="primary")
submit.click(
message_handler,
[chat, message, chatbot, messages],
[chat, message, chatbot, messages],
)
with gr.Row():
with gr.Column():
js = "(x) => confirm('Press a button!')"
done = gr.Button("Done", variant="stop")
done.click(
on_done_click,
[chatbot, patient, user],
[message, chatbot, messages],
)
with gr.Accordion("Settings", open=False):
model_name = gr.Dropdown(
choices=MODELS_NAMES, value=MODELS_NAMES[0], label="model"
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1,
label="temperature",
interactive=True,
)
apply_settings = gr.Button("Apply")
apply_settings.click(
on_apply_settings_click,
[model_name, temperature],
[chat, message, chatbot, messages],
)
with gr.Column():
patients_names = [el["name"] for el in patiens]
dropdown = gr.Dropdown(
choices=patients_names,
value=patients_names[0],
interactive=True,
label="Patient",
)
patient_card = gr.JSON(patient.value, visible=True, label="Patient card")
dropdown.change(
fn=on_drop_down_change,
inputs=[dropdown, messages],
outputs=[patient_card, patient, chatbot, messages],
)
# app = FastAPI()
# os.makedirs("static", exist_ok=True)
# app.mount("/static", StaticFiles(directory="static"), name="static")
# templates = Jinja2Templates(directory="templates")
# @app.get("/", response_class=HTMLResponse)
# async def home(request: Request):
# return templates.TemplateResponse(
# "home.html", {"request": request, "videos": []})
def auth_handler(username: str, password: str) -> bool:
if password != os.environ["GRADIO_PASSWORD"]:
return False
return True
demo.queue()
demo.launch(auth=auth_handler)
# gradio_app = gr.routes.App.create_app(demo)
# app.mount("/gradio", gradio_app)