Spaces:
Runtime error
Runtime error
File size: 2,825 Bytes
bc4b324 cd4e335 bc4b324 cd4e335 bc4b324 cd4e335 bc4b324 cd4e335 bc4b324 cd4e335 25e9af5 bc4b324 cd4e335 bc4b324 cf1b6d7 bc4b324 8163622 bc4b324 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import os
import pathlib
import pickle
import zipfile
import gradio as gr
from PIL import Image
from sentence_transformers import util
from tqdm.autonotebook import tqdm
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
img_folder = "photos/"
if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0:
os.makedirs(img_folder, exist_ok=True)
photo_filename = "unsplash-25k-photos.zip"
if not os.path.exists(
photo_filename
): # Download dataset if does not exist
util.http_get(
"http://sbert.net/datasets/" + photo_filename, photo_filename
)
# Extract all images
with zipfile.ZipFile(photo_filename, "r") as zf:
for member in tqdm(zf.infolist(), desc="Extracting"):
zf.extract(member, img_folder)
cwd = pathlib.Path(__file__).parent.absolute()
# Define model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
emb_filename = "unsplash-25k-photos-embeddings.pkl"
emb_path = cwd / emb_filename
if not os.path.exists(emb_filename):
util.http_get("http://sbert.net/datasets/" + emb_filename, emb_path)
with open(emb_path, "rb") as fIn:
img_names, img_emb = pickle.load(fIn)
def search_text(query):
"""Search an image based on the text query.
Args:
query ([string]): [query you want search for]
Returns:
[list]: [list of images that are related to the query.]
"""
inputs = tokenizer([query], padding=True, return_tensors="pt")
query_emb = model.get_text_features(**inputs)
hits = util.semantic_search(query_emb, img_emb, top_k=8)[0]
images = [
(
Image.open(cwd / "photos" / img_names[hit["corpus_id"]]),
f"{hit['score']*100:.2f}%",
)
for hit in hits
]
return images
gr.Interface(
title="Text2Image search using CLIP model π€ β‘οΈ πΈ",
description="This is a demo of the CLIP model. We use the CLIP model to search for images from the unsplash dataset based on a text query.",
article="By Julien Assuied, Elie Brosset, Lucas Chapron & Alexis Japas",
fn=search_text,
theme="gradio/soft",
allow_flagging="never",
inputs=[
gr.Textbox(
lines=2,
label="What do you want to see ?",
placeholder="Write the prompt here...",
),
],
outputs=[
gr.Gallery(
label="8 Most similar images", show_label=True, elem_id="gallery"
).style(grid=[2], height="auto"),
],
examples=[
[("Two cats")],
[("A plane flying")],
[("A fox")],
[("un homme marchant sur le parc")],
],
).launch(debug=True)
|