|
|
|
|
|
|
|
|
import base64 |
|
|
import io |
|
|
from typing import Any, Dict |
|
|
|
|
|
import torch |
|
|
import torchvision.transforms as T |
|
|
from PIL import Image |
|
|
from transformers import AutoImageProcessor, Dinov2ForImageClassification |
|
|
|
|
|
|
|
|
def get_inference_transform(processor: AutoImageProcessor, size: int): |
|
|
"""Get the raw validation transform for direct inference on PIL images.""" |
|
|
normalize = T.Normalize(mean=processor.image_mean, std=processor.image_std) |
|
|
|
|
|
to_rgb = T.Lambda(lambda img: img.convert('RGB')) |
|
|
|
|
|
def pad_to_square(img): |
|
|
w, h = img.size |
|
|
max_size = max(w, h) |
|
|
pad_w = (max_size - w) // 2 |
|
|
pad_h = (max_size - h) // 2 |
|
|
padding = (pad_w, pad_h, max_size - w - pad_w, max_size - h - pad_h) |
|
|
return T.Pad(padding, fill=0)(img) |
|
|
|
|
|
aug = T.Compose([ |
|
|
to_rgb, |
|
|
pad_to_square, |
|
|
T.Resize(size), |
|
|
T.ToTensor(), |
|
|
normalize |
|
|
]) |
|
|
|
|
|
return aug |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
HF Inference Endpoints entry‑point. |
|
|
Loads model/processor once, then uses your *imported* preprocessing |
|
|
on every request. |
|
|
""" |
|
|
|
|
|
def __init__(self, path: str = "", image_size: int = 224): |
|
|
|
|
|
self.processor = AutoImageProcessor.from_pretrained(path or ".") |
|
|
self.model = ( |
|
|
Dinov2ForImageClassification.from_pretrained(path or ".") |
|
|
.eval() |
|
|
) |
|
|
|
|
|
|
|
|
self.transform = get_inference_transform(self.processor, image_size) |
|
|
|
|
|
self.id2label = self.model.config.id2label |
|
|
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Expects {"inputs": "<base64‑encoded image>"}. |
|
|
Returns the top prediction + per‑class probabilities. |
|
|
""" |
|
|
|
|
|
if isinstance(data, (bytes, bytearray)): |
|
|
img_bytes = data |
|
|
|
|
|
|
|
|
elif isinstance(data, dict) and "inputs" in data: |
|
|
inp = data["inputs"] |
|
|
|
|
|
|
|
|
if isinstance(inp, str): |
|
|
img_bytes = base64.b64decode(inp.split(",")[-1]) |
|
|
|
|
|
|
|
|
elif isinstance(inp, (bytes, bytearray)): |
|
|
img_bytes = inp |
|
|
|
|
|
|
|
|
elif hasattr(inp, "convert"): |
|
|
image = inp |
|
|
else: |
|
|
raise ValueError("Unsupported 'inputs' format") |
|
|
|
|
|
else: |
|
|
raise ValueError("Unsupported request body type") |
|
|
|
|
|
|
|
|
if "image" not in locals(): |
|
|
image = Image.open(io.BytesIO(img_bytes)) |
|
|
|
|
|
|
|
|
pixel_values = self.transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = self.model(pixel_values).logits[0] |
|
|
probs = logits.softmax(dim=-1) |
|
|
|
|
|
|
|
|
k = min(5, probs.numel()) |
|
|
topk = torch.topk(probs, k) |
|
|
|
|
|
response = [ |
|
|
{"label": self.id2label[idx.item()], "score": prob.item()} |
|
|
for prob, idx in zip(topk.values, topk.indices) |
|
|
] |
|
|
|
|
|
return response |
|
|
|
|
|
|