Spaces:
Sleeping
Sleeping
import streamlit as st | |
import extra_streamlit_components as stx | |
import requests | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForVision2Seq | |
from io import BytesIO | |
import replicate | |
from llama_index.llms.palm import PaLM | |
from llama_index import ServiceContext, VectorStoreIndex, Document | |
from llama_index.memory import ChatMemoryBuffer | |
import os | |
import datetime | |
# Set up the title of the application | |
#st.title("PaLM-Kosmos-Vision") | |
st.set_page_config(layout="wide") | |
st.write("My version of ChatGPT vision. You can upload an image and start chatting with the LLM about the image") | |
# Initialize the cookie manager | |
cookie_manager = stx.CookieManager() | |
# Function to get image caption via Kosmos2. | |
def get_image_caption(image_data): | |
input_data = { | |
"image": image_data, | |
"description_type": "Brief" | |
} | |
output = replicate.run( | |
"lucataco/kosmos-2:3e7b211c29c092f4bcc8853922cc986baa52efe255876b80cac2c2fbb4aff805", | |
input=input_data | |
) | |
# Split the output string on the newline character and take the first item | |
text_description = output.split('\n\n')[0] | |
return text_description | |
# Function to create the chat engine. | |
def create_chat_engine(img_desc, api_key): | |
llm = PaLM(api_key=api_key) | |
service_context = ServiceContext.from_defaults(llm=llm, embed_model="local") | |
doc = Document(text=img_desc) | |
index = VectorStoreIndex.from_documents([doc], service_context=service_context) | |
chatmemory = ChatMemoryBuffer.from_defaults(token_limit=1500) | |
chat_engine = index.as_chat_engine( | |
chat_mode="context", | |
system_prompt=( | |
f"You are a chatbot, able to have normal interactions, as well as talk. " | |
"You always answer in great detail and are polite. Your responses always descriptive. " | |
"Your job is to talk about an image the user has uploaded. Image description: {img_desc}." | |
), | |
verbose=True, | |
memory=chatmemory | |
) | |
return chat_engine | |
# Clear chat function | |
def clear_chat(): | |
if "messages" in st.session_state: | |
del st.session_state.messages | |
if "image_file" in st.session_state: | |
del st.session_state.image_file | |
# Callback function to clear the chat when a new image is uploaded | |
def on_image_upload(): | |
clear_chat() | |
# Retrieve the message count from cookies | |
message_count = cookie_manager.get(cookie='message_count') | |
if message_count is None: | |
message_count = 0 | |
else: | |
message_count = int(message_count) | |
# If the message limit has been reached, disable the inputs | |
if 0: | |
st.error("Notice: The maximum message limit for this demo version has been reached.") | |
# Disabling the uploader and input by not displaying them | |
image_uploader_placeholder = st.empty() # Placeholder for the uploader | |
chat_input_placeholder = st.empty() # Placeholder for the chat input | |
else: | |
# Add a clear chat button | |
if st.button("Clear Chat"): | |
clear_chat() | |
# Image upload section. | |
image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"], key="uploaded_image", on_change=on_image_upload) | |
if image_file: | |
# Display the uploaded image at a standard width. | |
st.image(image_file, caption='Uploaded Image.', width=200) | |
# Process the uploaded image to get a caption. | |
image_data = BytesIO(image_file.getvalue()) | |
img_desc = get_image_caption(image_data) | |
st.write("Image Uploaded Successfully. Ask me anything about it.") | |
# Initialize the chat engine with the image description. | |
chat_engine = create_chat_engine(img_desc, st.secrets['GOOGLE_API_KEY']) | |
# Initialize session state for messages if it doesn't exist | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Display previous messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Handle new user input | |
user_input = st.chat_input("Ask me about the image:", key="chat_input") | |
if user_input: | |
# Append user message to the session state | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
# Display user message immediately | |
with st.chat_message("user"): | |
st.markdown(user_input) | |
# Call the chat engine to get the response if an image has been uploaded | |
if image_file and user_input: | |
try: | |
with st.spinner('Waiting for the chat engine to respond...'): | |
# Get the response from your chat engine | |
response = chat_engine.chat(user_input) | |
# Append assistant message to the session state | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
# Display the assistant message | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
except Exception as e: | |
st.error(f'An error occurred.') | |
# Optionally, you can choose to break the flow here if a critical error happens | |
# return | |
# Increment the message count and update the cookie | |
message_count += 1 | |
cookie_manager.set('message_count', str(message_count), expires_at=datetime.datetime.now() + datetime.timedelta(days=30)) | |