import os import sys sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) import time from abc import ABC, abstractmethod import numpy as np import torch from utils import configs from utils.functional import check_data_type_variable, get_device, image_augmentations class BaseModelImageEmbeddings(ABC): def __init__( self, name_model: str, freeze_model: bool, pretrained_model: bool, support_set_method: str, ): self.name_model = name_model self.freeze_model = freeze_model self.pretrained_model = pretrained_model self.support_set_method = support_set_method self.model = None self.device = get_device() self.check_arguments() def check_arguments(self): check_data_type_variable(self.name_model, str) check_data_type_variable(self.freeze_model, bool) check_data_type_variable(self.pretrained_model, bool) check_data_type_variable(self.support_set_method, str) old_name_model = self.name_model if self.name_model == configs.CLIP_NAME_MODEL: old_name_model = self.name_model self.name_model = "clip" if self.name_model not in tuple(configs.NAME_MODELS.keys()): raise ValueError(f"Model {self.name_model} not supported") if self.support_set_method not in configs.SUPPORT_SET_METHODS: raise ValueError( f"Support set method {self.support_set_method} not supported" ) self.name_model = old_name_model @abstractmethod def init_model(self): pass def get_embeddings(self, image: np.ndarray) -> dict: image_input = image_augmentations()(image=image)["image"] image_input = image_input.unsqueeze(axis=0).to(self.device) with torch.no_grad(): start_time = time.perf_counter() embeddings = self.model(image_input) end_time = time.perf_counter() - start_time embeddings = embeddings.detach().cpu().numpy() return { "embeddings": embeddings, "inference_time": end_time, }