Spaces:
Runtime error
Runtime error
import base64 | |
import os | |
from huggingface_hub import login | |
import PIL.Image | |
from byaldi import RAGMultiModalModel | |
import PIL.Image as PILImage | |
import io | |
import textwrap | |
import google.generativeai as genai | |
import gradio as gr # Add Gradio for UI | |
from PIL import Image as PILImage | |
# Initialize Google API and model | |
import torch | |
device = torch.device("cpu") # Force CPU | |
import os | |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
genai.configure(api_key=GOOGLE_API_KEY) | |
model = genai.GenerativeModel('models/gemini-1.5-flash-latest') | |
# Load the RAG multi-modal model | |
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2", verbose=1) | |
# Specify the index path where the index was saved during the first run | |
index_path = "/home/mohammadaqib/Desktop/project/research/Multi-Modal-RAG/Colpali/BCC" | |
RAG = RAGMultiModalModel.from_index(index_path) | |
# Initialize conversation history | |
conversation_history = [] | |
def get_user_input(query): | |
"""Process user input.""" | |
return query | |
def process_image_from_results(results): | |
"""Process images from the search results and merge them.""" | |
image_list = [] | |
for i in range(min(3, len(results))): | |
try: | |
# Ensure the result has a base64 attribute | |
image_bytes = base64.b64decode(results[i].base64) | |
image = PILImage.open(io.BytesIO(image_bytes)) # Open image directly from bytes | |
image_list.append(image) | |
except AttributeError: | |
print(f"Result {i} does not contain a 'base64' attribute.") | |
# Merge images if any | |
if image_list: | |
total_width = sum(img.width for img in image_list) | |
max_height = max(img.height for img in image_list) | |
merged_image = PILImage.new('RGB', (total_width, max_height)) | |
x_offset = 0 | |
for img in image_list: | |
merged_image.paste(img, (x_offset, 0)) | |
x_offset += img.width | |
# Save the merged image | |
merged_image.save('merged_image.jpg') | |
return merged_image | |
else: | |
return None | |
def generate_answer(query, image): | |
"""Generate an answer using the Gemini model and the merged image.""" | |
response = model.generate_content([f'Answer to the question asked using the image. Also mention the reference from image to support your answer. Example, Table Number or Statement number or any metadata. Question: {query}', image], stream=True) | |
response.resolve() | |
return response.text | |
def classify_system_question(query): | |
"""Check if the question is related to the system itself.""" | |
response = model.generate_content([f"Determine if the question is about the system itself, like 'Who are you?' or 'What can you do?' or 'Introduce yourself' . Answer with 'yes' or 'no'. Question: {query}"], stream=True) | |
response.resolve() | |
return response.text.strip().lower() == "yes" | |
def classify_question(query): | |
"""Classify whether the question is general or domain-specific using Gemini.""" | |
response = model.generate_content([f"Classify this question as 'general' or 'domain-specific'. Give one word answer i.e general or domain-specific. General questions are greetings and questions involving general knowledge like the capital of France. General questions also involve politics, geography, history, economics, cosmology, information about famous personalities, etc. Question: {query}"], stream=True) | |
response.resolve() | |
return response.text.strip().lower() # Assuming the response is either 'general' or 'domain-specific' | |
def chatbot(query, history): | |
max_history_length = 50 # Number of recent exchanges to keep | |
# Truncate the history to the last `max_history_length` exchanges | |
truncated_history = history[-max_history_length:] | |
# Add user input to the history | |
truncated_history.append(("You: " + query, "Model:")) | |
# Step 1: Check if the question is about the system | |
if classify_system_question(query): | |
text = "I am an AI assistant capable of answering queries related to the National Building Code of Canada and general questions. I was developed by the research group [SITE] at the University of Alberta. How can I assist you further?" | |
else: | |
# Step 2: Classify the question as general or domain-specific | |
question_type = classify_question(query) | |
# If the question is general, use Gemini to directly answer it | |
if question_type == "general": | |
text = model.generate_content([f"Answer this general question: {query}. If it is a greeting respond accordingly and if it is not greeting add a prefix saying that it is a general query."], stream=True) | |
text.resolve() | |
text = text.text | |
else: | |
# Step 3: Query the RAG model for domain-specific answers | |
results = RAG.search(query, k=3) | |
# Check if RAG found any results | |
if not results: | |
text = model.generate_content([f"Answer this question: {query}"], stream=True) | |
text.resolve() | |
text = text.text | |
text = "It is a general query. ANSWER:" + text | |
else: | |
# Process images from the results | |
image = process_image_from_results(results) | |
# Generate the answer using the Gemini model if an image is found | |
if image: | |
text = generate_answer(query, image) | |
text = "It is a query from NBCC. ANSWER:" + text | |
# Check if the answer is a fallback message (indicating no relevant answer) | |
if any(keyword in text.lower() for keyword in [ | |
"does not provide", | |
"cannot answer", | |
"does not contain", | |
"no relevant answer", | |
"not found", | |
"information unavailable", | |
"not in the document", | |
"unable to provide", | |
"no data", | |
"missing information", | |
"no match", | |
"provided text does not describe", | |
"are not explicitly listed", | |
"are not explicitly mentioned", | |
"no results", | |
"not available", | |
"query not found" | |
]): | |
# Fallback to Gemini for answering | |
text = model.generate_content([f"Answer this general question in concise manner: {query}"], stream=True) | |
text.resolve() | |
text = text.text | |
text = "It is a general query. ANSWER: " + text | |
else: | |
text = model.generate_content([f"Answer this question: {query}"], stream=True) | |
text.resolve() | |
text = text.text | |
text = "It is a query from NBCC. ANSWER: " + text | |
# Add the model's response to the truncated history | |
truncated_history[-1] = (truncated_history[-1][0], "Model: " + text) # Update the most recent message with model's answer | |
# Return the output text, updated state, and chat history (as tuple pairs) | |
return text, truncated_history, truncated_history # Ensure all three outputs are returned in the correct order | |
import gradio as gr | |
# Define Gradio interface | |
with gr.Blocks() as iface: | |
# Set the conversation state as an empty list | |
state = gr.State([]) | |
# Custom CSS to beautify the interface | |
iface.css = """ | |
.gradio-container { | |
background-color: #f9f9f9; | |
border-radius: 15px; | |
padding: 20px; | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
} | |
.gr-chatbox { | |
background-color: #f0f0f0; | |
border-radius: 10px; | |
padding: 10px; | |
max-height: 1000px; | |
overflow-y: scroll; | |
margin-bottom: 10px; | |
} | |
.gr-textbox input { | |
border-radius: 10px; | |
padding: 12px; | |
font-size: 16px; | |
border: 1px solid #ccc; | |
width: 100%; | |
margin-top: 10px; | |
box-sizing: border-box; | |
} | |
.gr-textbox input:focus { | |
border-color: #4CAF50; | |
outline: none; | |
} | |
.gr-button { | |
background-color: #4CAF50; | |
color: white; | |
padding: 12px; | |
border-radius: 10px; | |
font-size: 16px; | |
border: none; | |
cursor: pointer; | |
} | |
.gr-button:hover { | |
background-color: #45a049; | |
} | |
.gr-chatbot { | |
font-family: "Arial", sans-serif; | |
font-size: 14px; | |
} | |
.gr-chatbot .gr-chatbot-user { | |
background-color: #e1f5fe; | |
border-radius: 10px; | |
padding: 8px; | |
margin-bottom: 10px; | |
max-width: 80%; | |
} | |
.gr-chatbot .gr-chatbot-model { | |
background-color: #ffffff; | |
border-radius: 10px; | |
padding: 8px; | |
margin-bottom: 10px; | |
max-width: 80%; | |
} | |
.gr-chatbot .gr-chatbot-user p, | |
.gr-chatbot .gr-chatbot-model p { | |
margin: 0; | |
} | |
#input_box { | |
position: fixed; | |
bottom: 20px; | |
width: 95%; | |
padding: 10px; | |
border-radius: 10px; | |
box-shadow: 0 0 5px rgba(0, 0, 0, 0.2); | |
} | |
""" | |
# Add an image at the top of the page | |
with gr.Column(): | |
gr.Image("/home/mohammadaqib/Pictures/Screenshots/site.png",height = 300) # Use the image URL | |
gr.Markdown( | |
"# Question Answering System Over National Building Code of Canada" | |
) | |
# Chatbot UI | |
with gr.Row(): | |
chat_history = gr.Chatbot(label="Chat History", height=250) | |
# Place input at the bottom | |
with gr.Row(): | |
query = gr.Textbox( | |
label="Ask a Question", | |
placeholder="Enter your question here...", | |
lines=1, | |
interactive=True, | |
elem_id="input_box" # Custom ID for styling | |
) | |
# Output for the response | |
output_text = gr.Textbox(label="Answer", interactive=False, visible=False) # Optional to hide | |
# Define the interaction behavior | |
query.submit( | |
chatbot, | |
inputs=[query, state], | |
outputs=[output_text, state, chat_history], | |
show_progress=True | |
).then( | |
lambda _: "", # Clear the input after submission | |
inputs=None, | |
outputs=query | |
) | |
gr.Markdown("<p style='position: fixed; bottom:0; width: 100%; text-align: left; font-style: italic; margin-left: 15%; font-size: 18px;'>Developed by Mohammad Aqib, MSc Student at the University of Alberta, supervised by Dr. Qipei (Gavin) Mei.</p>", elem_id="footer") | |
# Launch the interface | |
iface.launch(share=True) | |