File size: 3,153 Bytes
bc4b324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import pathlib
import pickle
import zipfile

import gradio as gr
from PIL import Image
from sentence_transformers import util
from tqdm.autonotebook import tqdm
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer

img_folder = 'photos/'
if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0:
    os.makedirs(img_folder, exist_ok=True)

    photo_filename = 'unsplash-25k-photos.zip'
    if not os.path.exists(photo_filename):   #Download dataset if does not exist
        util.http_get('http://sbert.net/datasets/'+photo_filename, photo_filename)

    #Extract all images
    with zipfile.ZipFile(photo_filename, 'r') as zf:
        for member in tqdm(zf.infolist(), desc='Extracting'):
            zf.extract(member, img_folder)

cwd = pathlib.Path(__file__).parent.absolute()

# Define model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

emb_filename = "unsplash-25k-photos-embeddings.pkl"
emb_path = cwd / emb_filename
if not os.path.exists(emb_filename):
    util.http_get("http://sbert.net/datasets/" + emb_filename, emb_path)
with open(emb_path, "rb") as fIn:
    img_names, img_emb = pickle.load(fIn)


def search_text(query):
    """Search an image based on the text query.

    Args:
        query ([string]): [query you want search for]
        top_k (int, optional): [Amount of images o return]. Defaults to 1.
    Returns:
        [list]: [list of images that are related to the query.]
    """
    inputs = tokenizer([query], padding=True, return_tensors="pt")
    query_emb = model.get_text_features(**inputs)

    hits = util.semantic_search(query_emb, img_emb, top_k=8)[0]
    # Create list of tuples with image name and similarity score
    images = []
    for hit in hits:
        try:
            image = Image.open(cwd / "photos" / img_names[hit["corpus_id"]])
            images.append((image, f"{hit['score'] * 100:.2f}%"))
        except Exception as e:
            print(e)
    return images


gr.Interface(
    title="Text2Image search using CLIP model",
    description="This is a demo of the CLIP model. It is a multimodal model that can be used for zero-shot image classification, image captioning, image-text retrieval, and more. It is trained on a variety of data sources, including images from the internet, text from Wikipedia, and captions from Flickr.",
    article="By Julien Assuied, Elie Brosset, Lucas Chapron et Alexis Japas",
    fn=search_text,
    theme="gstaff/sketch",
    allow_flagging="never",
    inputs=[
        gr.Textbox(
            lines=2,
            label="What do you want to see ?",
            placeholder="Write the prompt here...",
        ),
    ],
    outputs=[
        gr.Gallery(
            label="Most similar images", show_label=True, elem_id="gallery"
        ).style(grid=[2], height="auto"),
    ],
    examples=[
        [("Two cats")],
        [("A plane flying")],
        [("A family picture")],
        [("un homme marchant sur le parc")],
    ],
).launch(debug=True)