import gradio as gr from pymongo import MongoClient from PIL import Image import base64 import os import io import boto3 import json # AWS Bedrock client setup bedrock_runtime = boto3.client('bedrock-runtime', aws_access_key_id=os.environ.get('AWS_ACCESS_KEY'), aws_secret_access_key=os.environ.get('AWS_SECRET_KEY'), region_name="us-east-1") # Function to construct the request body for Bedrock def construct_bedrock_body(base64_string, text): if text: return json.dumps({ "inputImage": base64_string, "embeddingConfig": {"outputEmbeddingLength": 1024}, "inputText": text }) return json.dumps({ "inputImage": base64_string, "embeddingConfig": {"outputEmbeddingLength": 1024}, }) # Function to get the embedding from Bedrock model def get_embedding_from_titan_multimodal(body): response = bedrock_runtime.invoke_model( body=body, modelId="amazon.titan-embed-image-v1", accept="application/json", contentType="application/json", ) response_body = json.loads(response.get("body").read()) return response_body["embedding"] # MongoDB setup uri = os.environ.get('MONGODB_ATLAS_URI') client = MongoClient(uri) db_name = 'celebrity_1000_embeddings' collection_name = 'celeb_images' celeb_images = client[db_name][collection_name] # Function to generate image description using Claude 3 Sonnet def generate_image_description_with_claude(images_base64_strs, image_base64): claude_body = json.dumps({ "anthropic_version": "bedrock-2023-05-31", "max_tokens": 1000, "system": "Please act as face comperison analyzer.", "messages": [{ "role": "user", "content": [ {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": image_base64}}, {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": images_base64_strs[0]}}, {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": images_base64_strs[1]}}, {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": images_base64_strs[2]}}, {"type": "text", "text": "Please let the user know how his first image is similar to the other 3 and which one is the most similar?"} ] }] }) claude_response = bedrock_runtime.invoke_model( body=claude_body, modelId="anthropic.claude-3-sonnet-20240229-v1:0", accept="application/json", contentType="application/json", ) response_body = json.loads(claude_response.get("body").read()) # Assuming the response contains a field 'content' with the description return response_body["content"][0].get("text", "No description available") # Main function to start image search def start_image_search(image, text): if not image: raise gr.Error("Please upload an image first, make sure to press the 'Submit' button after selecting the image.") buffered = io.BytesIO() image = image.resize((800, 600)) image.save(buffered, format="JPEG", quality=85) img_byte = buffered.getvalue() img_base64 = base64.b64encode(img_byte) img_base64_str = img_base64.decode('utf-8') body = construct_bedrock_body(img_base64_str, text) embedding = get_embedding_from_titan_multimodal(body) doc = list(celeb_images.aggregate([ { "$vectorSearch": { "index": "vector_index", "path": "embeddings", "queryVector": embedding, "numCandidates": 15, "limit": 3 } }, {"$project": {"image": 1}} ])) images = [] images_base64_strs = [] for image_doc in doc: pil_image = Image.open(io.BytesIO(base64.b64decode(image_doc['image']))) img_byte = io.BytesIO() pil_image.save(img_byte, format='JPEG') img_base64 = base64.b64encode(img_byte.getvalue()).decode('utf-8') images_base64_strs.append(img_base64) images.append(pil_image) description = generate_image_description_with_claude(images_base64_strs, img_base64_str) return images, description # Gradio Interface with gr.Blocks() as demo: gr.Markdown(""" # MongoDB's Vector Celeb Image Matcher Upload an image and find the most similar celeb image from the database, along with an AI-generated description. 💪 Make a great pose to impact the search! 🤯 """) gr.Interface(fn=start_image_search, inputs=[gr.Image(type="pil", label="Upload an image"), gr.Textbox(label="Enter an adjustment to the image")], outputs=[gr.Gallery(label="Located images for AI-generated descriptions", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto"),gr.Textbox(label="AI Based vision description")] ) demo.launch()