Spaces:
No application file
No application file
import chromadb | |
from PIL import Image as PILImage | |
import streamlit as st | |
import os | |
from utils.qa import chain | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain_community.chat_message_histories import StreamlitChatMessageHistory | |
import base64 | |
import io | |
# Initialize Chromadb client | |
path = "mm_vdb2" | |
client = chromadb.PersistentClient(path=path) | |
image_collection = client.get_collection(name="image") | |
video_collection = client.get_collection(name='video_collection') | |
# Set up memory storage for the chat | |
memory_storage = StreamlitChatMessageHistory(key="chat_messages") | |
memory = ConversationBufferWindowMemory(memory_key="chat_history", human_prefix="User", chat_memory=memory_storage, k=3) | |
# Function to get an answer from the chain | |
def get_answer(query): | |
response = chain.invoke(query) | |
return response.get("result", "No result found.") | |
# Function to display images in the UI | |
def display_images(image_collection, query_text, max_distance=None, debug=False): | |
results = image_collection.query( | |
query_texts=[query_text], | |
n_results=10, | |
include=['uris', 'distances'] | |
) | |
uris = results['uris'][0] | |
distances = results['distances'][0] | |
sorted_results = sorted(zip(uris, distances), key=lambda x: x[0]) | |
cols = st.columns(3) | |
for i, (uri, distance) in enumerate(sorted_results): | |
if max_distance is None or distance <= max_distance: | |
try: | |
img = PILImage.open(uri) | |
with cols[i % 3]: | |
st.image(img, use_container_width=True) | |
except Exception as e: | |
st.error(f"Error loading image: {e}") | |
# Function to display videos in the UI | |
def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False): | |
displayed_videos = set() | |
results = video_collection.query( | |
query_texts=[query_text], | |
n_results=max_results, | |
include=['uris', 'distances', 'metadatas'] | |
) | |
uris = results['uris'][0] | |
distances = results['distances'][0] | |
metadatas = results['metadatas'][0] | |
for uri, distance, metadata in zip(uris, distances, metadatas): | |
video_uri = metadata['video_uri'] | |
if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos: | |
if debug: | |
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}") | |
st.video(video_uri) | |
displayed_videos.add(video_uri) | |
else: | |
if debug: | |
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)") | |
# Function to format the inputs for image and video processing | |
def format_prompt_inputs(image_collection, video_collection, user_query): | |
frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55) | |
image_candidates = image_uris(image_collection, user_query, max_distance=1.5) | |
inputs = {"query": user_query} | |
frame = frame_candidates[0] if frame_candidates else "" | |
inputs["frame"] = frame | |
if image_candidates: | |
image = image_candidates[0] | |
with PILImage.open(image) as img: | |
img = img.resize((img.width // 6, img.height // 6)) | |
img = img.convert("L") | |
with io.BytesIO() as output: | |
img.save(output, format="JPEG", quality=60) | |
compressed_image_data = output.getvalue() | |
inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8') | |
else: | |
inputs["image_data_1"] = "" | |
return inputs | |
# Main function to initialize and run the UI | |
def home(): | |
# Set up the page layout | |
st.set_page_config(layout='wide', page_title="Virtual Tutor") | |
# Header | |
st.header("Welcome to Virtual Tutor - CHAT") | |
# SVG Banner for UI branding | |
st.markdown(""" | |
<svg width="600" height="100"> | |
<text x="50%" y="50%" font-family="San serif" font-size="42px" fill="Black" text-anchor="middle" stroke="white" | |
stroke-width="0.3" stroke-linejoin="round">Virtual Tutor - CHAT | |
</text> | |
</svg> | |
""", unsafe_allow_html=True) | |
# Initialize the chat session if not already initialized | |
if "messages" not in st.session_state: | |
st.session_state.messages = [{"role": "assistant", "content": "Hi! How may I assist you today?"}] | |
# Styling for the chat input container | |
st.markdown(""" | |
<style> | |
.stChatInputContainer > div { | |
background-color: #000000; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Display previous chat messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
# Display chat messages from memory | |
for i, msg in enumerate(memory_storage.messages): | |
name = "user" if i % 2 == 0 else "assistant" | |
st.chat_message(name).markdown(msg.content) | |
# Handle user input and generate response | |
if user_input := st.chat_input("Enter your question here..."): | |
with st.chat_message("user"): | |
st.markdown(user_input) | |
with st.spinner("Generating Response..."): | |
with st.chat_message("assistant"): | |
response = get_answer(user_input) | |
answer = response | |
st.markdown(answer) | |
# Save user and assistant messages to session state | |
message = {"role": "assistant", "content": answer} | |
message_u = {"role": "user", "content": user_input} | |
st.session_state.messages.append(message_u) | |
st.session_state.messages.append(message) | |
# Process inputs for image/video | |
inputs = format_prompt_inputs(image_collection, video_collection, user_input) | |
# Display images | |
st.markdown("### Images") | |
display_images(image_collection, user_input, max_distance=1.55, debug=False) | |
# Display videos based on frames | |
st.markdown("### Videos") | |
frame = inputs["frame"] | |
if frame: | |
directory_name = frame.split('/')[1] | |
video_path = f"videos_flattened/{directory_name}.mp4" | |
if os.path.exists(video_path): | |
st.video(video_path) | |
else: | |
st.error("Video file not found.") | |
# Call the home function to run the app | |
if __name__ == "__main__": | |
home() | |