RAGBOT / app.py
Rahatara's picture
Update app.py
3ddea46 verified
raw
history blame
8.63 kB
from typing import Any, List, Tuple
import gradio as gr
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_openai import ChatOpenAI
from langchain_community.document_loaders import PyMuPDFLoader
import fitz
from PIL import Image
import os
import re
import openai
# MyApp class to handle the processes
class MyApp:
def __init__(self) -> None:
self.OPENAI_API_KEY: str = None # Initialize with None
self.chain = None
self.chat_history: list = []
self.documents = None
self.file_name = None
def set_api_key(self, api_key: str):
self.OPENAI_API_KEY = api_key
openai.api_key = api_key
def process_file(self, file) -> Image.Image:
loader = PyMuPDFLoader(file.name)
self.documents = loader.load()
self.file_name = os.path.basename(file.name)
doc = fitz.open(file.name)
page = doc[0]
pix = page.get_pixmap(dpi=150)
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
return image
def build_chain(self, file) -> str:
embeddings = OpenAIEmbeddings(openai_api_key=self.OPENAI_API_KEY)
pdfsearch = Chroma.from_documents(
self.documents,
embeddings,
collection_name=self.file_name,
)
self.chain = ConversationalRetrievalChain.from_llm(
ChatOpenAI(temperature=0.0, openai_api_key=self.OPENAI_API_KEY),
retriever=pdfsearch.as_retriever(search_kwargs={"k": 1}),
return_source_documents=True,
)
return "Vector database built successfully!"
# Function to add text to chat history
def add_text(history: List[Tuple[str, str]], text: str) -> List[Tuple[str, str]]:
if not text:
raise gr.Error("Enter text")
history.append((text, ""))
return history
# Function to get response from the model
def get_response(history, query):
if app.chain is None:
raise gr.Error("The chain has not been built yet. Please ensure the vector database is built before querying.")
try:
result = app.chain.invoke(
{"question": query, "chat_history": app.chat_history}
)
app.chat_history.append((query, result["answer"]))
source_docs = result["source_documents"]
source_texts = []
for doc in source_docs:
source_texts.append(f"Page {doc.metadata['page'] + 1}: {doc.page_content}")
source_texts_str = "\n\n".join(source_texts)
history[-1] = (history[-1][0], result["answer"])
return history, source_texts_str
except Exception as e:
app.chat_history.append((query, "I have no information about it. Feed me knowledge, please!"))
return history, f"I have no information about it. Feed me knowledge, please! Error: {str(e)}"
# Function to get response for the current RAG tab
def get_response_current(history, query):
if app.chain is None:
raise gr.Error("The chain has not been built yet. Please ensure the vector database is built before querying.")
try:
result = app.chain.invoke(
{"question": query, "chat_history": app.chat_history}
)
app.chat_history.append((query, result["answer"]))
source_docs = result["source_documents"]
source_texts = []
for doc in source_docs:
source_texts.append(f"Page {doc.metadata['page'] + 1}: {doc.page_content}")
source_texts_str = "\n\n".join(source_texts)
history[-1] = (history[-1][0], result["answer"])
return history, source_texts_str
except Exception as e:
app.chat_history.append((query, "I have no information about it. Feed me knowledge, please!"))
return history, f"I have no information about it. Feed me knowledge, please! Error: {str(e)}"
# Function to render file
def render_file(file) -> Image.Image:
doc = fitz.open(file.name)
page = doc[0]
pix = page.get_pixmap(dpi=150)
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
return image
# Function to purge chat and render first page of PDF
def purge_chat_and_render_first(file) -> Tuple[Image.Image, list]:
app.chat_history = []
doc = fitz.open(file.name)
page = doc[0]
pix = page.get_pixmap(dpi=150)
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
return image, []
# Function to refresh chat
def refresh_chat():
app.chat_history = []
return []
app = MyApp()
# Function to set API key
def set_api_key(api_key):
app.set_api_key(api_key)
# Pre-process the saved PDF file after setting the API key
saved_file_path = "track_training.pdf"
with open(saved_file_path, 'rb') as saved_file:
app.process_file(saved_file)
app.build_chain(saved_file)
return f"API Key set to {api_key[:4]}...{api_key[-4:]} and vector database built successfully!"
# Gradio interface
with gr.Blocks() as demo:
api_key_input = gr.Textbox(label="OpenAI API Key", type="password", placeholder="Enter your OpenAI API Key")
api_key_btn = gr.Button("Set API Key")
api_key_status = gr.Textbox(value="API Key status", interactive=False)
api_key_btn.click(
fn=set_api_key,
inputs=[api_key_input],
outputs=[api_key_status]
)
with gr.Tab("Inst RAG"):
with gr.Column():
with gr.Row():
btn = gr.UploadButton("📁 Upload a PDF", file_types=[".pdf"])
show_img = gr.Image(label="Uploaded PDF")
btn.upload(
fn=purge_chat_and_render_first,
inputs=[btn],
outputs=[show_img],
)
with gr.Row():
process_btn = gr.Button("Process PDF")
show_img_processed = gr.Image(label="Processed PDF")
process_status = gr.Textbox(label="Processing Status", interactive=False)
process_btn.click(
fn=lambda file: (app.process_file(file), "Processing complete!"),
inputs=[btn],
outputs=[show_img_processed, process_status],
)
with gr.Row():
build_vector_btn = gr.Button("Build Vector Database")
status_text = gr.Textbox(label="Status", value="", interactive=False)
build_vector_btn.click(
fn=app.build_chain,
inputs=[btn],
outputs=[status_text],
)
with gr.Row():
chatbot = gr.Chatbot(elem_id="chatbot")
txt = gr.Textbox(
show_label=False,
placeholder="Enter text and press submit",
scale=2
)
submit_btn = gr.Button("Submit", scale=1)
refresh_btn = gr.Button("Refresh Chat", scale=1)
source_texts_output = gr.Textbox(label="Source Texts", interactive=False)
submit_btn.click(
fn=add_text,
inputs=[chatbot, txt],
outputs=[chatbot],
queue=False,
).success(
fn=get_response, inputs=[chatbot, txt], outputs=[chatbot, source_texts_output]
)
refresh_btn.click(
fn=refresh_chat,
inputs=[],
outputs=[chatbot],
)
with gr.Tab("Current RAG"):
with gr.Column():
chatbot_current = gr.Chatbot(elem_id="chatbot_current")
txt_current = gr.Textbox(
show_label=False,
placeholder="Enter text and press submit",
scale=2
)
submit_btn_current = gr.Button("Submit", scale=1)
refresh_btn_current = gr.Button("Refresh Chat", scale=1)
source_texts_output_current = gr.Textbox(label="Source Texts", interactive=False)
submit_btn_current.click(
fn=add_text,
inputs=[chatbot_current, txt_current],
outputs=[chatbot_current],
queue=False,
).success(
fn=get_response_current, inputs=[chatbot_current, txt_current], outputs=[chatbot_current, source_texts_output_current]
)
refresh_btn_current.click(
fn=refresh_chat,
inputs=[],
outputs=[chatbot_current],
)
demo.queue()
demo.launch()