Aqib2489's picture
Update app.py
90b2605 verified
# Initialize Google API and model
import torch
device = torch.device("cpu") # Force CPU
import base64
import os
from huggingface_hub import login
import PIL.Image
from byaldi import RAGMultiModalModel
import PIL.Image as PILImage
import io
import textwrap
import google.generativeai as genai
import gradio as gr # Add Gradio for UI
from PIL import Image as PILImage
import os
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=GOOGLE_API_KEY)
model = genai.GenerativeModel('models/gemini-1.5-flash-latest')
# Load the RAG multi-modal model
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2", verbose=1)
RAG.to(device)
# Specify the index path where the index was saved during the first run
index_path = "/home/mohammadaqib/Desktop/project/research/Multi-Modal-RAG/Colpali/BCC"
RAG = RAGMultiModalModel.from_index(index_path)
# Initialize conversation history
conversation_history = []
def get_user_input(query):
"""Process user input."""
return query
def process_image_from_results(results):
"""Process images from the search results and merge them."""
image_list = []
for i in range(min(3, len(results))):
try:
# Ensure the result has a base64 attribute
image_bytes = base64.b64decode(results[i].base64)
image = PILImage.open(io.BytesIO(image_bytes)) # Open image directly from bytes
image_list.append(image)
except AttributeError:
print(f"Result {i} does not contain a 'base64' attribute.")
# Merge images if any
if image_list:
total_width = sum(img.width for img in image_list)
max_height = max(img.height for img in image_list)
merged_image = PILImage.new('RGB', (total_width, max_height))
x_offset = 0
for img in image_list:
merged_image.paste(img, (x_offset, 0))
x_offset += img.width
# Save the merged image
merged_image.save('merged_image.jpg')
return merged_image
else:
return None
def generate_answer(query, image):
"""Generate an answer using the Gemini model and the merged image."""
response = model.generate_content([f'Answer to the question asked using the image. Also mention the reference from image to support your answer. Example, Table Number or Statement number or any metadata. Question: {query}', image], stream=True)
response.resolve()
return response.text
def classify_system_question(query):
"""Check if the question is related to the system itself."""
response = model.generate_content([f"Determine if the question is about the system itself, like 'Who are you?' or 'What can you do?' or 'Introduce yourself' . Answer with 'yes' or 'no'. Question: {query}"], stream=True)
response.resolve()
return response.text.strip().lower() == "yes"
def classify_question(query):
"""Classify whether the question is general or domain-specific using Gemini."""
response = model.generate_content([f"Classify this question as 'general' or 'domain-specific'. Give one word answer i.e general or domain-specific. General questions are greetings and questions involving general knowledge like the capital of France. General questions also involve politics, geography, history, economics, cosmology, information about famous personalities, etc. Question: {query}"], stream=True)
response.resolve()
return response.text.strip().lower() # Assuming the response is either 'general' or 'domain-specific'
def chatbot(query, history):
max_history_length = 50 # Number of recent exchanges to keep
# Truncate the history to the last `max_history_length` exchanges
truncated_history = history[-max_history_length:]
# Add user input to the history
truncated_history.append(("You: " + query, "Model:"))
# Step 1: Check if the question is about the system
if classify_system_question(query):
text = "I am an AI assistant capable of answering queries related to the National Building Code of Canada and general questions. I was developed by the research group [SITE] at the University of Alberta. How can I assist you further?"
else:
# Step 2: Classify the question as general or domain-specific
question_type = classify_question(query)
# If the question is general, use Gemini to directly answer it
if question_type == "general":
text = model.generate_content([f"Answer this general question: {query}. If it is a greeting respond accordingly and if it is not greeting add a prefix saying that it is a general query."], stream=True)
text.resolve()
text = text.text
else:
# Step 3: Query the RAG model for domain-specific answers
results = RAG.search(query, k=3)
# Check if RAG found any results
if not results:
text = model.generate_content([f"Answer this question: {query}"], stream=True)
text.resolve()
text = text.text
text = "It is a general query. ANSWER:" + text
else:
# Process images from the results
image = process_image_from_results(results)
# Generate the answer using the Gemini model if an image is found
if image:
text = generate_answer(query, image)
text = "It is a query from NBCC. ANSWER:" + text
# Check if the answer is a fallback message (indicating no relevant answer)
if any(keyword in text.lower() for keyword in [
"does not provide",
"cannot answer",
"does not contain",
"no relevant answer",
"not found",
"information unavailable",
"not in the document",
"unable to provide",
"no data",
"missing information",
"no match",
"provided text does not describe",
"are not explicitly listed",
"are not explicitly mentioned",
"no results",
"not available",
"query not found"
]):
# Fallback to Gemini for answering
text = model.generate_content([f"Answer this general question in concise manner: {query}"], stream=True)
text.resolve()
text = text.text
text = "It is a general query. ANSWER: " + text
else:
text = model.generate_content([f"Answer this question: {query}"], stream=True)
text.resolve()
text = text.text
text = "It is a query from NBCC. ANSWER: " + text
# Add the model's response to the truncated history
truncated_history[-1] = (truncated_history[-1][0], "Model: " + text) # Update the most recent message with model's answer
# Return the output text, updated state, and chat history (as tuple pairs)
return text, truncated_history, truncated_history # Ensure all three outputs are returned in the correct order
import gradio as gr
# Define Gradio interface
with gr.Blocks() as iface:
# Set the conversation state as an empty list
state = gr.State([])
# Custom CSS to beautify the interface
iface.css = """
.gradio-container {
background-color: #f9f9f9;
border-radius: 15px;
padding: 20px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}
.gr-chatbox {
background-color: #f0f0f0;
border-radius: 10px;
padding: 10px;
max-height: 1000px;
overflow-y: scroll;
margin-bottom: 10px;
}
.gr-textbox input {
border-radius: 10px;
padding: 12px;
font-size: 16px;
border: 1px solid #ccc;
width: 100%;
margin-top: 10px;
box-sizing: border-box;
}
.gr-textbox input:focus {
border-color: #4CAF50;
outline: none;
}
.gr-button {
background-color: #4CAF50;
color: white;
padding: 12px;
border-radius: 10px;
font-size: 16px;
border: none;
cursor: pointer;
}
.gr-button:hover {
background-color: #45a049;
}
.gr-chatbot {
font-family: "Arial", sans-serif;
font-size: 14px;
}
.gr-chatbot .gr-chatbot-user {
background-color: #e1f5fe;
border-radius: 10px;
padding: 8px;
margin-bottom: 10px;
max-width: 80%;
}
.gr-chatbot .gr-chatbot-model {
background-color: #ffffff;
border-radius: 10px;
padding: 8px;
margin-bottom: 10px;
max-width: 80%;
}
.gr-chatbot .gr-chatbot-user p,
.gr-chatbot .gr-chatbot-model p {
margin: 0;
}
#input_box {
position: fixed;
bottom: 20px;
width: 95%;
padding: 10px;
border-radius: 10px;
box-shadow: 0 0 5px rgba(0, 0, 0, 0.2);
}
"""
# Add an image at the top of the page
with gr.Column():
gr.Image("/home/mohammadaqib/Pictures/Screenshots/site.png",height = 300) # Use the image URL
gr.Markdown(
"# Question Answering System Over National Building Code of Canada"
)
# Chatbot UI
with gr.Row():
chat_history = gr.Chatbot(label="Chat History", height=250)
# Place input at the bottom
with gr.Row():
query = gr.Textbox(
label="Ask a Question",
placeholder="Enter your question here...",
lines=1,
interactive=True,
elem_id="input_box" # Custom ID for styling
)
# Output for the response
output_text = gr.Textbox(label="Answer", interactive=False, visible=False) # Optional to hide
# Define the interaction behavior
query.submit(
chatbot,
inputs=[query, state],
outputs=[output_text, state, chat_history],
show_progress=True
).then(
lambda _: "", # Clear the input after submission
inputs=None,
outputs=query
)
gr.Markdown("<p style='position: fixed; bottom:0; width: 100%; text-align: left; font-style: italic; margin-left: 15%; font-size: 18px;'>Developed by Mohammad Aqib, MSc Student at the University of Alberta, supervised by Dr. Qipei (Gavin) Mei.</p>", elem_id="footer")
# Launch the interface
iface.launch(share=True)