k-lauren's picture
Add handler, requirements, and update preprocessor config for Inference Endpoints
f0e6dfd
import base64
import io
from typing import Any, Dict, List
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForObjectDetection
class EndpointHandler:
def __init__(self, path: str = ""):
self.processor = AutoImageProcessor.from_pretrained(path)
self.model = AutoModelForObjectDetection.from_pretrained(path)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.get("inputs", data)
# Handle base64-encoded image
if isinstance(inputs, str):
image_bytes = base64.b64decode(inputs)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
elif isinstance(inputs, bytes):
image = Image.open(io.BytesIO(inputs)).convert("RGB")
elif isinstance(inputs, Image.Image):
image = inputs.convert("RGB")
else:
raise ValueError(
"Unsupported input type. Provide a base64-encoded image string or raw bytes."
)
# Run inference
with torch.no_grad():
encoded = self.processor(images=image, return_tensors="pt")
outputs = self.model(**encoded)
# Post-process: convert to bounding boxes
target_size = torch.tensor([image.size[::-1]]) # (height, width)
results = self.processor.post_process_object_detection(
outputs, threshold=0.5, target_sizes=target_size
)[0]
detections = []
for score, label, box in zip(
results["scores"], results["labels"], results["boxes"]
):
xmin, ymin, xmax, ymax = box.tolist()
detections.append(
{
"score": round(score.item(), 4),
"label": self.model.config.id2label[label.item()],
"box": {
"xmin": round(xmin, 2),
"ymin": round(ymin, 2),
"xmax": round(xmax, 2),
"ymax": round(ymax, 2),
},
}
)
return detections