File size: 558 Bytes
d30298c 402b016 d30298c 402b016 d30298c 402b016 d30298c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from facenet_pytorch import InceptionResnetV1
from transformers import PreTrainedModel
from .deepfakeconfig import DeepFakeConfig
class DeepFakeModel(PreTrainedModel):
config_class = DeepFakeConfig
def __init__(self, config):
super().__init__(config)
self.model = InceptionResnetV1(
pretrained="vggface2",
classify=True,
num_classes=1,
device=config.DEVICE
)
DeepFakeConfig.register_for_auto_class()
DeepFakeModel.register_for_auto_class("AutoModelForImageClassification")
|