Spaces:
Paused
Paused
| import gradio as gr | |
| from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer | |
| from sentence_transformers import SentenceTransformer, util | |
| import pickle | |
| from PIL import Image | |
| import os | |
| import requests | |
| import subprocess | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| def is_valid_image(content): | |
| try: | |
| # Attempt to open the image content | |
| Image.open(BytesIO(content)).verify() | |
| return True | |
| except OSError: | |
| return False | |
| ## Define model | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
| # Open the precomputed embeddings | |
| emb_filename = 'unsplash-25k-photos-embeddings.pkl' | |
| with open(emb_filename, 'rb') as fIn: | |
| img_names, img_emb = pickle.load(fIn) | |
| def download_image(ids): | |
| id = ids.split(".")[0] | |
| url = f"https://unsplash.com/photos/{id}/download?w=320" | |
| # Use requests to download the image | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| # Check if the downloaded content is a valid image | |
| if is_valid_image(response.content): | |
| # Open the image directly from the response content | |
| img = Image.open(BytesIO(response.content)) | |
| # Display the image (optional) | |
| # img.show() | |
| return img | |
| else: | |
| # print("Downloaded content is not a valid image.") | |
| return None | |
| else: | |
| # print(f"Failed to download image. Status code: {response.status_code}") | |
| return None | |
| def search_text(query, top_k): | |
| """Search an image based on the text query.""" | |
| # First, we encode the query. | |
| inputs = tokenizer([query], padding=True, return_tensors="pt") | |
| query_emb = model.get_text_features(**inputs) | |
| # Then, we use the util.semantic_search function, which computes the cosine-similarity | |
| # between the query embedding and all image embeddings. | |
| # It then returns the top_k highest ranked images, which we output | |
| if top_k < 10: | |
| top_k_img = 8 | |
| elif top_k < 15: | |
| top_k_img = 13 | |
| else: | |
| top_k_img = 17 | |
| hits = util.semantic_search(query_emb, img_emb, top_k=top_k_img)[0] | |
| # print("Going hits") | |
| images = [] | |
| # print(hits) | |
| # print(len(hits)) | |
| for hit in hits: | |
| photo_name = img_names[hit['corpus_id']] | |
| # print(photo_name) | |
| img = download_image(photo_name) | |
| if img is not None: | |
| images.append(img) | |
| return images[:top_k] | |
| iface = gr.Interface( | |
| title="Text to Image using CLIP Model 📸", | |
| description="Gradio Demo for CLIP model. \n To use it, simply write which image you are looking for", | |
| fn=search_text, | |
| inputs=[ | |
| gr.Textbox( | |
| lines=4, | |
| label="Write what you are looking for in an image...", | |
| placeholder="Text Here...", | |
| ), | |
| gr.Slider(5, 15, step=5), | |
| ], | |
| outputs=[ | |
| gr.Gallery( | |
| label="Generated images", show_label=False, elem_id="gallery" | |
| ) | |
| ], | |
| examples=[ | |
| [("Dog in the beach"), 5], | |
| [("Paris during night."), 10], | |
| [("A cute kangaroo"), 5], | |
| [("Picnic Spots"), 10], | |
| [("Desert"), 5], | |
| [("A racetrack"), 15], | |
| ], | |
| ).launch() |