Multimodal_v2 / user.py
NEXAS's picture
Update user.py
32bc45c verified
raw
history blame
6.59 kB
import chromadb
from PIL import Image as PILImage
import streamlit as st
import os
from utils.qa import chain
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
import base64
import io
# Initialize Chromadb client
path = "mm_vdb2"
client = chromadb.PersistentClient(path=path)
image_collection = client.get_collection(name="image")
video_collection = client.get_collection(name='video_collection')
# Set up memory storage for the chat
memory_storage = StreamlitChatMessageHistory(key="chat_messages")
memory = ConversationBufferWindowMemory(memory_key="chat_history", human_prefix="User", chat_memory=memory_storage, k=3)
# Function to get an answer from the chain
def get_answer(query):
response = chain.invoke(query)
return response.get("result", "No result found.")
# Function to display images in the UI
def display_images(image_collection, query_text, max_distance=None, debug=False):
results = image_collection.query(
query_texts=[query_text],
n_results=10,
include=['uris', 'distances']
)
uris = results['uris'][0]
distances = results['distances'][0]
sorted_results = sorted(zip(uris, distances), key=lambda x: x[0])
cols = st.columns(3)
for i, (uri, distance) in enumerate(sorted_results):
if max_distance is None or distance <= max_distance:
try:
img = PILImage.open(uri)
with cols[i % 3]:
st.image(img, use_container_width=True)
except Exception as e:
st.error(f"Error loading image: {e}")
# Function to display videos in the UI
def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False):
displayed_videos = set()
results = video_collection.query(
query_texts=[query_text],
n_results=max_results,
include=['uris', 'distances', 'metadatas']
)
uris = results['uris'][0]
distances = results['distances'][0]
metadatas = results['metadatas'][0]
for uri, distance, metadata in zip(uris, distances, metadatas):
video_uri = metadata['video_uri']
if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos:
if debug:
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}")
st.video(video_uri)
displayed_videos.add(video_uri)
else:
if debug:
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")
# Function to format the inputs for image and video processing
def format_prompt_inputs(image_collection, video_collection, user_query):
frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55)
image_candidates = image_uris(image_collection, user_query, max_distance=1.5)
inputs = {"query": user_query}
frame = frame_candidates[0] if frame_candidates else ""
inputs["frame"] = frame
if image_candidates:
image = image_candidates[0]
with PILImage.open(image) as img:
img = img.resize((img.width // 6, img.height // 6))
img = img.convert("L")
with io.BytesIO() as output:
img.save(output, format="JPEG", quality=60)
compressed_image_data = output.getvalue()
inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8')
else:
inputs["image_data_1"] = ""
return inputs
# Main function to initialize and run the UI
def home():
# Set up the page layout
st.set_page_config(layout='wide', page_title="Virtual Tutor")
# Header
st.header("Welcome to Virtual Tutor - CHAT")
# SVG Banner for UI branding
st.markdown("""
<svg width="600" height="100">
<text x="50%" y="50%" font-family="San serif" font-size="42px" fill="Black" text-anchor="middle" stroke="white"
stroke-width="0.3" stroke-linejoin="round">Virtual Tutor - CHAT
</text>
</svg>
""", unsafe_allow_html=True)
# Initialize the chat session if not already initialized
if "messages" not in st.session_state:
st.session_state.messages = [{"role": "assistant", "content": "Hi! How may I assist you today?"}]
# Styling for the chat input container
st.markdown("""
<style>
.stChatInputContainer > div {
background-color: #000000;
}
</style>
""", unsafe_allow_html=True)
# Display previous chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
# Display chat messages from memory
for i, msg in enumerate(memory_storage.messages):
name = "user" if i % 2 == 0 else "assistant"
st.chat_message(name).markdown(msg.content)
# Handle user input and generate response
if user_input := st.chat_input("Enter your question here..."):
with st.chat_message("user"):
st.markdown(user_input)
with st.spinner("Generating Response..."):
with st.chat_message("assistant"):
response = get_answer(user_input)
answer = response
st.markdown(answer)
# Save user and assistant messages to session state
message = {"role": "assistant", "content": answer}
message_u = {"role": "user", "content": user_input}
st.session_state.messages.append(message_u)
st.session_state.messages.append(message)
# Process inputs for image/video
inputs = format_prompt_inputs(image_collection, video_collection, user_input)
# Display images
st.markdown("### Images")
display_images(image_collection, user_input, max_distance=1.55, debug=False)
# Display videos based on frames
st.markdown("### Videos")
frame = inputs["frame"]
if frame:
directory_name = frame.split('/')[1]
video_path = f"videos_flattened/{directory_name}.mp4"
if os.path.exists(video_path):
st.video(video_path)
else:
st.error("Video file not found.")
# Call the home function to run the app
if __name__ == "__main__":
home()