Maksym Batiuk
Edit folder structure
82fbb90
raw
history blame
5.48 kB
import gradio as gr
from src.rag.question_answerer import QuestionAnsweringBot
from src.rag.retriever import Retriever
def ask_question(
question: str,
model: str,
api_key: str,
semantic_usage: float,
initial_top_n: int,
final_top_n: int,
) -> tuple[str, str]:
"""
Handles question input from the user and returns the answer with relevant context chunks.
Args:
question (str): User's question.
api_key (str): The API key for LiteLLM.
semantic_usage (float): Weight for semantic usage in retrieval (0-1).
Returns:
tuple[str, str]: The answer and the relevant context chunks.
"""
# Set the API key and model name
qa_bot.model = model
qa_bot.api_key = api_key
# Get the answer and context chunks
answer, contexts = qa_bot.answer_question(
question=question,
bm25_weight=(1 - semantic_usage),
initial_top_n=initial_top_n,
final_top_n=final_top_n,
)
# Format contexts as a single string
formatted_contexts = "\n".join(
[
f"Context {i}:\n{chunk}" + "\n-----------------------------------\n"
for i, chunk in enumerate(contexts, 1)
]
)
return answer, formatted_contexts
if __name__ == "__main__":
# Initialize retriever and bot
retriever = Retriever(chunked_dir="data/chunked")
qa_bot = QuestionAnsweringBot(retriever=retriever)
# Gradio UI
with gr.Blocks() as demo:
# Add custom CSS for Montserrat font, button styling, and purple slider/checkbox
demo.css = """
@import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@400;600&display=swap');
* {
font-family: 'Montserrat', sans-serif !important;
}
#custom-button {
color: #9d00ff !important;
background-color: transparent !important;
border: 2px solid #9d00ff !important;
border-radius: 5px !important;
padding: 10px 20px !important;
font-weight: bold !important;
font-size: 16px !important;
cursor: pointer !important;
}
#custom-button:hover {
background-color: #9d00ff !important;
color: white !important;
}
/* Slider bar background */
input[type="range"]::-webkit-slider-runnable-track{
background: #e4e4e4;
border-radius: 3px;
height: 8px;
}
/* Circle */
input[type="range"]::-webkit-slider-thumb {
-webkit-appearance: none;
appearance: none;
width: 15px;
height: 15px;
background: #9d00ff;
border-radius: 50%;
cursor: pointer;
}
"""
# Short Description
gr.Markdown(
"""
<h1 style='text-align: center; color: #9d00ff;'>RAG Contextual Answering</h1>
<p style='text-align: center;'>This tool allows you to ask questions and receive contextual answers
with relevant information from the files.</p>
"""
)
# Input Section
with gr.Group():
with gr.Row():
model = gr.Textbox(
value="groq/llama3-8b-8192",
label="Model Name",
placeholder="Enter your model name here",
)
api_key = gr.Textbox(
label="API Key",
placeholder="Enter your API key here",
type="password",
)
question = gr.Textbox(
value="Which survival instincts prey have?",
label="Question",
placeholder="Type your question here...",
)
semantic_usage = gr.Slider(
label="Semantic usage (0 - only key phrases search, 1 - only semantic search)",
minimum=0, maximum=1, value=0.5, step=0.001,
)
with gr.Row():
initial_top_n = gr.Slider(
label="BM25 & Semantic Search Top N",
minimum=1, maximum=100, value=50, step=1,
)
final_top_n = gr.Slider(
label="Reranker Top N",
minimum=1, maximum=100, value=5, step=1,
)
# Submit Button
submit_button = gr.Button("Get Answer", elem_id="custom-button")
# Answer and Context Sections
answer = gr.Textbox(label="Answer", interactive=False, lines=5)
chunks_response = gr.Textbox(
label="Context Chunks",
interactive=False,
)
# Button Action
submit_button.click(
ask_question,
inputs=[
question,
model,
api_key,
semantic_usage,
initial_top_n,
final_top_n,
],
outputs=[answer, chunks_response],
)
# Footer
gr.Markdown(
"""
<h3 style='text-align: center; color: #9d00ff;'>Made by Maksym Batiuk and Olena Morozevych.</h3>
"""
)
# Launch the Gradio app
demo.launch(server_name="0.0.0.0", server_port=7860)