0xnewton-superlore
nits throw on bad request
426785e
import base64
import torch
from typing import Dict, List, Any
from io import BytesIO
from transformers import CLIPProcessor, CLIPModel
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
from torch.nn.functional import cosine_similarity
from typing import Union
max_text_list_length = 30
max_image_list_length = 20
class EndpointHandler():
def __init__(self, path: str="", image_size: int=224) -> None:
"""
Initialize the EndpointHandler with a given model path and image size.
Args:
path (str, optional): Path to the pretrained model. Defaults to an empty string.
image_size (int, optional): The size of the images to be processed. Defaults to 224.
"""
self.model = CLIPModel.from_pretrained("Superlore/clip-vit-large-patch14")
self.processor = CLIPProcessor.from_pretrained("Superlore/clip-vit-large-patch14")
self.image_transform = Compose([
Resize(image_size, interpolation=3),
CenterCrop(image_size),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def __call__(self, data: Dict[str, Any]) -> Dict[str, list]:
"""
Process input data containing image and text lists, computing image and text embeddings,
and, if both image and text lists are provided, calculate similarity scores between them.
Args:
data (Dict[str, Any]): A dictionary containing the following key:
- "inputs" (Dict[str, list]): A dictionary containing the following keys:
- "image_list" (List[str]): A list of base64-encoded images.
- "text_list" (Union[List[str], str]): A list of text strings.
Returns:
Dict[str, list]: A dictionary containing the following keys:
- "image_features" (List[List[float]]): A list of image embeddings.
- "text_features" (List[List[float]]): A list of text embeddings.
- "similarity_scores" (List[List[float]]): A list of similarity scores between image and text embeddings.
Empty if either "image_list" or "text_list" is empty.
"""
if not isinstance(data, dict):
raise ValueError("Expected input data to be a dict.")
inputs = data.get("inputs", {})
if not isinstance(inputs, dict):
raise ValueError("Expected 'inputs' to be a dict.")
image_list = inputs.get("image_list", []) # list of b64 images
text_list = inputs.get("text_list", []) # list of texts (or just plain string)
if not isinstance(image_list, list):
raise ValueError("Expected 'image_list' to be a list.")
if not isinstance(text_list, list) and not isinstance(text_list, str):
raise ValueError("Expected 'text_list' to be a list or string.")
if not all(isinstance(image, str) for image in image_list):
raise ValueError("Expected 'image_list' to contain only strings.")
if isinstance(text_list, list) and not all(isinstance(text, str) for text in text_list):
raise ValueError("Expected 'text_list' to contain only strings.")
# if text_list is a string, convert to list
if isinstance(text_list, str):
text_list = [text_list]
if len(image_list) > max_image_list_length:
raise ValueError(f"Expected 'image_list' to have a maximum length of {max_image_list_length}.")
if len(text_list) > max_text_list_length:
raise ValueError(f"Expected 'text_list' to have a maximum length of {max_text_list_length}.")
if not all(is_valid_base64_image(image) for image in image_list):
raise ValueError("Expected 'image_list' to contain only valid base64-encoded images.")
image_features = self.get_image_embeddings(image_list) if len(image_list) > 0 else None
text_features = self.get_text_embeddings(text_list) if len(text_list) > 0 else None
result = {
"image_features": image_features.tolist() if image_features is not None else [],
"text_features": text_features.tolist() if text_features is not None else [],
"similarity_scores": []
}
# if image_features & text_features, compute similarity
if image_features is not None and text_features is not None:
similarity_scores = [cosine_similarity(img_feat, text_features) for img_feat in image_features]
result["similarity_scores"] = [t.tolist() for t in similarity_scores]
return result
def preprocess_images(self, base64_images: List[str]) -> torch.Tensor:
"""Loads a list of images and applies preprocessing steps."""
preprocessed_images = []
for base64_image in base64_images:
# Decode the base64-encoded image and convert it to an RGB image
image_data = base64.b64decode(base64_image)
image = Image.open(BytesIO(image_data)).convert("RGB")
preprocessed_image = self.image_transform(image).unsqueeze(0)
preprocessed_images.append(preprocessed_image)
return torch.cat(preprocessed_images, dim=0)
def get_image_embeddings(self, base64_images: List[str]) -> torch.Tensor:
image_tensors = self.preprocess_images(base64_images)
with torch.no_grad():
self.model.eval()
image_features = self.model.get_image_features(pixel_values=image_tensors)
return image_features
def get_text_embeddings(self, text_list: Union[List[str], str]) -> torch.Tensor:
with torch.no_grad():
# Tokenize the input text list
input_tokens = self.processor(text_list, return_tensors="pt", padding=True, truncation=True)
# Generate the embeddings for the text list
self.model.eval()
text_features = self.model.get_text_features(**input_tokens)
return text_features
def is_valid_base64_image(data: str) -> bool:
try:
# Decode the base64 string
img_data = base64.b64decode(data)
# Open the image using PIL
img = Image.open(BytesIO(img_data))
# Check that the image format is supported
img.verify()
return True
except:
return False