from facenet_pytorch import InceptionResnetV1 from transformers import PreTrainedModel from .deepfakeconfig import DeepFakeConfig class DeepFakeModel(PreTrainedModel): config_class = DeepFakeConfig def __init__(self, config): super().__init__(config) self.model = InceptionResnetV1( pretrained="vggface2", classify=True, num_classes=1, device=config.DEVICE ) DeepFakeConfig.register_for_auto_class() DeepFakeModel.register_for_auto_class("AutoModelForImageClassification")