|
import gradio as gr |
|
from src.rag.question_answerer import QuestionAnsweringBot |
|
from src.rag.retriever import Retriever |
|
|
|
|
|
def ask_question( |
|
question: str, |
|
model: str, |
|
api_key: str, |
|
semantic_usage: float, |
|
initial_top_n: int, |
|
final_top_n: int, |
|
) -> tuple[str, str]: |
|
""" |
|
Handles question input from the user and returns the answer with relevant context chunks. |
|
|
|
Args: |
|
question (str): User's question. |
|
api_key (str): The API key for LiteLLM. |
|
semantic_usage (float): Weight for semantic usage in retrieval (0-1). |
|
|
|
Returns: |
|
tuple[str, str]: The answer and the relevant context chunks. |
|
""" |
|
|
|
qa_bot.model = model |
|
qa_bot.api_key = api_key |
|
|
|
|
|
answer, contexts = qa_bot.answer_question( |
|
question=question, |
|
bm25_weight=(1 - semantic_usage), |
|
initial_top_n=initial_top_n, |
|
final_top_n=final_top_n, |
|
) |
|
|
|
|
|
formatted_contexts = "\n".join( |
|
[ |
|
f"Context {i}:\n{chunk}" + "\n-----------------------------------\n" |
|
for i, chunk in enumerate(contexts, 1) |
|
] |
|
) |
|
|
|
return answer, formatted_contexts |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
retriever = Retriever(chunked_dir="data/chunked") |
|
qa_bot = QuestionAnsweringBot(retriever=retriever) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
demo.css = """ |
|
@import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@400;600&display=swap'); |
|
|
|
* { |
|
font-family: 'Montserrat', sans-serif !important; |
|
} |
|
|
|
#custom-button { |
|
color: #9d00ff !important; |
|
background-color: transparent !important; |
|
border: 2px solid #9d00ff !important; |
|
border-radius: 5px !important; |
|
padding: 10px 20px !important; |
|
font-weight: bold !important; |
|
font-size: 16px !important; |
|
cursor: pointer !important; |
|
} |
|
|
|
#custom-button:hover { |
|
background-color: #9d00ff !important; |
|
color: white !important; |
|
} |
|
|
|
|
|
/* Slider bar background */ |
|
input[type="range"]::-webkit-slider-runnable-track{ |
|
background: #e4e4e4; |
|
border-radius: 3px; |
|
height: 8px; |
|
} |
|
|
|
/* Circle */ |
|
input[type="range"]::-webkit-slider-thumb { |
|
-webkit-appearance: none; |
|
appearance: none; |
|
width: 15px; |
|
height: 15px; |
|
background: #9d00ff; |
|
border-radius: 50%; |
|
cursor: pointer; |
|
} |
|
""" |
|
|
|
|
|
gr.Markdown( |
|
""" |
|
<h1 style='text-align: center; color: #9d00ff;'>RAG Contextual Answering</h1> |
|
<p style='text-align: center;'>This tool allows you to ask questions and receive contextual answers |
|
with relevant information from the files.</p> |
|
""" |
|
) |
|
|
|
|
|
with gr.Group(): |
|
with gr.Row(): |
|
model = gr.Textbox( |
|
value="groq/llama3-8b-8192", |
|
label="Model Name", |
|
placeholder="Enter your model name here", |
|
) |
|
api_key = gr.Textbox( |
|
label="API Key", |
|
placeholder="Enter your API key here", |
|
type="password", |
|
) |
|
|
|
question = gr.Textbox( |
|
value="Which survival instincts prey have?", |
|
label="Question", |
|
placeholder="Type your question here...", |
|
) |
|
semantic_usage = gr.Slider( |
|
label="Semantic usage (0 - only key phrases search, 1 - only semantic search)", |
|
minimum=0, maximum=1, value=0.5, step=0.001, |
|
) |
|
with gr.Row(): |
|
initial_top_n = gr.Slider( |
|
label="BM25 & Semantic Search Top N", |
|
minimum=1, maximum=100, value=50, step=1, |
|
) |
|
final_top_n = gr.Slider( |
|
label="Reranker Top N", |
|
minimum=1, maximum=100, value=5, step=1, |
|
) |
|
|
|
|
|
submit_button = gr.Button("Get Answer", elem_id="custom-button") |
|
|
|
|
|
answer = gr.Textbox(label="Answer", interactive=False, lines=5) |
|
chunks_response = gr.Textbox( |
|
label="Context Chunks", |
|
interactive=False, |
|
) |
|
|
|
|
|
submit_button.click( |
|
ask_question, |
|
inputs=[ |
|
question, |
|
model, |
|
api_key, |
|
semantic_usage, |
|
initial_top_n, |
|
final_top_n, |
|
], |
|
outputs=[answer, chunks_response], |
|
) |
|
|
|
|
|
gr.Markdown( |
|
""" |
|
<h3 style='text-align: center; color: #9d00ff;'>Made by Maksym Batiuk and Olena Morozevych.</h3> |
|
""" |
|
) |
|
|
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|