owlvit-base-patch32 / handler.py
Thomasboosinger's picture
Update handler.py
0606046 verified
raw
history blame contribute delete
No virus
1.25 kB
from transformers import pipeline
import torch
from PIL import Image
import base64
from io import BytesIO
class EndpointHandler:
def __init__(self, model_path=""):
# Dynamically assign computing device based on availability.
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {'GPU: ' + torch.cuda.get_device_name(0) if self.device == 'cuda' else 'CPU'}")
# Initialize model with the capability to automatically adjust to GPU or CPU.
self.pipeline = pipeline("zero-shot-object-detection", model=model_path, device=0 if self.device == 'cuda' else -1)
def __call__(self, data):
"""
Decode image, run zero-shot object detection, and return results.
Args:
data (dict): Contains base64-encoded image and candidate labels.
Returns:
list[dict]: Each dict contains a label and its score from object detection.
"""
# Decode the base64 image to PIL format.
image = Image.open(BytesIO(base64.b64decode(data['inputs']['image'])))
# Run detection and obtain results.
results = self.pipeline(image=image, candidate_labels=data['inputs']['candidates'], threshold = .01)
return results