File size: 5,790 Bytes
9bb4aae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea9d736
 
 
 
 
 
9bb4aae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea9d736
 
 
 
 
 
9bb4aae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eef445c
9bb4aae
 
 
 
 
 
 
eef445c
9bb4aae
 
 
 
 
 
 
 
f2611db
 
9bb4aae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import streamlit as st
from sentence_transformers import SentenceTransformer, util
from pathlib import Path
import pickle
import requests
from PIL import Image
from io import BytesIO
import pandas as pd
from loguru import logger
import torch

T2I = "Text 2 Image"
I2I = "Image 2 Image"
def get_match(model, query, img_embs):
    query_emb = model.encode([query], convert_to_tensor=True)
    cosine_sim = util.pytorch_cos_sim(query_emb, img_embs)
    return cosine_sim
def text_2_image(model, img_emb, img_names, img_urls, n_top_k_images):
    st.title("Text to Image")
    st.write("This is the text to image mode. Enter a text to be converted to an image")
    text = st.text_input("Enter the text to be converted to an image")
    if text:
        if st.button("Convert"):
            st.write("The image with the most similar embedding is:")
            cosine_sim = get_match(model, text, img_emb)
            top_k_images_indices = torch.topk(cosine_sim, n_top_k_images, 1).indices.squeeze()
            if top_k_images_indices.nelement() == 1:
                top_k_images_indices = [top_k_images_indices.tolist()]
            else:
                top_k_images_indices = top_k_images_indices.tolist()
            images_found = [img_names[top_k_best_image] for top_k_best_image in top_k_images_indices]
            cols = st.columns(n_top_k_images)
            for i, image_found in enumerate(images_found):
                logger.success(f"Image match found: {image_found}")
                img_url_best_match = img_urls.loc[img_urls["photo_id"] == image_found]
                logger.info(img_url_best_match.photo_url)
                if len(img_url_best_match) >= 1:
                    response = requests.get(img_url_best_match.iloc[0]["photo_image_url"] + "?w=320")
                    image = Image.open(BytesIO(response.content))
                    with cols[i]:
                        st.image(image, caption=f"{i+1}/{n_top_k_images} most similar")
                else:
                    st.error("No image found")


def image_2_image(model, img_emb, img_names, img_urls,n_top_k_images):
    st.title("Image to Image")
    st.write("This is the image to image mode. Enter an image to be converted to an image")
    image = st.file_uploader("Upload an image to be converted to an image", type=["jpg", "png", "jpeg"])
    if image is not None:
        image = Image.open(BytesIO(image.getvalue()))
        st.image(image, caption="Uploaded image")
        if st.button("Convert"):
            st.write("The image with the most similar embedding is:")
            cosine_sim = get_match(model, image.convert("RGB"), img_emb)
            top_k_images_indices = torch.topk(cosine_sim, n_top_k_images, 1).indices.squeeze()
            if top_k_images_indices.nelement() == 1:
                top_k_images_indices = [top_k_images_indices.tolist()]
            else:
                top_k_images_indices = top_k_images_indices.tolist()
            images_found = [img_names[top_k_best_image] for top_k_best_image in top_k_images_indices]
            cols = st.columns(n_top_k_images)
            for i, image_found in enumerate(images_found):
                logger.success(f"Image match found: {image_found}")
                img_url_best_match = img_urls.loc[img_urls["photo_id"] == image_found]
                logger.info(img_url_best_match.photo_url)
                if len(img_url_best_match) >= 1:
                    response = requests.get(img_url_best_match.iloc[0]["photo_image_url"] + "?w=320")
                    image = Image.open(BytesIO(response.content))
                    with cols[i]:
                        st.image(image, caption=f"{i+1}/{n_top_k_images} most similar")
                else:
                    st.error("No image found")

@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def load_model(name):
    # st.sidebar.info("Loading model")
    model = SentenceTransformer(name)
    # st.sidebar.success(f"Model {name} loaded")
    return model

@st.cache(suppress_st_warning=True)
def load_embeddings(filename):
    st.sidebar.info("Loading Unsplash-Lite image embeddings")
    with open(filename, "rb") as fIn:
        img_names, img_emb = pickle.load(fIn)
    st.sidebar.success("Images embeddings loaded")
    return img_names, img_emb

@st.cache(suppress_st_warning=True)
def load_image_url_list(filename):
    url_list = pd.read_csv(filename, sep='\t', header=0)
    return url_list

def main():
    st.title("CLIP Image Search")
    model = load_model("clip-ViT-B-32")
    st.write("Select the mode to search for a match in Unsplash (thumbnail size) dataset. text2image mode needs a text as input and outputs the image with the most similar embedding (following cosine similarity). The Image to image mode is similar, but an input image is used instead of a text query")
    emb_filename = Path("unsplash-25k-photos-embeddings.pkl")
    urls_file = "photos.tsv000"
    img_urls = load_image_url_list(urls_file)
    img_names, img_emb = load_embeddings(emb_filename)
    # Convert list of image names to a dict matching image IDs and their embedding index
    img_names = {img_number: img_name.split('.')[0] for img_number, img_name in enumerate(img_names)}
    st.sidebar.title("Settings")
    app_mode = st.sidebar.selectbox("Choose the app mode",
        [T2I, I2I])
    n_images_to_search = st.sidebar.number_input("Select the number of images to search", min_value=1, max_value=6)
    if app_mode == T2I:
        st.sidebar.info("Text to image mode")
        text_2_image(model, img_emb, img_names, img_urls,n_images_to_search)
    elif app_mode == I2I:
        st.sidebar.info("Image to image mode")
        image_2_image(model, img_emb, img_names, img_urls, n_images_to_search)
if __name__ == "__main__":
    main()