import os import sys sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from abc import ABC, abstractmethod from utils import configs from utils.functional import check_data_type_variable, get_device class BaseModelMainModel(ABC): def __init__( self, name_model: str, freeze_model: bool, pretrained_model: bool, support_set_method: str, ): self.name_model = name_model self.freeze_model = freeze_model self.pretrained_model = pretrained_model self.support_set_method = support_set_method self.device = get_device() self.check_arguments() def check_arguments(self): check_data_type_variable(self.name_model, str) check_data_type_variable(self.freeze_model, bool) check_data_type_variable(self.pretrained_model, bool) check_data_type_variable(self.support_set_method, str) old_name_model = self.name_model if self.name_model == configs.CLIP_NAME_MODEL: old_name_model = self.name_model self.name_model = "clip" if self.name_model not in tuple(configs.NAME_MODELS.keys()): raise ValueError(f"Model {self.name_model} not supported") if self.support_set_method not in configs.SUPPORT_SET_METHODS: raise ValueError( f"Support set method {self.support_set_method} not supported" ) self.name_model = old_name_model @abstractmethod def init_model(self): pass @abstractmethod def predict(self): pass