owlvit-base-patch32 / handler.py
Thomasboosinger's picture
Update handler.py
0606046 verified
from transformers import pipeline
import torch
from PIL import Image
import base64
from io import BytesIO
class EndpointHandler:
def __init__(self, model_path=""):
# Dynamically assign computing device based on availability.
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {'GPU: ' + torch.cuda.get_device_name(0) if self.device == 'cuda' else 'CPU'}")
# Initialize model with the capability to automatically adjust to GPU or CPU.
self.pipeline = pipeline("zero-shot-object-detection", model=model_path, device=0 if self.device == 'cuda' else -1)
def __call__(self, data):
Decode image, run zero-shot object detection, and return results.
data (dict): Contains base64-encoded image and candidate labels.
list[dict]: Each dict contains a label and its score from object detection.
# Decode the base64 image to PIL format.
image = Image.open(BytesIO(base64.b64decode(data['inputs']['image'])))
# Run detection and obtain results.
results = self.pipeline(image=image, candidate_labels=data['inputs']['candidates'], threshold = .01)
return results