Spaces:
Runtime error
Runtime error
import io | |
import torch | |
from torch.utils.model_zoo import load_url | |
from PIL import Image | |
from scipy.special import expit | |
import matplotlib.pyplot as plt | |
import sys | |
sys.path.append('./icpr2020dfdc/') | |
from blazeface import FaceExtractor, BlazeFace, VideoReader | |
from architectures import fornet,weights | |
from isplutils import utils | |
import gradio as gr | |
""" | |
Choose an architecture between | |
- EfficientNetB4 | |
- EfficientNetB4ST | |
- EfficientNetAutoAttB4 | |
- EfficientNetAutoAttB4ST | |
- Xception | |
""" | |
net_model = 'EfficientNetAutoAttB4' | |
""" | |
Choose a training dataset between | |
- DFDC | |
- FFPP | |
""" | |
train_db = 'DFDC' | |
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') | |
face_policy = 'scale' | |
face_size = 224 | |
frames_per_video = 32 | |
model_url = weights.weight_url['{:s}_{:s}'.format(net_model,train_db)] | |
net = getattr(fornet,net_model)().eval().to(device) | |
net.load_state_dict(load_url(model_url,map_location=device,check_hash=True)) | |
transf = utils.get_transformer(face_policy, face_size, net.get_normalizer(), train=False) | |
facedet = BlazeFace().to(device) | |
facedet.load_weights("./icpr2020dfdc/blazeface/blazeface.pth") | |
facedet.load_anchors("./icpr2020dfdc/blazeface/anchors.npy") | |
videoreader = VideoReader(verbose=False) | |
video_read_fn = lambda x: videoreader.read_frames(x, num_frames=frames_per_video) | |
face_extractor = FaceExtractor(video_read_fn=video_read_fn,facedet=facedet) | |
title = "Face Manipulation Detection Through Ensemble of CNNs" | |
def inference(vid): | |
#return "./Labels/Fake.png", f"{vid}" | |
vid_real_faces = face_extractor.process_video(vid) | |
faces_real_t = torch.stack( [ transf(image=frame['faces'][0])['image'] for frame in vid_real_faces if len(frame['faces'])] ) | |
with torch.no_grad(): | |
faces_real_pred = net(faces_real_t.to(device)).cpu().numpy().flatten() | |
# fig,ax = plt.subplots(1,2,figsize=(12,4)) | |
plt.stem([f['frame_idx'] for f in vid_real_faces if len(f['faces'])],expit(faces_real_pred),use_line_collection=True) | |
plt.title('Score per Frame') | |
plt.xlabel('Frame') | |
plt.ylabel('Score') | |
plt.ylim([0,1]) | |
plt.grid() | |
img_buf = io.BytesIO() | |
img_buf.truncate(0) | |
plt.savefig(img_buf, format='png') | |
img_buf.seek(0) | |
im = Image.open(img_buf) | |
res = expit(faces_real_pred.mean()) | |
if res >= 0.5: | |
return "./Labels/Fake.jpg",im, f"{res*100:.2f}%" | |
else: | |
return "./Labels/Real.jpg",im, f"{res*100:.2f}%" | |
demo = gr.Interface( | |
fn=inference, | |
inputs=[gr.inputs.Video(type="mp4", label="In")], | |
outputs=[gr.outputs.Image(type="pil", label="Label"), gr.outputs.Image(type="pil", label="Score per Frame") ,gr.outputs.Label(type="text", label="Score") ] | |
).launch(debug=True) |