Pash1986's picture
Update app.py
2251e58 verified
raw
history blame
No virus
5.09 kB
import gradio as gr
from pymongo import MongoClient
from PIL import Image
import base64
import os
import io
import boto3
import json
# AWS Bedrock client setup
bedrock_runtime = boto3.client('bedrock-runtime',
aws_access_key_id=os.environ.get('AWS_ACCESS_KEY'),
aws_secret_access_key=os.environ.get('AWS_SECRET_KEY'),
region_name="us-east-1")
# Function to construct the request body for Bedrock
def construct_bedrock_body(base64_string, text):
if text:
return json.dumps({
"inputImage": base64_string,
"embeddingConfig": {"outputEmbeddingLength": 1024},
"inputText": text
})
return json.dumps({
"inputImage": base64_string,
"embeddingConfig": {"outputEmbeddingLength": 1024},
})
# Function to get the embedding from Bedrock model
def get_embedding_from_titan_multimodal(body):
response = bedrock_runtime.invoke_model(
body=body,
modelId="amazon.titan-embed-image-v1",
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
return response_body["embedding"]
# MongoDB setup
uri = os.environ.get('MONGODB_ATLAS_URI')
client = MongoClient(uri)
db_name = 'celebrity_1000_embeddings'
collection_name = 'celeb_images'
celeb_images = client[db_name][collection_name]
# Function to generate image description using Claude 3 Sonnet
def generate_image_description_with_claude(images_base64_strs, image_base64):
claude_body = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 1000,
"system": "Please act as face comperison analyzer.",
"messages": [{
"role": "user",
"content": [
{"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": image_base64}},
{"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": images_base64_strs[0]}},
{"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": images_base64_strs[1]}},
{"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": images_base64_strs[2]}},
{"type": "text", "text": "Please let the user know how his first image is similar to the other 3 and which one is the most similar?"}
]
}]
})
claude_response = bedrock_runtime.invoke_model(
body=claude_body,
modelId="anthropic.claude-3-sonnet-20240229-v1:0",
accept="application/json",
contentType="application/json",
)
response_body = json.loads(claude_response.get("body").read())
# Assuming the response contains a field 'content' with the description
return response_body["content"][0].get("text", "No description available")
# Main function to start image search
def start_image_search(image, text):
if not image:
raise gr.Error("Please upload an image first, make sure to press the 'Submit' button after selecting the image.")
buffered = io.BytesIO()
image = image.resize((800, 600))
image.save(buffered, format="JPEG", quality=85)
img_byte = buffered.getvalue()
img_base64 = base64.b64encode(img_byte)
img_base64_str = img_base64.decode('utf-8')
body = construct_bedrock_body(img_base64_str, text)
embedding = get_embedding_from_titan_multimodal(body)
doc = list(celeb_images.aggregate([
{
"$vectorSearch": {
"index": "vector_index",
"path": "embeddings",
"queryVector": embedding,
"numCandidates": 15,
"limit": 3
}
}, {"$project": {"image": 1}}
]))
images = []
images_base64_strs = []
for image_doc in doc:
pil_image = Image.open(io.BytesIO(base64.b64decode(image_doc['image'])))
img_byte = io.BytesIO()
pil_image.save(img_byte, format='JPEG')
img_base64 = base64.b64encode(img_byte.getvalue()).decode('utf-8')
images_base64_strs.append(img_base64)
images.append(pil_image)
description = generate_image_description_with_claude(images_base64_strs, img_base64_str)
return images, description
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("""
# MongoDB's Vector Celeb Image Matcher
Upload an image and find the most similar celeb image from the database, along with an AI-generated description.
💪 Make a great pose to impact the search! 🤯
""")
gr.Interface(fn=start_image_search,
inputs=[gr.Image(type="pil", label="Upload an image"), gr.Textbox(label="Enter an adjustment to the image")],
outputs=[gr.Gallery(label="Located images for AI-generated descriptions", show_label=False, elem_id="gallery",
columns=[3], rows=[1], object_fit="contain", height="auto"),gr.Textbox(label="AI Based vision description")]
)
demo.launch()