Hamidreza-Hashemp's picture
Update app.py
38a364a
raw
history blame contribute delete
No virus
2.44 kB
import argparse
import os
import re
import time
import cv2
import torch
import pandas as pd
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set
from training.zoo.classifiers import DeepFakeClassifier
import gradio as gr
def deepfakeclassifier(potential_test_video, option):
if option == 'Pretrained':
weights_dir = "./weights"
models_dir = ["Original_DeepFakeClassifier_tf_efficientnet_b7_ns"]
else:
weights_dir = "./weights"
models_dir = ["Custom_classifier_DeepFakeClassifier_tf_efficientnet_b7_ns"]
parts = potential_test_video.split("/")
test_videos = [parts[-1]]
parts[0] += "/"
test_dir = parts[:-1]
test_dir = os.path.join(*test_dir)
models = []
model_paths = [os.path.join(weights_dir, model) for model in models_dir]
for path in model_paths:
model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to('cpu')
print("loading state dict {}".format(path))
checkpoint = torch.load(path, map_location="cpu")
state_dict = checkpoint.get("state_dict", checkpoint)
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
model.eval()
del checkpoint
models.append(model.float())
frames_per_video = 32
video_reader = VideoReader()
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
face_extractor = FaceExtractor(video_read_fn)
input_size = 380
strategy = confident_strategy
stime = time.time()
print("Predicting {} videos".format(len(test_videos)))
predictions = predict_on_video_set(face_extractor=face_extractor, input_size=input_size, models=models,
strategy=strategy, frames_per_video=frames_per_video, videos=test_videos,
num_workers=6, test_dir=test_dir)
print("Elapsed:", time.time() - stime)
return "This video is FAKE with {} probability!".format(predictions[0])
demo = gr.Interface(fn=deepfakeclassifier, inputs=[gr.Video(),
gr.Radio(["Pretrained", "Scratch"])] ,outputs="text", description="Pretrained option is training over the winning idea. Scratch is my training from \
the scratch. Pretrained optional performs better as it is trained with much more data for training!")
demo.launch(debug=True)