import streamlit as st from sentence_transformers import SentenceTransformer, util from pathlib import Path import pickle import requests from PIL import Image from io import BytesIO import pandas as pd from loguru import logger import torch T2I = "Text 2 Image" I2I = "Image 2 Image" def get_match(model, query, img_embs): query_emb = model.encode([query], convert_to_tensor=True) cosine_sim = util.pytorch_cos_sim(query_emb, img_embs) return cosine_sim def text_2_image(model, img_emb, img_names, img_urls, n_top_k_images): st.title("Text to Image") st.write("This is the text to image mode. Enter a text to be converted to an image") text = st.text_input("Enter the text to be converted to an image") if text: if st.button("Convert"): st.write("The image with the most similar embedding is:") cosine_sim = get_match(model, text, img_emb) top_k_images_indices = torch.topk(cosine_sim, n_top_k_images, 1).indices.squeeze() if top_k_images_indices.nelement() == 1: top_k_images_indices = [top_k_images_indices.tolist()] else: top_k_images_indices = top_k_images_indices.tolist() images_found = [img_names[top_k_best_image] for top_k_best_image in top_k_images_indices] cols = st.columns(n_top_k_images) for i, image_found in enumerate(images_found): logger.success(f"Image match found: {image_found}") img_url_best_match = img_urls.loc[img_urls["photo_id"] == image_found] logger.info(img_url_best_match.photo_url) if len(img_url_best_match) >= 1: response = requests.get(img_url_best_match.iloc[0]["photo_image_url"] + "?w=320") image = Image.open(BytesIO(response.content)) with cols[i]: st.image(image, caption=f"{i+1}/{n_top_k_images} most similar") else: st.error("No image found") def image_2_image(model, img_emb, img_names, img_urls,n_top_k_images): st.title("Image to Image") st.write("This is the image to image mode. Enter an image to be converted to an image") image = st.file_uploader("Upload an image to be converted to an image", type=["jpg", "png", "jpeg"]) if image is not None: image = Image.open(BytesIO(image.getvalue())) st.image(image, caption="Uploaded image") if st.button("Convert"): st.write("The image with the most similar embedding is:") cosine_sim = get_match(model, image.convert("RGB"), img_emb) top_k_images_indices = torch.topk(cosine_sim, n_top_k_images, 1).indices.squeeze() if top_k_images_indices.nelement() == 1: top_k_images_indices = [top_k_images_indices.tolist()] else: top_k_images_indices = top_k_images_indices.tolist() images_found = [img_names[top_k_best_image] for top_k_best_image in top_k_images_indices] cols = st.columns(n_top_k_images) for i, image_found in enumerate(images_found): logger.success(f"Image match found: {image_found}") img_url_best_match = img_urls.loc[img_urls["photo_id"] == image_found] logger.info(img_url_best_match.photo_url) if len(img_url_best_match) >= 1: response = requests.get(img_url_best_match.iloc[0]["photo_image_url"] + "?w=320") image = Image.open(BytesIO(response.content)) with cols[i]: st.image(image, caption=f"{i+1}/{n_top_k_images} most similar") else: st.error("No image found") @st.cache(suppress_st_warning=True, allow_output_mutation=True) def load_model(name): # st.sidebar.info("Loading model") model = SentenceTransformer(name) # st.sidebar.success(f"Model {name} loaded") return model @st.cache(suppress_st_warning=True) def load_embeddings(filename): st.sidebar.info("Loading Unsplash-Lite image embeddings") with open(filename, "rb") as fIn: img_names, img_emb = pickle.load(fIn) st.sidebar.success("Images embeddings loaded") return img_names, img_emb @st.cache(suppress_st_warning=True) def load_image_url_list(filename): url_list = pd.read_csv(filename, sep='\t', header=0) return url_list def main(): st.title("CLIP Image Search") model = load_model("clip-ViT-B-32") st.write("Select the mode to search for a match in Unsplash (thumbnail size) dataset. text2image mode needs a text as input and outputs the image with the most similar embedding (following cosine similarity). The Image to image mode is similar, but an input image is used instead of a text query") emb_filename = Path("unsplash-25k-photos-embeddings.pkl") urls_file = "photos.tsv000" img_urls = load_image_url_list(urls_file) img_names, img_emb = load_embeddings(emb_filename) # Convert list of image names to a dict matching image IDs and their embedding index img_names = {img_number: img_name.split('.')[0] for img_number, img_name in enumerate(img_names)} st.sidebar.title("Settings") app_mode = st.sidebar.selectbox("Choose the app mode", [T2I, I2I]) n_images_to_search = st.sidebar.number_input("Select the number of images to search", min_value=1, max_value=6) if app_mode == T2I: st.sidebar.info("Text to image mode") text_2_image(model, img_emb, img_names, img_urls,n_images_to_search) elif app_mode == I2I: st.sidebar.info("Image to image mode") image_2_image(model, img_emb, img_names, img_urls, n_images_to_search) if __name__ == "__main__": main()