File size: 558 Bytes
d30298c
 
402b016
 
d30298c
402b016
 
d30298c
402b016
 
 
 
 
 
 
 
 
 
 
d30298c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from facenet_pytorch import InceptionResnetV1
from transformers import PreTrainedModel
from .deepfakeconfig import DeepFakeConfig


class DeepFakeModel(PreTrainedModel):
    config_class = DeepFakeConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = InceptionResnetV1(
            pretrained="vggface2",
            classify=True,
            num_classes=1,
            device=config.DEVICE
        )


DeepFakeConfig.register_for_auto_class()
DeepFakeModel.register_for_auto_class("AutoModelForImageClassification")