buster / app.py
jerpint's picture
move files around (#17)
2b4f517 unverified
raw history blame
No virus
3.99 kB
import logging
import os
from typing import Optional
import gradio as gr
import pandas as pd
from buster.completers import Completion
import cfg
from cfg import setup_buster
buster = setup_buster(cfg.buster_cfg)
# suppress httpx logs they are spammy and uninformative
logging.getLogger("httpx").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
AVAILABLE_SOURCES = ["towardsai", "wikipedia", "langchain_course"]
def format_sources(matched_documents: pd.DataFrame) -> str:
if len(matched_documents) == 0:
return ""
documents_answer_template: str = "πŸ“ Here are the sources I used to answer your question:\n\n{documents}\n\n{footnote}"
document_template: str = "[πŸ”— {document.title}]({document.url}), relevance: {document.similarity_to_answer:2.1f} %"
matched_documents.similarity_to_answer = (
matched_documents.similarity_to_answer * 100
)
# drop duplicates, keep highest ranking ones
matched_documents = matched_documents.sort_values(
"similarity_to_answer", ascending=False
).drop_duplicates("title", keep="first")
documents = "\n".join(
[
document_template.format(document=document)
for _, document in matched_documents.iterrows()
]
)
footnote: str = "I'm a bot πŸ€– and not always perfect."
return documents_answer_template.format(documents=documents, footnote=footnote)
def add_sources(history, completion):
if completion.answer_relevant:
formatted_sources = format_sources(completion.matched_documents)
history.append([None, formatted_sources])
return history
def user(user_input, history):
"""Adds user's question immediately to the chat."""
return "", history + [[user_input, None]]
def get_empty_source_completion(user_input):
return Completion(
user_input=user_input,
answer_text="You have to select at least one source from the dropdown menu.",
matched_documents=pd.DataFrame(),
error=False,
)
def get_answer(history, sources: Optional[list[str]] = None):
user_input = history[-1][0]
if len(sources) == 0:
completion = get_empty_source_completion(user_input)
else:
completion = buster.process_input(user_input, sources=sources)
history[-1][1] = ""
for token in completion.answer_generator:
history[-1][1] += token
yield history, completion
block = gr.Blocks()
with block:
with gr.Row():
gr.Markdown(
"<h3><center>Buster πŸ€–: A Question-Answering Bot for your documentation</center></h3>"
)
source_selection = gr.Dropdown(
choices=AVAILABLE_SOURCES,
label="Select Sources",
value=AVAILABLE_SOURCES,
multiselect=True,
)
chatbot = gr.Chatbot()
with gr.Row():
question = gr.Textbox(
label="What's your question?",
placeholder="Ask a question to AI stackoverflow here...",
lines=1,
)
submit = gr.Button(value="Send", variant="secondary")
examples = gr.Examples(
examples=cfg.example_questions,
inputs=question,
)
gr.Markdown(
"This application uses ChatGPT to search the docs for relevant info and answer questions. "
"\n\n### Powered by [Buster πŸ€–](www.github.com/jerpint/buster)"
)
response = gr.State()
submit.click(user, [question, chatbot], [question, chatbot], queue=False).then(
get_answer, inputs=[chatbot, source_selection], outputs=[chatbot, response]
).then(add_sources, inputs=[chatbot, response], outputs=[chatbot])
question.submit(user, [question, chatbot], [question, chatbot], queue=False).then(
get_answer, inputs=[chatbot, source_selection], outputs=[chatbot, response]
).then(add_sources, inputs=[chatbot, response], outputs=[chatbot])
block.queue(concurrency_count=16)
block.launch(debug=True, share=False)