Spaces:
Running
Running
import os | |
from io import BytesIO | |
import numpy as np | |
import torch | |
from PIL import Image | |
from img2art_search.data.dataset import ImageRetrievalDataset | |
from img2art_search.data.transforms import transform | |
from img2art_search.models.compute_embeddings import search_image | |
def predict(img: Image.Image) -> list: | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
if img: | |
img = img.convert("RGB") | |
pred_image_data = transform(img).unsqueeze(0).to(DEVICE) | |
indices, distances = search_image(pred_image_data) | |
results = [] | |
for index, distance in zip(indices, distances): | |
buffered = BytesIO(index) | |
image = Image.open(buffered) | |
decoded_image_array = np.array(image) | |
results.append( | |
( | |
Image.fromarray(decoded_image_array), | |
str(distance), | |
) | |
) | |
return results | |
else: | |
return [] | |