AI-ANK's picture
Update app.py
65efe6b
import streamlit as st
import extra_streamlit_components as stx
import requests
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq
from io import BytesIO
import replicate
from llama_index.llms.palm import PaLM
from llama_index import ServiceContext, VectorStoreIndex, Document
from llama_index.memory import ChatMemoryBuffer
import os
import datetime
# Set up the title of the application
#st.title("PaLM-Kosmos-Vision")
st.set_page_config(layout="wide")
st.write("My version of ChatGPT vision. You can upload an image and start chatting with the LLM about the image")
# Initialize the cookie manager
cookie_manager = stx.CookieManager()
# Function to get image caption via Kosmos2.
@st.cache_data
def get_image_caption(image_data):
input_data = {
"image": image_data,
"description_type": "Brief"
}
output = replicate.run(
"lucataco/kosmos-2:3e7b211c29c092f4bcc8853922cc986baa52efe255876b80cac2c2fbb4aff805",
input=input_data
)
# Split the output string on the newline character and take the first item
text_description = output.split('\n\n')[0]
return text_description
# Function to create the chat engine.
@st.cache_resource
def create_chat_engine(img_desc, api_key):
llm = PaLM(api_key=api_key)
service_context = ServiceContext.from_defaults(llm=llm, embed_model="local")
doc = Document(text=img_desc)
index = VectorStoreIndex.from_documents([doc], service_context=service_context)
chatmemory = ChatMemoryBuffer.from_defaults(token_limit=1500)
chat_engine = index.as_chat_engine(
chat_mode="context",
system_prompt=(
f"You are a chatbot, able to have normal interactions, as well as talk. "
"You always answer in great detail and are polite. Your responses always descriptive. "
"Your job is to talk about an image the user has uploaded. Image description: {img_desc}."
),
verbose=True,
memory=chatmemory
)
return chat_engine
# Clear chat function
def clear_chat():
if "messages" in st.session_state:
del st.session_state.messages
if "image_file" in st.session_state:
del st.session_state.image_file
# Callback function to clear the chat when a new image is uploaded
def on_image_upload():
clear_chat()
# Retrieve the message count from cookies
message_count = cookie_manager.get(cookie='message_count')
if message_count is None:
message_count = 0
else:
message_count = int(message_count)
# If the message limit has been reached, disable the inputs
if 0:
st.error("Notice: The maximum message limit for this demo version has been reached.")
# Disabling the uploader and input by not displaying them
image_uploader_placeholder = st.empty() # Placeholder for the uploader
chat_input_placeholder = st.empty() # Placeholder for the chat input
else:
# Add a clear chat button
if st.button("Clear Chat"):
clear_chat()
# Image upload section.
image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"], key="uploaded_image", on_change=on_image_upload)
if image_file:
# Display the uploaded image at a standard width.
st.image(image_file, caption='Uploaded Image.', width=200)
# Process the uploaded image to get a caption.
image_data = BytesIO(image_file.getvalue())
img_desc = get_image_caption(image_data)
st.write("Image Uploaded Successfully. Ask me anything about it.")
# Initialize the chat engine with the image description.
chat_engine = create_chat_engine(img_desc, st.secrets['GOOGLE_API_KEY'])
# Initialize session state for messages if it doesn't exist
if "messages" not in st.session_state:
st.session_state.messages = []
# Display previous messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Handle new user input
user_input = st.chat_input("Ask me about the image:", key="chat_input")
if user_input:
# Append user message to the session state
st.session_state.messages.append({"role": "user", "content": user_input})
# Display user message immediately
with st.chat_message("user"):
st.markdown(user_input)
# Call the chat engine to get the response if an image has been uploaded
if image_file and user_input:
try:
with st.spinner('Waiting for the chat engine to respond...'):
# Get the response from your chat engine
response = chat_engine.chat(user_input)
# Append assistant message to the session state
st.session_state.messages.append({"role": "assistant", "content": response})
# Display the assistant message
with st.chat_message("assistant"):
st.markdown(response)
except Exception as e:
st.error(f'An error occurred.')
# Optionally, you can choose to break the flow here if a critical error happens
# return
# Increment the message count and update the cookie
message_count += 1
cookie_manager.set('message_count', str(message_count), expires_at=datetime.datetime.now() + datetime.timedelta(days=30))