image-search / app.py
JLD's picture
Fixed n_searched_images=1 case
ea9d736
import streamlit as st
from sentence_transformers import SentenceTransformer, util
from pathlib import Path
import pickle
import requests
from PIL import Image
from io import BytesIO
import pandas as pd
from loguru import logger
import torch
T2I = "Text 2 Image"
I2I = "Image 2 Image"
def get_match(model, query, img_embs):
query_emb = model.encode([query], convert_to_tensor=True)
cosine_sim = util.pytorch_cos_sim(query_emb, img_embs)
return cosine_sim
def text_2_image(model, img_emb, img_names, img_urls, n_top_k_images):
st.title("Text to Image")
st.write("This is the text to image mode. Enter a text to be converted to an image")
text = st.text_input("Enter the text to be converted to an image")
if text:
if st.button("Convert"):
st.write("The image with the most similar embedding is:")
cosine_sim = get_match(model, text, img_emb)
top_k_images_indices = torch.topk(cosine_sim, n_top_k_images, 1).indices.squeeze()
if top_k_images_indices.nelement() == 1:
top_k_images_indices = [top_k_images_indices.tolist()]
else:
top_k_images_indices = top_k_images_indices.tolist()
images_found = [img_names[top_k_best_image] for top_k_best_image in top_k_images_indices]
cols = st.columns(n_top_k_images)
for i, image_found in enumerate(images_found):
logger.success(f"Image match found: {image_found}")
img_url_best_match = img_urls.loc[img_urls["photo_id"] == image_found]
logger.info(img_url_best_match.photo_url)
if len(img_url_best_match) >= 1:
response = requests.get(img_url_best_match.iloc[0]["photo_image_url"] + "?w=320")
image = Image.open(BytesIO(response.content))
with cols[i]:
st.image(image, caption=f"{i+1}/{n_top_k_images} most similar")
else:
st.error("No image found")
def image_2_image(model, img_emb, img_names, img_urls,n_top_k_images):
st.title("Image to Image")
st.write("This is the image to image mode. Enter an image to be converted to an image")
image = st.file_uploader("Upload an image to be converted to an image", type=["jpg", "png", "jpeg"])
if image is not None:
image = Image.open(BytesIO(image.getvalue()))
st.image(image, caption="Uploaded image")
if st.button("Convert"):
st.write("The image with the most similar embedding is:")
cosine_sim = get_match(model, image.convert("RGB"), img_emb)
top_k_images_indices = torch.topk(cosine_sim, n_top_k_images, 1).indices.squeeze()
if top_k_images_indices.nelement() == 1:
top_k_images_indices = [top_k_images_indices.tolist()]
else:
top_k_images_indices = top_k_images_indices.tolist()
images_found = [img_names[top_k_best_image] for top_k_best_image in top_k_images_indices]
cols = st.columns(n_top_k_images)
for i, image_found in enumerate(images_found):
logger.success(f"Image match found: {image_found}")
img_url_best_match = img_urls.loc[img_urls["photo_id"] == image_found]
logger.info(img_url_best_match.photo_url)
if len(img_url_best_match) >= 1:
response = requests.get(img_url_best_match.iloc[0]["photo_image_url"] + "?w=320")
image = Image.open(BytesIO(response.content))
with cols[i]:
st.image(image, caption=f"{i+1}/{n_top_k_images} most similar")
else:
st.error("No image found")
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def load_model(name):
# st.sidebar.info("Loading model")
model = SentenceTransformer(name)
# st.sidebar.success(f"Model {name} loaded")
return model
@st.cache(suppress_st_warning=True)
def load_embeddings(filename):
st.sidebar.info("Loading Unsplash-Lite image embeddings")
with open(filename, "rb") as fIn:
img_names, img_emb = pickle.load(fIn)
st.sidebar.success("Images embeddings loaded")
return img_names, img_emb
@st.cache(suppress_st_warning=True)
def load_image_url_list(filename):
url_list = pd.read_csv(filename, sep='\t', header=0)
return url_list
def main():
st.title("CLIP Image Search")
model = load_model("clip-ViT-B-32")
st.write("Select the mode to search for a match in Unsplash (thumbnail size) dataset. text2image mode needs a text as input and outputs the image with the most similar embedding (following cosine similarity). The Image to image mode is similar, but an input image is used instead of a text query")
emb_filename = Path("unsplash-25k-photos-embeddings.pkl")
urls_file = "photos.tsv000"
img_urls = load_image_url_list(urls_file)
img_names, img_emb = load_embeddings(emb_filename)
# Convert list of image names to a dict matching image IDs and their embedding index
img_names = {img_number: img_name.split('.')[0] for img_number, img_name in enumerate(img_names)}
st.sidebar.title("Settings")
app_mode = st.sidebar.selectbox("Choose the app mode",
[T2I, I2I])
n_images_to_search = st.sidebar.number_input("Select the number of images to search", min_value=1, max_value=6)
if app_mode == T2I:
st.sidebar.info("Text to image mode")
text_2_image(model, img_emb, img_names, img_urls,n_images_to_search)
elif app_mode == I2I:
st.sidebar.info("Image to image mode")
image_2_image(model, img_emb, img_names, img_urls, n_images_to_search)
if __name__ == "__main__":
main()