Spaces:
Runtime error
Runtime error
import gradio as gr | |
from time import sleep | |
import json | |
from pymongo import MongoClient | |
from bson import ObjectId | |
from openai import OpenAI | |
import os | |
from PIL import Image | |
import time | |
import traceback | |
import asyncio | |
from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_openai import ChatOpenAI | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
import base64 | |
import io | |
from reportlab.pdfgen import canvas | |
from reportlab.lib.pagesizes import letter | |
from reportlab.lib.utils import ImageReader | |
import boto3 | |
import re | |
output_parser = StrOutputParser() | |
import json | |
import requests | |
openai_client = OpenAI() | |
def fetch_url_data(url): | |
try: | |
response = requests.get(url) | |
response.raise_for_status() # Raises an HTTPError if the HTTP request returned an unsuccessful status code | |
return response.text | |
except requests.RequestException as e: | |
return f"Error: {e}" | |
uri = os.environ.get('MONGODB_ATLAS_URI') | |
email = "example@example.com" | |
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" | |
# AWS Bedrock client setup | |
bedrock_runtime = boto3.client('bedrock-runtime', | |
aws_access_key_id=os.environ.get('AWS_ACCESS_KEY'), | |
aws_secret_access_key=os.environ.get('AWS_SECRET_KEY'), | |
region_name="us-east-1") | |
chatClient = MongoClient(uri) | |
db_name = 'sample_mflix' | |
collection_name = 'embedded_movies' | |
collection = chatClient[db_name][collection_name] | |
## Chat RAG Functions | |
try: | |
vector_store = MongoDBAtlasVectorSearch(embedding=OpenAIEmbeddings(), collection=collection, index_name='vector_index', text_key='plot', embedding_key='plot_embedding') | |
llm = ChatOpenAI(temperature=0) | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are a movie recommendation engine which post a concise and short summary on relevant movies."), | |
("user", "List of movies: {input}") | |
]) | |
chain = prompt | llm | output_parser | |
except: | |
#If open ai key is wrong | |
print ('Open AI key is wrong') | |
vector_store = None | |
print("An error occurred: \n" + error_message) | |
def get_movies(message, history): | |
try: | |
movies = vector_store.similarity_search(query=message, k=3, embedding_key='plot_embedding') | |
return_text = '' | |
for movie in movies: | |
return_text = return_text + 'Title : ' + movie.metadata['title'] + '\n------------\n' + 'Plot: ' + movie.page_content + '\n\n' | |
print_llm_text = chain.invoke({"input": return_text}) | |
for i in range(len(print_llm_text)): | |
time.sleep(0.05) | |
yield "Found: " + "\n\n" + print_llm_text[: i+1] | |
except Exception as e: | |
error_message = traceback.format_exc() | |
print("An error occurred: \n" + error_message) | |
yield "Please clone the repo and add your open ai key as well as your MongoDB Atlas URI in the Secret Section of you Space\n OPENAI_API_KEY (your Open AI key) and MONGODB_ATLAS_CLUSTER_URI (0.0.0.0/0 whitelisted instance with Vector index created) \n\n For more information : https://mongodb.com/products/platform/atlas-vector-search" | |
## Restaurant Advisor RAG Functions | |
def get_restaurants(search, location, meters): | |
try: | |
client = MongoClient(uri) | |
db_name = 'whatscooking' | |
collection_name = 'restaurants' | |
restaurants_collection = client[db_name][collection_name] | |
trips_collection = client[db_name]['smart_trips'] | |
except: | |
print("Error Connecting to the MongoDB Atlas Cluster") | |
# Pre aggregate restaurants collection based on chosen location and radius, the output is stored into | |
# trips_collection | |
try: | |
newTrip, pre_agg = pre_aggregate_meters(restaurants_collection, location, meters) | |
## Get openai embeddings | |
response = openai_client.embeddings.create( | |
input=search, | |
model="text-embedding-3-small", | |
dimensions=256 | |
) | |
## prepare the similarity search on current trip | |
vectorQuery = { | |
"$vectorSearch": { | |
"index" : "vector_index", | |
"queryVector": response.data[0].embedding, | |
"path" : "embedding", | |
"numCandidates": 10, | |
"limit": 3, | |
"filter": {"searchTrip": newTrip} | |
}} | |
## Run the retrieved documents through a RAG system. | |
restaurant_docs = list(trips_collection.aggregate([vectorQuery, | |
{"$project": {"_id" : 0, "embedding": 0}}])) | |
chat_response = openai_client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are a helpful restaurant assistant. You will get a context if the context is not relevat to the user query please address that and not provide by default the restaurants as is."}, | |
{ "role": "user", "content": f"Find me the 2 best restaurant and why based on {search} and {restaurant_docs}. explain trades offs and why I should go to each one. You can mention the third option as a possible alternative."} | |
] | |
) | |
## Removed the temporary documents | |
trips_collection.delete_many({"searchTrip": newTrip}) | |
if len(restaurant_docs) == 0: | |
return "No restaurants found", '<iframe style="background: #FFFFFF;border: none;border-radius: 2px;box-shadow: 0 2px 10px 0 rgba(70, 76, 79, .2);" width="640" height="480" src="https://charts.mongodb.com/charts-paveldev-wiumf/embed/charts?id=65c24b0c-2215-4e6f-829c-f484dfd8a90c&filter={\'restaurant_id\':\'\'}&maxDataAge=3600&theme=light&autoRefresh=true"></iframe>', str(pre_agg), str(vectorQuery) | |
## Build the map filter | |
first_restaurant = restaurant_docs[0]['restaurant_id'] | |
second_restaurant = restaurant_docs[1]['restaurant_id'] | |
third_restaurant = restaurant_docs[2]['restaurant_id'] | |
restaurant_string = f"'{first_restaurant}', '{second_restaurant}', '{third_restaurant}'" | |
iframe = '<iframe style="background: #FFFFFF;border: none;border-radius: 2px;box-shadow: 0 2px 10px 0 rgba(70, 76, 79, .2);" width="640" height="480" src="https://charts.mongodb.com/charts-paveldev-wiumf/embed/charts?id=65c24b0c-2215-4e6f-829c-f484dfd8a90c&filter={\'restaurant_id\':{$in:[' + restaurant_string + ']}}&maxDataAge=3600&theme=light&autoRefresh=true"></iframe>' | |
client.close() | |
return chat_response.choices[0].message.content, iframe,str(pre_agg), str(vectorQuery) | |
except Exception as e: | |
print(e) | |
return "Your query caused an error, please retry with allowed input only ...", '<iframe style="background: #FFFFFF;border: none;border-radius: 2px;box-shadow: 0 2px 10px 0 rgba(70, 76, 79, .2);" width="640" height="480" src="https://charts.mongodb.com/charts-paveldev-wiumf/embed/charts?id=65c24b0c-2215-4e6f-829c-f484dfd8a90c&filter={\'restaurant_id\':\'\'}&maxDataAge=3600&theme=light&autoRefresh=true"></iframe>', str(pre_agg), str(vectorQuery) | |
def pre_aggregate_meters(restaurants_collection, location, meters): | |
## Do the geo location preaggregate and assign the search trip id. | |
tripId = ObjectId() | |
pre_aggregate_pipeline = [{ | |
"$geoNear": { | |
"near": location, | |
"distanceField": "distance", | |
"maxDistance": meters, | |
"spherical": True, | |
}, | |
}, | |
{ | |
"$addFields": { | |
"searchTrip" : tripId, | |
"date" : tripId.generation_time | |
} | |
}, | |
{ | |
"$merge": { | |
"into": "smart_trips" | |
} | |
} ] | |
result = restaurants_collection.aggregate(pre_aggregate_pipeline); | |
sleep(3) | |
return tripId, pre_aggregate_pipeline | |
## Celeb Matcher RAG Functions | |
def construct_bedrock_body(base64_string, text): | |
if text: | |
return json.dumps({ | |
"inputImage": base64_string, | |
"embeddingConfig": {"outputEmbeddingLength": 1024}, | |
"inputText": text | |
}) | |
return json.dumps({ | |
"inputImage": base64_string, | |
"embeddingConfig": {"outputEmbeddingLength": 1024}, | |
}) | |
# Function to get the embedding from Bedrock model | |
def get_embedding_from_titan_multimodal(body): | |
response = bedrock_runtime.invoke_model( | |
body=body, | |
modelId="amazon.titan-embed-image-v1", | |
accept="application/json", | |
contentType="application/json", | |
) | |
response_body = json.loads(response.get("body").read()) | |
return response_body["embedding"] | |
# MongoDB setup | |
uri = os.environ.get('MONGODB_ATLAS_URI') | |
client = MongoClient(uri) | |
db_name = 'celebrity_1000_embeddings' | |
collection_name = 'celeb_images' | |
celeb_images = client[db_name][collection_name] | |
participants_db = client[db_name]['participants'] | |
# Function to record participant details | |
def record_participant(email, company, description, images): | |
if not email or not company: | |
## regex to validate email | |
if not re.match(email_pattern, email): | |
raise gr.Error("Please enter a valid email address") | |
raise gr.Error("Please enter your email and company name to record the participant details.") | |
if not images: | |
raise gr.Error("Please search for an image first before recording the participant.") | |
participant_data = {'email': email, 'company': company} | |
participants_db.insert_one(participant_data) | |
# Create PDF after recording participant | |
pdf_file = create_pdf(images, description, email, company) | |
return pdf_file | |
def create_pdf(images, description, email, company): | |
print(images) | |
filename = f"image_search_results_{email}.pdf" | |
c = canvas.Canvas(filename, pagesize=letter) | |
width, height = letter | |
y_position = height | |
c.drawString(50, y_position - 30, f"Thanks for participating, {email}! Here are your celeb match results:") | |
c.drawString(50, y_position - 70, "Claude 3 summary of the MongoDB celeb comparison:") | |
# Split the description into words | |
words = description.split() | |
# Initialize variables | |
lines = [] | |
current_line = [] | |
# Iterate through words and group them into lines | |
for word in words: | |
current_line.append(word) | |
if len(current_line) == 10: # Split every 10 words | |
lines.append(" ".join(current_line)) | |
current_line = [] | |
# Add the remaining words to the last line | |
if current_line: | |
lines.append(" ".join(current_line)) | |
# Write each line of the description | |
y_position -= 90 # Initial Y position | |
for line in lines: | |
c.drawString(50, y_position, line) | |
y_position -= 15 # Adjust for line spacing | |
image_position = y_position | |
for image in images: | |
print(image) | |
y_position -= 300 # Adjust this based on your image sizes | |
if y_position <= 150: | |
c.showPage() | |
y_position = height - 50 | |
buffered = io.BytesIO() | |
pil_image = Image.open(image[0]) | |
pil_image.save(buffered, format='JPEG') | |
c.drawImage(ImageReader(buffered), 50, image_position - 150, width=200, height=200) | |
image_position = image_position - 200 | |
c.save() | |
return filename | |
# Function to generate image description using Claude 3 Sonnet | |
def generate_image_description_with_claude(images_base64_strs, image_base64): | |
claude_body = json.dumps({ | |
"anthropic_version": "bedrock-2023-05-31", | |
"max_tokens": 1000, | |
"system": "Please act as face comperison analyzer.", | |
"messages": [{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": image_base64}}, | |
{"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": images_base64_strs[0]}}, | |
{"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": images_base64_strs[1]}}, | |
{"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": images_base64_strs[2]}}, | |
{"type": "text", "text": "Please let the user know how his first image is similar to the other 3 and which one is the most similar?"} | |
] | |
}] | |
}) | |
claude_response = bedrock_runtime.invoke_model( | |
body=claude_body, | |
modelId="anthropic.claude-3-sonnet-20240229-v1:0", | |
accept="application/json", | |
contentType="application/json", | |
) | |
response_body = json.loads(claude_response.get("body").read()) | |
# Assuming the response contains a field 'content' with the description | |
return response_body["content"][0].get("text", "No description available") | |
# Main function to start image search | |
def start_image_search(image, text): | |
if not image: | |
raise gr.Error("Please upload an image first, make sure to press the 'Submit' button after selecting the image.") | |
buffered = io.BytesIO() | |
image = image.resize((800, 600)) | |
image.save(buffered, format="JPEG", quality=85) | |
img_byte = buffered.getvalue() | |
img_base64 = base64.b64encode(img_byte) | |
img_base64_str = img_base64.decode('utf-8') | |
body = construct_bedrock_body(img_base64_str, text) | |
embedding = get_embedding_from_titan_multimodal(body) | |
doc = list(celeb_images.aggregate([ | |
{ | |
"$vectorSearch": { | |
"index": "vector_index", | |
"path": "embeddings", | |
"queryVector": embedding, | |
"numCandidates": 15, | |
"limit": 3 | |
} | |
}, {"$project": {"image": 1}} | |
])) | |
images = [] | |
images_base64_strs = [] | |
for image_doc in doc: | |
pil_image = Image.open(io.BytesIO(base64.b64decode(image_doc['image']))) | |
img_byte = io.BytesIO() | |
pil_image.save(img_byte, format='JPEG') | |
img_base64 = base64.b64encode(img_byte.getvalue()).decode('utf-8') | |
images_base64_strs.append(img_base64) | |
images.append(pil_image) | |
description = generate_image_description_with_claude(images_base64_strs, img_base64_str) | |
return images, description | |
with gr.Blocks() as demo: | |
with gr.Tab("Celeb Matcher Demo"): | |
with gr.Tab("Demo"): | |
gr.Markdown(""" | |
# MongoDB's Vector Celeb Image Matcher | |
Upload an image and find the most similar celeb image from the database, along with an AI-generated description. | |
💪 Make a great pose to impact the search! 🤯 | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload an image") | |
text_input = gr.Textbox(label="Enter an adjustment to the image") | |
search_button = gr.Button("Search") | |
with gr.Column(): | |
output_gallery = gr.Gallery(label="Located images", show_label=False, elem_id="gallery", | |
columns=[3], rows=[1], object_fit="contain", height="auto") | |
output_description = gr.Textbox(label="AI Based vision description") | |
gr.Markdown(""" | |
""") | |
with gr.Row(): | |
email_input = gr.Textbox(label="Enter your email") | |
company_input = gr.Textbox(label="Enter your company name") | |
record_button = gr.Button("Record & Download PDF") | |
search_button.click( | |
fn=start_image_search, | |
inputs=[image_input, text_input], | |
outputs=[output_gallery, output_description] | |
) | |
record_button.click( | |
fn=record_participant, | |
inputs=[email_input, company_input, output_description, output_gallery], | |
outputs=gr.File(label="Download Search Results as PDF") | |
) | |
with gr.Tab("Code"): | |
gr.Code(label="Code", language="python", value=fetch_url_data('https://huggingface.co/spaces/MongoDB/aws-bedrock-celeb-matcher/raw/main/app.py')) | |
with gr.Tab("Chat RAG Demo"): | |
with gr.Tab("Demo"): | |
gr.ChatInterface(get_movies, examples=["What movies are scary?", "Find me a comedy", "Movies for kids"], title="Movies Atlas Vector Search",description="This small chat uses a similarity search to find relevant movies, it uses MongoDB Atlas Vector Search read more here: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-tutorial",submit_btn="Search").queue() | |
with gr.Tab("Code"): | |
gr.Code(label="Code", language="python", value=fetch_url_data('https://huggingface.co/spaces/MongoDB/MongoDB-Movie-Search/raw/main/app.py')) | |
with gr.Tab("Restaruant advisor RAG Demo"): | |
with gr.Tab("Demo"): | |
gr.Markdown( | |
""" | |
# MongoDB's Vector Restaurant planner | |
Start typing below to see the results. You can search a specific cuisine for you and choose 3 predefined locations. | |
The radius specify the distance from the start search location. This space uses the dataset called [whatscooking.restaurants](https://huggingface.co/datasets/AIatMongoDB/whatscooking.restaurants) | |
""") | |
# Create the interface | |
gr.Interface( | |
get_restaurants, | |
[gr.Textbox(placeholder="What type of dinner are you looking for?"), | |
gr.Radio(choices=[ | |
("Timesquare Manhattan", { | |
"type": "Point", | |
"coordinates": [-73.98527039999999, 40.7589099] | |
}), | |
("Westside Manhattan", { | |
"type": "Point", | |
"coordinates": [-74.013686, 40.701975] | |
}), | |
("Downtown Manhattan", { | |
"type": "Point", | |
"coordinates": [-74.000468, 40.720777] | |
}) | |
], label="Location", info="What location you need?"), | |
gr.Slider(minimum=500, maximum=10000, randomize=False, step=5, label="Radius in meters")], | |
[gr.Textbox(label="MongoDB Vector Recommendations", placeholder="Results will be displayed here"), "html", | |
gr.Code(label="Pre-aggregate pipeline",language="json" ), | |
gr.Code(label="Vector Query", language="json")] | |
) | |
with gr.Tab("Code"): | |
gr.Code(label="Code", language="python", value=fetch_url_data('https://huggingface.co/spaces/MongoDB/whatscooking-advisor/raw/main/app.py')) | |
if __name__ == "__main__": | |
demo.launch() |