import base64 import torch from typing import Dict, List, Any from io import BytesIO from transformers import CLIPProcessor, CLIPModel from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from PIL import Image from torch.nn.functional import cosine_similarity from typing import Union max_text_list_length = 30 max_image_list_length = 20 class EndpointHandler(): def __init__(self, path: str="", image_size: int=224) -> None: """ Initialize the EndpointHandler with a given model path and image size. Args: path (str, optional): Path to the pretrained model. Defaults to an empty string. image_size (int, optional): The size of the images to be processed. Defaults to 224. """ self.model = CLIPModel.from_pretrained("Superlore/clip-vit-large-patch14") self.processor = CLIPProcessor.from_pretrained("Superlore/clip-vit-large-patch14") self.image_transform = Compose([ Resize(image_size, interpolation=3), CenterCrop(image_size), ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) def __call__(self, data: Dict[str, Any]) -> Dict[str, list]: """ Process input data containing image and text lists, computing image and text embeddings, and, if both image and text lists are provided, calculate similarity scores between them. Args: data (Dict[str, Any]): A dictionary containing the following key: - "inputs" (Dict[str, list]): A dictionary containing the following keys: - "image_list" (List[str]): A list of base64-encoded images. - "text_list" (Union[List[str], str]): A list of text strings. Returns: Dict[str, list]: A dictionary containing the following keys: - "image_features" (List[List[float]]): A list of image embeddings. - "text_features" (List[List[float]]): A list of text embeddings. - "similarity_scores" (List[List[float]]): A list of similarity scores between image and text embeddings. Empty if either "image_list" or "text_list" is empty. """ if not isinstance(data, dict): raise ValueError("Expected input data to be a dict.") inputs = data.get("inputs", {}) if not isinstance(inputs, dict): raise ValueError("Expected 'inputs' to be a dict.") image_list = inputs.get("image_list", []) # list of b64 images text_list = inputs.get("text_list", []) # list of texts (or just plain string) if not isinstance(image_list, list): raise ValueError("Expected 'image_list' to be a list.") if not isinstance(text_list, list) and not isinstance(text_list, str): raise ValueError("Expected 'text_list' to be a list or string.") if not all(isinstance(image, str) for image in image_list): raise ValueError("Expected 'image_list' to contain only strings.") if isinstance(text_list, list) and not all(isinstance(text, str) for text in text_list): raise ValueError("Expected 'text_list' to contain only strings.") # if text_list is a string, convert to list if isinstance(text_list, str): text_list = [text_list] if len(image_list) > max_image_list_length: raise ValueError(f"Expected 'image_list' to have a maximum length of {max_image_list_length}.") if len(text_list) > max_text_list_length: raise ValueError(f"Expected 'text_list' to have a maximum length of {max_text_list_length}.") if not all(is_valid_base64_image(image) for image in image_list): raise ValueError("Expected 'image_list' to contain only valid base64-encoded images.") image_features = self.get_image_embeddings(image_list) if len(image_list) > 0 else None text_features = self.get_text_embeddings(text_list) if len(text_list) > 0 else None result = { "image_features": image_features.tolist() if image_features is not None else [], "text_features": text_features.tolist() if text_features is not None else [], "similarity_scores": [] } # if image_features & text_features, compute similarity if image_features is not None and text_features is not None: similarity_scores = [cosine_similarity(img_feat, text_features) for img_feat in image_features] result["similarity_scores"] = [t.tolist() for t in similarity_scores] return result def preprocess_images(self, base64_images: List[str]) -> torch.Tensor: """Loads a list of images and applies preprocessing steps.""" preprocessed_images = [] for base64_image in base64_images: # Decode the base64-encoded image and convert it to an RGB image image_data = base64.b64decode(base64_image) image = Image.open(BytesIO(image_data)).convert("RGB") preprocessed_image = self.image_transform(image).unsqueeze(0) preprocessed_images.append(preprocessed_image) return torch.cat(preprocessed_images, dim=0) def get_image_embeddings(self, base64_images: List[str]) -> torch.Tensor: image_tensors = self.preprocess_images(base64_images) with torch.no_grad(): self.model.eval() image_features = self.model.get_image_features(pixel_values=image_tensors) return image_features def get_text_embeddings(self, text_list: Union[List[str], str]) -> torch.Tensor: with torch.no_grad(): # Tokenize the input text list input_tokens = self.processor(text_list, return_tensors="pt", padding=True, truncation=True) # Generate the embeddings for the text list self.model.eval() text_features = self.model.get_text_features(**input_tokens) return text_features def is_valid_base64_image(data: str) -> bool: try: # Decode the base64 string img_data = base64.b64decode(data) # Open the image using PIL img = Image.open(BytesIO(img_data)) # Check that the image format is supported img.verify() return True except: return False