buster / app.py
Louis-François Bouchard
improved layout
fe6af19
raw
history blame
No virus
7.16 kB
import logging
import os
from typing import Optional
import gradio as gr
import pandas as pd
from buster.completers import Completion
from gradio.themes.utils import (
colors,
fonts,
get_matching_version,
get_theme_assets,
sizes,
)
import cfg
from cfg import setup_buster
buster = setup_buster(cfg.buster_cfg)
# suppress httpx logs they are spammy and uninformative
logging.getLogger("httpx").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64))
AVAILABLE_SOURCES_UI = [
"Towards AI",
"HuggingFace",
"Wikipedia",
"Gen AI 360: LangChain",
"Gen AI 360: LLMs",
]
AVAILABLE_SOURCES = [
"towards_ai",
"hf_transformers",
"wikipedia",
"langchain_course",
"llm_course",
]
def log_likes(completion: Completion, like_data: gr.LikeData):
# make it a str so json-parsable
collection = "liked_data-test"
completion_json = completion.to_json(
columns_to_ignore=["embedding", "similarity", "similarity_to_answer"]
)
completion_json["liked"] = like_data.liked
logger.info(f"User reported {like_data.liked=}")
try:
cfg.mongo_db[collection].insert_one(completion_json)
logger.info("")
except:
logger.info("Something went wrong logging")
def log_emails(email: gr.Textbox):
collection = "email_data-test"
logger.info(f"User reported {email=}")
email_document = {"email": email}
try:
cfg.mongo_db[collection].insert_one(email_document)
logger.info("")
except:
logger.info("Something went wrong logging")
return ""
def format_sources(matched_documents: pd.DataFrame) -> str:
if len(matched_documents) == 0:
return ""
documents_answer_template: str = "📝 Here are the sources I used to answer your question:\n\n{documents}\n\n{footnote}"
document_template: str = "[🔗 {document.source}: {document.title}]({document.url}), relevance: {document.similarity_to_answer:2.1f} %" # | # total chunks matched: {document.repetition:d}"
matched_documents.similarity_to_answer = (
matched_documents.similarity_to_answer * 100
)
# matched_documents["repetition"] = matched_documents.groupby("title")[
# "title"
# ].transform("size")
# drop duplicates, keep highest ranking ones
matched_documents = matched_documents.sort_values(
"similarity_to_answer", ascending=False
).drop_duplicates("title", keep="first")
# Revert back to correct display
display_source_to_ui = {
ui: src for ui, src in zip(AVAILABLE_SOURCES, AVAILABLE_SOURCES_UI)
}
matched_documents["source"] = matched_documents["source"].replace(
display_source_to_ui
)
documents = "\n".join(
[
document_template.format(document=document)
for _, document in matched_documents.iterrows()
]
)
footnote: str = "I'm a bot 🤖 and not always perfect."
return documents_answer_template.format(documents=documents, footnote=footnote)
def add_sources(history, completion):
if completion.answer_relevant:
formatted_sources = format_sources(completion.matched_documents)
history.append([None, formatted_sources])
return history
def user(user_input, history):
"""Adds user's question immediately to the chat."""
return "", history + [[user_input, None]]
def get_empty_source_completion(user_input):
return Completion(
user_input=user_input,
answer_text="You have to select at least one source from the dropdown menu.",
matched_documents=pd.DataFrame(),
error=False,
)
def get_answer(history, sources: Optional[list[str]] = None):
user_input = history[-1][0]
if len(sources) == 0:
completion = get_empty_source_completion(user_input)
else:
# Go to code names
display_ui_to_source = {
ui: src for ui, src in zip(AVAILABLE_SOURCES_UI, AVAILABLE_SOURCES)
}
sources_renamed = [display_ui_to_source[disp] for disp in sources]
completion = buster.process_input(user_input, sources=sources_renamed)
history[-1][1] = ""
for token in completion.answer_generator:
history[-1][1] += token
yield history, completion
# CSS = """
# .contain { display: flex; flex-direction: column; }
# .gradio-container { height: 100vh !important; }
# #component-0 { height: 100%; }
# #chatbot { flex-grow: 1; overflow: auto;}
# """
theme = gr.themes.Soft()
# theme.block_background_fill
# demo = gr.Blocks(theme=theme)
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="blue",
font=fonts.GoogleFont("Source Sans Pro"),
font_mono=fonts.GoogleFont("IBM Plex Mono"),
)
) as demo:
with gr.Row():
gr.Markdown(
"<h3><center>Towards AI 🤖: A Question-Answering Bot for anything AI-related</center></h3>"
"<h6><center><i>Powered by Activeloop and 4th Generation Intel® Xeon® Scalable Processors</i></center></h6>"
)
latest_completion = gr.State()
source_selection = gr.Dropdown(
choices=AVAILABLE_SOURCES_UI,
label="Select Sources",
value=AVAILABLE_SOURCES_UI,
multiselect=True,
)
chatbot = gr.Chatbot(elem_id="chatbot", show_copy_button=True)
with gr.Row():
question = gr.Textbox(
label="What's your question?",
placeholder="Ask a question to our AI tutor here...",
lines=1,
)
submit = gr.Button(value="Send", variant="secondary")
with gr.Row():
examples = gr.Examples(
examples=cfg.example_questions,
inputs=question,
)
with gr.Row():
email = gr.Textbox(
label="Want to receive updates about our AI tutor?",
placeholder="Enter your email here...",
lines=1,
scale=3,
)
submit_email = gr.Button(value="Submit", variant="secondary", scale=0)
gr.Markdown(
"This application uses ChatGPT to search the docs for relevant information and answer questions."
"\n\n### Built in top of the open-source [Buster 🤖](https://www.github.com/jerpint/buster) project. Huge thanks to them."
)
completion = gr.State()
submit.click(user, [question, chatbot], [question, chatbot], queue=False).then(
get_answer, inputs=[chatbot, source_selection], outputs=[chatbot, completion]
).then(add_sources, inputs=[chatbot, completion], outputs=[chatbot])
question.submit(user, [question, chatbot], [question, chatbot], queue=False).then(
get_answer, inputs=[chatbot, source_selection], outputs=[chatbot, completion]
).then(add_sources, inputs=[chatbot, completion], outputs=[chatbot])
chatbot.like(log_likes, completion)
submit_email.click(log_emails, email, email)
email.submit(log_emails, email, email)
demo.queue(concurrency_count=CONCURRENCY_COUNT)
demo.launch(debug=True, share=False)