|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import AutoImageProcessor, Dinov2Model |
|
|
from PIL import Image |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
import imagehash |
|
|
|
|
|
MODEL_NAME = "facebook/dinov2-small" |
|
|
processor = AutoImageProcessor.from_pretrained(MODEL_NAME) |
|
|
|
|
|
print(f"Resize Strategy: {processor.size}") |
|
|
print(f"Do Center Crop?: {processor.do_center_crop}") |
|
|
print(f"Crop Size: {processor.crop_size}") |
|
|
|
|
|
model = Dinov2Model.from_pretrained(MODEL_NAME) |
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="API de Embedding de Imagem", |
|
|
description="Endpoint para obter o embedding e pHash de uma imagem.", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
class ImageRequest(BaseModel): |
|
|
image: str |
|
|
target_dim: int = 512 |
|
|
use_float16: bool = False |
|
|
|
|
|
@app.post("/embed") |
|
|
async def get_embedding(request: ImageRequest): |
|
|
try: |
|
|
header, img_base64 = request.image.split(",", 1) |
|
|
image_data = base64.b64decode(img_base64) |
|
|
image = Image.open(BytesIO(image_data)).convert("RGB") |
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
last_hidden_state = outputs.last_hidden_state |
|
|
embedding = last_hidden_state[:, 0] |
|
|
|
|
|
|
|
|
if request.target_dim == 384: |
|
|
final_embedding = embedding |
|
|
elif request.target_dim == 512: |
|
|
projection = nn.Linear(model.config.hidden_size, 512) |
|
|
final_embedding = projection(embedding) |
|
|
else: |
|
|
raise HTTPException(status_code=400, detail="Dimensão inválida. Escolha entre 384 ou 512.") |
|
|
|
|
|
|
|
|
if request.use_float16: |
|
|
final_embedding = final_embedding.half() |
|
|
|
|
|
phash = str(imagehash.phash(image)) |
|
|
return { |
|
|
"embedding": final_embedding.squeeze().tolist(), |
|
|
"phash": phash |
|
|
} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=f"Erro ao processar a imagem: {e}") |