import torch from typing import Dict, Any, List from PIL import Image import base64 from io import BytesIO import logging class EndpointHandler: """ A handler class for processing image and text data, generating embeddings using a specified model and processor. Attributes: model: The pre-trained model used for generating embeddings. processor: The pre-trained processor used to process images and text before model inference. device: The device (CPU or CUDA) used to run model inference. default_batch_size: The default batch size for processing images and text in batches. """ def __init__(self, path: str = "", default_batch_size: int = 4): """ Initializes the EndpointHandler with a specified model path and default batch size. Args: path (str): Path to the pre-trained model and processor. default_batch_size (int): Default batch size for processing images and text data. """ # Initialize logging logging.basicConfig(level=logging.INFO) self.logger = logging.getLogger(__name__) from colpali_engine.models import ColQwen2, ColQwen2Processor self.logger.info("Initializing model and processor.") try: self.model = ColQwen2.from_pretrained( path, torch_dtype=torch.bfloat16, device_map="auto", ).eval() self.processor = ColQwen2Processor.from_pretrained(path) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.default_batch_size = default_batch_size self.logger.info("Initialization complete.") except Exception as e: self.logger.error(f"Failed to initialize model or processor: {e}") raise def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]: """ Processes a batch of images and generates embeddings. Args: images (List[Image.Image]): List of images to process. Returns: List[List[float]]: List of embeddings for each image. """ self.logger.debug(f"Processing batch of {len(images)} images.") try: batch_images = self.processor.process_images(images).to(self.device) with torch.no_grad(): image_embeddings = self.model(**batch_images) self.logger.debug("Image batch processing complete.") return image_embeddings.cpu().tolist() except Exception as e: self.logger.error(f"Error processing image batch: {e}") raise def _process_text_batch(self, texts: List[str]) -> List[List[float]]: """ Processes a batch of text queries and generates embeddings. Args: texts (List[str]): List of text queries to process. Returns: List[List[float]]: List of embeddings for each text query. """ self.logger.debug(f"Processing batch of {len(texts)} text queries.") try: batch_queries = self.processor.process_queries(texts).to(self.device) with torch.no_grad(): query_embeddings = self.model(**batch_queries) self.logger.debug("Text batch processing complete.") return query_embeddings.cpu().tolist() except Exception as e: self.logger.error(f"Error processing text batch: {e}") raise def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings. Args: data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size. Returns: Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages. """ images_data = data.get("image", []) text_data = data.get("text", []) batch_size = data.get("batch_size", self.default_batch_size) # Decode and process images images = [] if images_data: self.logger.info("Decoding images from base64.") for img_data in images_data: if isinstance(img_data, str): try: image_bytes = base64.b64decode(img_data) image = Image.open(BytesIO(image_bytes)).convert("RGB") images.append(image) except Exception as e: self.logger.error(f"Invalid image data: {e}") return {"error": f"Invalid image data: {e}"} else: self.logger.error("Images should be base64-encoded strings.") return {"error": "Images should be base64-encoded strings."} image_embeddings = [] if images: self.logger.info("Processing image embeddings.") try: for i in range(0, len(images), batch_size): batch_images = images[i : i + batch_size] batch_embeddings = self._process_image_batch(batch_images) image_embeddings.extend(batch_embeddings) except Exception as e: self.logger.error(f"Error generating image embeddings: {e}") return {"error": f"Error generating image embeddings: {e}"} # Process text data text_embeddings = [] if text_data: self.logger.info("Processing text embeddings.") try: for i in range(0, len(text_data), batch_size): batch_texts = text_data[i : i + batch_size] batch_text_embeddings = self._process_text_batch(batch_texts) text_embeddings.extend(batch_text_embeddings) except Exception as e: self.logger.error(f"Error generating text embeddings: {e}") return {"error": f"Error generating text embeddings: {e}"} # Compute similarity scores if both image and text embeddings are available scores = [] if image_embeddings and text_embeddings: self.logger.info("Computing similarity scores.") try: image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device) text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device) with torch.no_grad(): scores = ( self.processor.score_multi_vector( text_embeddings_tensor, image_embeddings_tensor ) .cpu() .tolist() ) self.logger.info("Similarity scoring complete.") except Exception as e: self.logger.error(f"Error computing similarity scores: {e}") return {"error": f"Error computing similarity scores: {e}"} return {"image": image_embeddings, "text": text_embeddings, "scores": scores}