|
import base64 |
|
import torch |
|
from typing import Dict, List, Any |
|
from io import BytesIO |
|
from transformers import CLIPProcessor, CLIPModel |
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize |
|
from PIL import Image |
|
from torch.nn.functional import cosine_similarity |
|
from typing import Union |
|
|
|
max_text_list_length = 30 |
|
max_image_list_length = 20 |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path: str="", image_size: int=224) -> None: |
|
""" |
|
Initialize the EndpointHandler with a given model path and image size. |
|
|
|
Args: |
|
path (str, optional): Path to the pretrained model. Defaults to an empty string. |
|
image_size (int, optional): The size of the images to be processed. Defaults to 224. |
|
""" |
|
self.model = CLIPModel.from_pretrained("Superlore/clip-vit-large-patch14") |
|
self.processor = CLIPProcessor.from_pretrained("Superlore/clip-vit-large-patch14") |
|
self.image_transform = Compose([ |
|
Resize(image_size, interpolation=3), |
|
CenterCrop(image_size), |
|
ToTensor(), |
|
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
|
]) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, list]: |
|
""" |
|
Process input data containing image and text lists, computing image and text embeddings, |
|
and, if both image and text lists are provided, calculate similarity scores between them. |
|
|
|
Args: |
|
data (Dict[str, Any]): A dictionary containing the following key: |
|
- "inputs" (Dict[str, list]): A dictionary containing the following keys: |
|
- "image_list" (List[str]): A list of base64-encoded images. |
|
- "text_list" (Union[List[str], str]): A list of text strings. |
|
|
|
Returns: |
|
Dict[str, list]: A dictionary containing the following keys: |
|
- "image_features" (List[List[float]]): A list of image embeddings. |
|
- "text_features" (List[List[float]]): A list of text embeddings. |
|
- "similarity_scores" (List[List[float]]): A list of similarity scores between image and text embeddings. |
|
Empty if either "image_list" or "text_list" is empty. |
|
""" |
|
if not isinstance(data, dict): |
|
raise ValueError("Expected input data to be a dict.") |
|
|
|
inputs = data.get("inputs", {}) |
|
|
|
if not isinstance(inputs, dict): |
|
raise ValueError("Expected 'inputs' to be a dict.") |
|
|
|
image_list = inputs.get("image_list", []) |
|
text_list = inputs.get("text_list", []) |
|
|
|
if not isinstance(image_list, list): |
|
raise ValueError("Expected 'image_list' to be a list.") |
|
if not isinstance(text_list, list) and not isinstance(text_list, str): |
|
raise ValueError("Expected 'text_list' to be a list or string.") |
|
if not all(isinstance(image, str) for image in image_list): |
|
raise ValueError("Expected 'image_list' to contain only strings.") |
|
if isinstance(text_list, list) and not all(isinstance(text, str) for text in text_list): |
|
raise ValueError("Expected 'text_list' to contain only strings.") |
|
|
|
|
|
if isinstance(text_list, str): |
|
text_list = [text_list] |
|
|
|
if len(image_list) > max_image_list_length: |
|
raise ValueError(f"Expected 'image_list' to have a maximum length of {max_image_list_length}.") |
|
if len(text_list) > max_text_list_length: |
|
raise ValueError(f"Expected 'text_list' to have a maximum length of {max_text_list_length}.") |
|
if not all(is_valid_base64_image(image) for image in image_list): |
|
raise ValueError("Expected 'image_list' to contain only valid base64-encoded images.") |
|
|
|
image_features = self.get_image_embeddings(image_list) if len(image_list) > 0 else None |
|
text_features = self.get_text_embeddings(text_list) if len(text_list) > 0 else None |
|
|
|
result = { |
|
"image_features": image_features.tolist() if image_features is not None else [], |
|
"text_features": text_features.tolist() if text_features is not None else [], |
|
"similarity_scores": [] |
|
} |
|
|
|
if image_features is not None and text_features is not None: |
|
similarity_scores = [cosine_similarity(img_feat, text_features) for img_feat in image_features] |
|
result["similarity_scores"] = [t.tolist() for t in similarity_scores] |
|
|
|
return result |
|
|
|
def preprocess_images(self, base64_images: List[str]) -> torch.Tensor: |
|
"""Loads a list of images and applies preprocessing steps.""" |
|
preprocessed_images = [] |
|
for base64_image in base64_images: |
|
|
|
image_data = base64.b64decode(base64_image) |
|
image = Image.open(BytesIO(image_data)).convert("RGB") |
|
preprocessed_image = self.image_transform(image).unsqueeze(0) |
|
preprocessed_images.append(preprocessed_image) |
|
|
|
return torch.cat(preprocessed_images, dim=0) |
|
|
|
def get_image_embeddings(self, base64_images: List[str]) -> torch.Tensor: |
|
image_tensors = self.preprocess_images(base64_images) |
|
|
|
with torch.no_grad(): |
|
self.model.eval() |
|
image_features = self.model.get_image_features(pixel_values=image_tensors) |
|
|
|
return image_features |
|
|
|
def get_text_embeddings(self, text_list: Union[List[str], str]) -> torch.Tensor: |
|
with torch.no_grad(): |
|
|
|
input_tokens = self.processor(text_list, return_tensors="pt", padding=True, truncation=True) |
|
|
|
|
|
self.model.eval() |
|
text_features = self.model.get_text_features(**input_tokens) |
|
return text_features |
|
|
|
|
|
def is_valid_base64_image(data: str) -> bool: |
|
try: |
|
|
|
img_data = base64.b64decode(data) |
|
|
|
|
|
img = Image.open(BytesIO(img_data)) |
|
|
|
|
|
img.verify() |
|
|
|
return True |
|
except: |
|
return False |