Spaces:
Running
Running
import gradio as gr | |
from PIL import Image | |
from datasets import load_dataset, Dataset | |
import random | |
import numpy as np | |
import time | |
# Dataset | |
ds = load_dataset("tonyassi/finesse1-embeddings", split='train') | |
id_to_row = {row['id']: row for row in ds} | |
remaining_ds = None | |
preference_embedding = [] | |
################################################################################### | |
def get_random_images(dataset, num): | |
# Select 4 random indices from the dataset | |
random_indices = random.sample(range(len(dataset)), num) | |
# Get the 4 random images | |
random_images = dataset.select(random_indices) | |
# Create a new dataset with the remaining images | |
remaining_indices = [i for i in range(len(dataset)) if i not in random_indices] | |
new_dataset = dataset.select(remaining_indices) | |
return random_images, new_dataset | |
def find_similar_images(dataset, num, embedding): | |
# Ensure FAISS index exists and search for similar images | |
dataset.add_faiss_index(column='embeddings') | |
scores, retrieved_examples = dataset.get_nearest_examples('embeddings', np.array(embedding), k=num) | |
# Drop FAISS index after use to avoid re-indexing | |
dataset.drop_index('embeddings') | |
# Extract all dataset IDs and use a set to find remaining indices | |
dataset_ids = dataset['id'] | |
retrieved_ids_set = set(retrieved_examples['id']) | |
# Use a list comprehension with enumerate for faster indexing | |
remaining_indices = [i for i, id in enumerate(dataset_ids) if id not in retrieved_ids_set] | |
# Create a new dataset without the retrieved images | |
new_dataset = dataset.select(remaining_indices) | |
return retrieved_examples, new_dataset | |
def average_embedding(embedding1, embedding2): | |
embedding1 = np.array(embedding1) | |
embedding2 = np.array(embedding2) | |
return (embedding1 + embedding2) / 2 | |
################################################################################### | |
def load_images(): | |
print("ds", ds.num_rows) | |
global remaining_ds | |
remaining_ds = ds | |
global preference_embedding | |
preference_embedding = [] | |
# Get random images | |
rand_imgs, remaining_ds = get_random_images(ds, 10) | |
# Create a list of tuples [(img1,caption1),(img2,caption2)...] | |
result = list(zip(rand_imgs['image'], [str(id) for id in rand_imgs['id']])) | |
return result | |
def select_image(evt: gr.SelectData, gallery, preference_gallery): | |
global remaining_ds | |
print("remaining_ds", remaining_ds.num_rows) | |
# Selected image | |
selected_id = int(evt.value['caption']) | |
selected_row = id_to_row[selected_id] | |
selected_embedding = selected_row['embeddings'] | |
selected_image = selected_row['image'] | |
# Update preference embedding | |
global preference_embedding | |
if len(preference_embedding) == 0: | |
preference_embedding = selected_embedding | |
else: | |
preference_embedding = average_embedding(preference_embedding, selected_embedding) | |
# Find images which are most similar to the preference embedding | |
simlar_images, remaining_ds = find_similar_images(remaining_ds, 5, preference_embedding) | |
# Create a list of tuples [(img1,caption1),(img2,caption2)...] | |
result = list(zip(simlar_images['image'], [str(id) for id in simlar_images['id']])) | |
# Get random images | |
rand_imgs, remaining_ds = get_random_images(remaining_ds, 5) | |
# Create a list of tuples [(img1,caption1),(img2,caption2)...] | |
random_result = list(zip(rand_imgs['image'], [str(id) for id in rand_imgs['id']])) | |
final_result = result + random_result | |
# Update prefernce gallery | |
if (preference_gallery==None): | |
final_preference_gallery = [selected_image] | |
else: | |
final_preference_gallery = [selected_image] + preference_gallery | |
return gr.Gallery(value=final_result, selected_index=None), final_preference_gallery | |
################################################################################### | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
<center><h1> Product Recommendation using Image Similarity </h1></center> | |
<center>by <a href="https://www.tonyassi.com/" target="_blank">Tony Assi</a></center><br> | |
<center><i> This is a demo of product recommendation using image similarity of user preferences. </i><a href="https://huggingface.co/blog/tonyassi/product-recommendation-using-image-similarity/" target="_blank">Read the article.</a></center> <br> | |
The the user selects their favorite product which then gets added to the user preference group. Each of the image embeddings in the user preference products get averaged into a preference embedding. Each round some products are displayed: 5 products most similar to user preference embedding and 5 random products. Embeddings are generated with [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224). The dataset used is [tonyassi/finesse1-embeddings](https://huggingface.co/datasets/tonyassi/finesse1-embeddings). | |
""") | |
product_gallery = gr.Gallery(columns=5, object_fit='contain', allow_preview=False, label='Products') | |
preference_gallery = gr.Gallery(columns=5, object_fit='contain', allow_preview=False, label='Preference', interactive=False) | |
demo.load(load_images, inputs=None, outputs=[product_gallery]) | |
product_gallery.select(select_image, inputs=[product_gallery, preference_gallery], outputs=[product_gallery, preference_gallery]) | |
demo.launch() |