Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForVision2Seq | |
from io import BytesIO | |
import replicate | |
from llama_index.llms.palm import PaLM | |
from llama_index import ServiceContext, VectorStoreIndex, Document | |
from llama_index.memory import ChatMemoryBuffer | |
import os | |
import base64 | |
import tempfile | |
# Function to get image caption via Kosmos2 (as in your original code) | |
import numpy as np | |
from PIL import Image | |
# Function to get image caption via Kosmos2 | |
def get_image_caption(image_array): | |
# Convert the numpy array to a PIL Image | |
image = Image.fromarray(image_array.astype('uint8'), 'RGB') | |
# Save the PIL Image to a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpeg") as tmp_file: | |
image.save(tmp_file, format="JPEG") | |
tmp_file_path = tmp_file.name | |
# Prepare the input data for the model | |
input_data = { | |
"image": open(tmp_file_path, "rb"), | |
"description_type": "Brief" | |
} | |
# Get the model output | |
output = replicate.run( | |
"lucataco/kosmos-2:3e7b211c29c092f4bcc8853922cc986baa52efe255876b80cac2c2fbb4aff805", | |
input=input_data | |
) | |
# Process the output to extract the description | |
text_description = output.split('\n\n')[0] | |
return text_description | |
# Function to create the chat engine (as in your original code) | |
def create_chat_engine(img_desc, api_key): | |
llm = PaLM(api_key=api_key) | |
service_context = ServiceContext.from_defaults(llm=llm, embed_model="local") | |
doc = Document(text=img_desc) | |
index = VectorStoreIndex.from_documents([doc], service_context=service_context) | |
chatmemory = ChatMemoryBuffer.from_defaults(token_limit=1500) | |
chat_engine = index.as_chat_engine( | |
chat_mode="context", | |
system_prompt=( | |
f"You are a chatbot, able to have normal interactions, as well as talk. " | |
"You always answer in great detail and are polite. Your responses always descriptive. " | |
"Your job is to talk about an image the user has uploaded. Image description: {img_desc}." | |
), | |
verbose=True, | |
memory=chatmemory | |
) | |
return chat_engine | |
# Function to handle chat interaction | |
# Function to handle chat interaction | |
def process_image_and_chat(image_array, user_input): | |
if image_array is None: | |
return "Please capture an image." | |
img_desc = get_image_caption(image_array) | |
chat_engine = create_chat_engine(img_desc, os.environ["GOOGLE_API_KEY"]) | |
if user_input: | |
try: | |
response = chat_engine.chat(user_input) | |
return response | |
except Exception as e: | |
return f'An error occurred: {str(e)}' | |
else: | |
return "Ask me anything about the uploaded image." | |
# Define Gradio interface | |
image_input = gr.Image(sources=["webcam"], type="numpy") | |
text_input = gr.Textbox(label="Ask me about the image:") | |
output_text = gr.Textbox(label="Response") | |
iface = gr.Interface( | |
fn=process_image_and_chat, | |
inputs=[image_input, text_input], | |
outputs=output_text, | |
title="My version of ChatGPT vision", | |
description="You can capture an image using your webcam and start chatting with the LLM about the image", | |
allow_flagging="never" | |
) | |
# Launch the app | |
iface.launch() | |