from typing import Union import gradio as gr from numpy import empty import open_clip import torch import PIL.Image as Image # Set device to GPU if available device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"PyTorch Device {device}") # Load the OpenCLIP model and the necessary preprocessors # openclip_model = 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' # openclip_model = 'laion/CLIP-ViT-B-16-laion2B-s34B-b88K' openclip_model_name = "laion/CLIP-ViT-L-14-laion2B-s32B-b82K" openclip_model = "hf-hub:" + openclip_model_name model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms( model_name=openclip_model, device=device ) # Define function to generate text embeddings # @spaces.GPU def generate_text_embedding(text_data: Union[str, tuple[str]]) -> list[str]: """ Generate embeddings for text data using the OpenCLIP model. Parameters ---------- text_data : str or tuple of str Text data to embed. Returns ------- text_embeddings : list of str List of text embeddings. """ # Embed text data text_embeddings = [] empty_data_indices = [] if text_data: # If text_data is a string, convert to list of strings if isinstance(text_data, str): text_data = [text_data] # If text_data is a tuple of strings, convert to list of strings if isinstance(text_data, tuple): text_data = list(text_data) # If text_data is not a list of strings, raise error if not isinstance(text_data, list): raise TypeError("text_data must be a string or a tuple of strings.") # Keep track of indices of empty text strings empty_data_indices = [i for i, text in enumerate(text_data) if text == ""] # Remove empty text strings text_data = [text for text in text_data if text != ""] if text_data: # Tokenize text_data and convert to tensor text_data = open_clip.tokenize(text_data).to(device) # Generate text embeddings with torch.no_grad(): text_embeddings = model.encode_text(text_data) # Convert embeddings to list of strings text_embeddings = [ embedding.detach().cpu().numpy().tolist() for embedding in text_embeddings ] # Insert empty strings at indices of empty text strings for i in empty_data_indices: text_embeddings.insert(i, "") return text_embeddings # Define function to generate image embeddings def generate_image_embedding( image_data: Union[Image.Image, tuple[Image.Image]] ) -> list[str]: """ Generate embeddings for image data using the OpenCLIP model. Parameters ---------- image_data : PIL.Image.Image or tuple of PIL.Image.Image Image data to embed. Returns ------- image_embeddings : list of str List of image embeddings. """ # Embed image data image_embeddings = [] empty_data_indices = [] if image_data: # If image_data is a single PIL image, convert to list of PIL images if isinstance(image_data, Image.Image): image_data = [image_data] # If image_data is a tuple of images, convert to list of images if isinstance(image_data, tuple): image_data = list(image_data) # Keep track of indices of None images empty_data_indices = [i for i, img in enumerate(image_data) if img is None] # Remove None images image_data = [img for img in image_data if img is not None] if image_data: # Preprocess image_data and convert to tensor image_data = [preprocess_val(img).unsqueeze(0) for img in image_data] image_data = torch.stack(image_data).squeeze(1).to(device) # Generate image embeddings with torch.no_grad(): image_embeddings = model.encode_image(image_data) # Convert embeddings to list of strings image_embeddings = [ embedding.detach().cpu().numpy().tolist() for embedding in image_embeddings ] # Insert empty strings at indices of empty images for i in empty_data_indices: image_embeddings.insert(i, "") return image_embeddings # Define function to generate embeddings def generate_embedding( text_data: Union[str, tuple[str]], image_data: Union[Image.Image, tuple[Image.Image]], ) -> tuple[list[str], list[str], list[str]]: """ Generate embeddings for text and image data using the OpenCLIP model. Parameters ---------- text_data : str or tuple of str Text data to embed. image_data : PIL.Image.Image or tuple of PIL.Image.Image Image data to embed. Returns ------- text_embeddings : list of str List of text embeddings. image_embeddings : list of str List of image embeddings. similarity : list of str List of cosine similarity between text and image embeddings. """ # Embed text data text_embeddings = generate_text_embedding(text_data) # Embed image data image_embeddings = generate_image_embedding(image_data) # Calculate cosine similarity between text and image embeddings similarity = [] empty_data_indices = [] if text_embeddings and image_embeddings: # Filter out embedding pairs with either empty text or image embeddings, tracking indices of empty embeddings text_embeddings_filtered = [] image_embeddings_filtered = [] for i, (text_embedding, image_embedding) in enumerate( zip(text_embeddings, image_embeddings) ): if text_embedding != "" and image_embedding != "": text_embeddings_filtered.append(text_embedding) image_embeddings_filtered.append(image_embedding) else: empty_data_indices.append(i) # Calculate cosine similarity if there are any non-empty embedding pairs if image_embeddings_filtered and text_embeddings_filtered: # Convert lists back to tensors for processing text_embeddings_tensor = torch.tensor(text_embeddings_filtered) image_embeddings_tensor = torch.tensor(image_embeddings_filtered) # Normalize the embeddings text_embedding_norm = text_embeddings_tensor / text_embeddings_tensor.norm( dim=-1, keepdim=True ) image_embedding_norm = ( image_embeddings_tensor / image_embeddings_tensor.norm(dim=-1, keepdim=True) ) # Calculate cosine similarity similarity = torch.nn.functional.cosine_similarity( text_embedding_norm, image_embedding_norm, dim=-1 ) # Convert to percentage as text similarity = [f"{sim.item() * 100:.2f}%" for sim in similarity] # Insert empty text strings in similarity for i in empty_data_indices: similarity.insert(i, "") return (text_embeddings, image_embeddings, similarity, openclip_model_name) # Define Gradio interface demo = gr.Interface( fn=generate_embedding, inputs=[ gr.Textbox( lines=5, max_lines=5, placeholder="Enter Text Here...", label="Text to Embed", ), gr.Image(height=512, type="pil", label="Image to Embed"), ], outputs=[ gr.Textbox(lines=5, max_lines=5, label="Text Embedding", autoscroll=False), gr.Textbox(lines=5, max_lines=5, label="Image Embedding", autoscroll=False), gr.Textbox(label="Cosine Similarity"), gr.Textbox(label="Embedding Model"), ], title="OpenCLIP Embedding Generator", description="Generate embeddings using OpenCLIP model for text and images.", allow_flagging="never", batch=False, api_name="embed", ) # Enable queueing and launch the app if __name__ == "__main__": demo.queue(api_open=True).launch(show_api=True)