import gradio as gr from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer from sentence_transformers import SentenceTransformer, util import pickle from PIL import Image import os import requests import subprocess from PIL import Image import requests from io import BytesIO def is_valid_image(content): try: # Attempt to open the image content Image.open(BytesIO(content)).verify() return True except OSError: return False ## Define model model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") # Open the precomputed embeddings emb_filename = 'unsplash-25k-photos-embeddings.pkl' with open(emb_filename, 'rb') as fIn: img_names, img_emb = pickle.load(fIn) def download_image(ids): id = ids.split(".")[0] url = f"https://unsplash.com/photos/{id}/download?w=320" # Use requests to download the image response = requests.get(url) if response.status_code == 200: # Check if the downloaded content is a valid image if is_valid_image(response.content): # Open the image directly from the response content img = Image.open(BytesIO(response.content)) # Display the image (optional) # img.show() return img else: # print("Downloaded content is not a valid image.") return None else: # print(f"Failed to download image. Status code: {response.status_code}") return None def search_text(query, top_k): """Search an image based on the text query.""" # First, we encode the query. inputs = tokenizer([query], padding=True, return_tensors="pt") query_emb = model.get_text_features(**inputs) # Then, we use the util.semantic_search function, which computes the cosine-similarity # between the query embedding and all image embeddings. # It then returns the top_k highest ranked images, which we output if top_k < 10: top_k_img = 8 elif top_k < 15: top_k_img = 13 else: top_k_img = 17 hits = util.semantic_search(query_emb, img_emb, top_k=top_k_img)[0] # print("Going hits") images = [] # print(hits) # print(len(hits)) for hit in hits: photo_name = img_names[hit['corpus_id']] # print(photo_name) img = download_image(photo_name) if img is not None: images.append(img) return images[:top_k] iface = gr.Interface( title="Text to Image using CLIP Model 📸", description="Gradio Demo for CLIP model. \n To use it, simply write which image you are looking for", fn=search_text, inputs=[ gr.Textbox( lines=4, label="Write what you are looking for in an image...", placeholder="Text Here...", ), gr.Slider(5, 15, step=5), ], outputs=[ gr.Gallery( label="Generated images", show_label=False, elem_id="gallery" ) ], examples=[ [("Dog in the beach"), 5], [("Paris during night."), 10], [("A cute kangaroo"), 5], [("Picnic Spots"), 10], [("Desert"), 5], [("A racetrack"), 15], ], ).launch()