import os import sys sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) import time from abc import ABC, abstractmethod import numpy as np import torch from utils import configs from utils.functional import ( check_data_type_variable, euclidean_distance_normalized, get_device, image_augmentations, ) class BaseModelImageSimilarity(ABC): def __init__( self, name_model: str, freeze_model: bool, pretrained_model: bool, support_set_method: str, ): self.name_model = name_model self.freeze_model = freeze_model self.pretrained_model = pretrained_model self.support_set_method = support_set_method self.model = None self.device = get_device() self.check_arguments() def check_arguments(self): check_data_type_variable(self.name_model, str) check_data_type_variable(self.freeze_model, bool) check_data_type_variable(self.pretrained_model, bool) check_data_type_variable(self.support_set_method, str) old_name_model = self.name_model if self.name_model == configs.CLIP_NAME_MODEL: old_name_model = self.name_model self.name_model = "clip" if self.name_model not in tuple(configs.NAME_MODELS.keys()): raise ValueError(f"Model {self.name_model} not supported") if self.support_set_method not in configs.SUPPORT_SET_METHODS: raise ValueError( f"Support set method {self.support_set_method} not supported" ) self.name_model = old_name_model @abstractmethod def init_model(self): pass def get_similarity(self, image1: np.ndarray, image2: np.ndarray) -> dict: image1_input = image_augmentations()(image=image1)["image"] image2_input = image_augmentations()(image=image2)["image"] image1_input = image1_input.unsqueeze(axis=0).to(self.device) image2_input = image2_input.unsqueeze(axis=0).to(self.device) with torch.no_grad(): start_time = time.perf_counter() image1_input = self.model(image1_input) image2_input = self.model(image2_input) end_time = time.perf_counter() - start_time image1_input = image1_input.detach().cpu().numpy() image2_input = image2_input.detach().cpu().numpy() similarity = euclidean_distance_normalized(image1_input, image2_input) result_similarity = ( "same image" if similarity > configs.NAME_MODELS[self.name_model]["image_similarity_threshold"] else "not same image" ) return { "similarity": similarity, "result_similarity": result_similarity, "inference_time": end_time, }