CLIP / app.py
TharunSivamani's picture
fixing code
b65eb60 verified
import gradio as gr
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
from sentence_transformers import SentenceTransformer, util
import pickle
from PIL import Image
import os
import requests
import subprocess
from PIL import Image
import requests
from io import BytesIO
def is_valid_image(content):
try:
# Attempt to open the image content
Image.open(BytesIO(content)).verify()
return True
except OSError:
return False
## Define model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
# Open the precomputed embeddings
emb_filename = 'unsplash-25k-photos-embeddings.pkl'
with open(emb_filename, 'rb') as fIn:
img_names, img_emb = pickle.load(fIn)
def download_image(ids):
id = ids.split(".")[0]
url = f"https://unsplash.com/photos/{id}/download?w=320"
# Use requests to download the image
response = requests.get(url)
if response.status_code == 200:
# Check if the downloaded content is a valid image
if is_valid_image(response.content):
# Open the image directly from the response content
img = Image.open(BytesIO(response.content))
# Display the image (optional)
# img.show()
return img
else:
# print("Downloaded content is not a valid image.")
return None
else:
# print(f"Failed to download image. Status code: {response.status_code}")
return None
def search_text(query, top_k):
"""Search an image based on the text query."""
# First, we encode the query.
inputs = tokenizer([query], padding=True, return_tensors="pt")
query_emb = model.get_text_features(**inputs)
# Then, we use the util.semantic_search function, which computes the cosine-similarity
# between the query embedding and all image embeddings.
# It then returns the top_k highest ranked images, which we output
if top_k < 10:
top_k_img = 8
elif top_k < 15:
top_k_img = 13
else:
top_k_img = 17
hits = util.semantic_search(query_emb, img_emb, top_k=top_k_img)[0]
# print("Going hits")
images = []
# print(hits)
# print(len(hits))
for hit in hits:
photo_name = img_names[hit['corpus_id']]
# print(photo_name)
img = download_image(photo_name)
if img is not None:
images.append(img)
return images[:top_k]
iface = gr.Interface(
title="Text to Image using CLIP Model 📸",
description="Gradio Demo for CLIP model. \n To use it, simply write which image you are looking for",
fn=search_text,
inputs=[
gr.Textbox(
lines=4,
label="Write what you are looking for in an image...",
placeholder="Text Here...",
),
gr.Slider(5, 15, step=5),
],
outputs=[
gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
)
],
examples=[
[("Dog in the beach"), 5],
[("Paris during night."), 10],
[("A cute kangaroo"), 5],
[("Picnic Spots"), 10],
[("Desert"), 5],
[("A racetrack"), 15],
],
).launch()