chicelli's picture
Upload predict.py
ed54141 verified
raw
history blame contribute delete
969 Bytes
import os
from io import BytesIO
import numpy as np
import torch
from PIL import Image
from img2art_search.data.dataset import ImageRetrievalDataset
from img2art_search.data.transforms import transform
from img2art_search.models.compute_embeddings import search_image
def predict(img: Image.Image) -> list:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if img:
img = img.convert("RGB")
pred_image_data = transform(img).unsqueeze(0).to(DEVICE)
indices, distances = search_image(pred_image_data)
results = []
for index, distance in zip(indices, distances):
buffered = BytesIO(index)
image = Image.open(buffered)
decoded_image_array = np.array(image)
results.append(
(
Image.fromarray(decoded_image_array),
str(distance),
)
)
return results
else:
return []