from transformers import PreTrainedModel from transformers import PretrainedConfig from omegaconf import OmegaConf from models import get_model import yaml class ModelConfig(PretrainedConfig): def __init__( self, **kwargs, ): super().__init__(**kwargs) self.conf = dict(yaml.safe_load(open('pretrained_model/model.yaml'))) class CVLFaceRecognitionModel(PreTrainedModel): config_class = ModelConfig def __init__(self, cfg): super().__init__(cfg) model_conf = OmegaConf.create(cfg.conf) self.model = get_model(model_conf) self.model.load_state_dict_from_path('pretrained_model/model.pt') def forward(self, *args, **kwargs): return self.model(*args, **kwargs)