dinov2-small / handler.py
Yusif
Upload folder using huggingface_hub
69cf133 verified
from typing import Dict, List, Any
import torch
from transformers import AutoModel, AutoImageProcessor
import base64
from PIL import Image
import io
class EndpointHandler():
def __init__(self, path="facebook/dinov2-small"):
# Load DINOv2 model and image processor
self.model = AutoModel.from_pretrained(path)
self.processor = AutoImageProcessor.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# Get images from request
images_b64 = data.pop("inputs", data)
# Decode base64 images
images = []
for img_b64 in images_b64:
img = Image.open(io.BytesIO(base64.b64decode(img_b64)))
images.append(img)
# Process images
inputs = self.processor(images=images, return_tensors="pt")
# Get embeddings
with torch.no_grad():
outputs = self.model(**inputs)
# Get global image embedding by averaging the last hidden states
image_features = outputs.last_hidden_state.mean(dim=1)
# Calculate similarity if 2 images provided
if len(images) == 2:
similarity = torch.cosine_similarity(
image_features[0], image_features[1], dim=0
).item()
return [{"similarity": similarity, "embeddings": image_features.tolist()}]
return [{"embeddings": image_features.tolist()}]