Spaces:
Sleeping
Sleeping
import random | |
import torch | |
from .feature_extractor import FeatureExtractor | |
from .storage import ImgFeatureStorage | |
from .image import random_crop | |
DEF_N_CROPS = 10 | |
class CroppingImgRetriever: | |
def __init__(self, extractor, storage, n_crops=DEF_N_CROPS): | |
self.extractor = extractor | |
self.storage = storage | |
self.n_crops = n_crops | |
def process(self, input_img): | |
max_similarity = -1 | |
most_similar_img = None | |
input_img_crop_coords = None | |
for i in range(self.n_crops): | |
cropped_img, crop_coords = random_crop(input_img) | |
similar_img, similarity = self.get_similar_img(cropped_img) | |
if similarity > max_similarity: | |
max_similarity = similarity | |
most_similar_img = similar_img | |
input_img_crop_coords = crop_coords | |
return most_similar_img, input_img_crop_coords | |
def get_similar_img(self, input_img): | |
input_features = self.extractor.predict(input_img) | |
storage_features = self.storage.get_features_tensor() | |
similarities = torch.nn.functional.cosine_similarity(input_features, storage_features) | |
most_similar_index = torch.argmax(similarities) | |
similar_img = self.storage.ix(most_similar_index) | |
return similar_img, similarities[most_similar_index] | |
def from_filenames(cls, extractor_filename, storage_filename, n_crops=DEF_N_CROPS): | |
extractor = FeatureExtractor.load(extractor_filename) | |
storage = ImgFeatureStorage.load(storage_filename) | |
return cls(extractor, storage, n_crops=n_crops) |