from transformers import pipeline from PIL import Image from io import BytesIO import base64 from typing import Dict, List, Any class EndpointHandler(): def __init__(self, model_path=""): # Initialize the zero-shot object detection pipeline with the specified model # and set the device to GPU for faster computation. self.pipeline = pipeline(task="zero-shot-object-detection", model=model_path, device=0) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Handles incoming requests for zero-shot object detection, decoding the image and predicting labels based on provided candidates. Args: data (Dict[str, Any]): The input data containing an encoded image and candidate labels. Returns: List[Dict[str, Any]]: Predictions with labels and scores for the detected objects. """ # Decode the base64-encoded image to a PIL Image object for processing. image_data = data.get("inputs", {}).get('image', '') image = Image.open(BytesIO(base64.b64decode(image_data))) # Extract candidate labels from the input data. candidate_labels = data.get("inputs", {}).get("candidates", []) # Perform zero-shot object detection using the provided image and candidate labels. detection_results = self.pipeline(image=image, candidate_labels=candidate_labels) # Return the detection results directly, which should match the expected output structure. return detection_results