|
from transformers import PreTrainedModel |
|
from transformers import PretrainedConfig |
|
from omegaconf import OmegaConf |
|
from models import get_model |
|
import yaml |
|
|
|
class ModelConfig(PretrainedConfig): |
|
|
|
def __init__( |
|
self, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.conf = dict(yaml.safe_load(open('pretrained_model/model.yaml'))) |
|
|
|
|
|
class CVLFaceRecognitionModel(PreTrainedModel): |
|
config_class = ModelConfig |
|
|
|
def __init__(self, cfg): |
|
super().__init__(cfg) |
|
model_conf = OmegaConf.create(cfg.conf) |
|
self.model = get_model(model_conf) |
|
self.model.load_state_dict_from_path('pretrained_model/model.pt') |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.model(*args, **kwargs) |
|
|
|
|
|
|
|
|