isitreallyworthit / iirwi /retrieval.py
danibalcells's picture
Clean up
1b59f08
raw
history blame
1.63 kB
import random
import torch
from .feature_extractor import FeatureExtractor
from .storage import ImgFeatureStorage
from .image import random_crop
DEF_N_CROPS = 10
class CroppingImgRetriever:
def __init__(self, extractor, storage, n_crops=DEF_N_CROPS):
self.extractor = extractor
self.storage = storage
self.n_crops = n_crops
def process(self, input_img):
max_similarity = -1
most_similar_img = None
input_img_crop_coords = None
for i in range(self.n_crops):
cropped_img, crop_coords = random_crop(input_img)
similar_img, similarity = self.get_similar_img(cropped_img)
if similarity > max_similarity:
max_similarity = similarity
most_similar_img = similar_img
input_img_crop_coords = crop_coords
return most_similar_img, input_img_crop_coords
def get_similar_img(self, input_img):
input_features = self.extractor.predict(input_img)
storage_features = self.storage.get_features_tensor()
similarities = torch.nn.functional.cosine_similarity(input_features, storage_features)
most_similar_index = torch.argmax(similarities)
similar_img = self.storage.ix(most_similar_index)
return similar_img, similarities[most_similar_index]
@classmethod
def from_filenames(cls, extractor_filename, storage_filename, n_crops=DEF_N_CROPS):
extractor = FeatureExtractor.load(extractor_filename)
storage = ImgFeatureStorage.load(storage_filename)
return cls(extractor, storage, n_crops=n_crops)