File size: 2,537 Bytes
ad5c7cd
 
 
 
 
 
5cf4e2d
ad5c7cd
 
 
 
 
641e847
ad5c7cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cf4e2d
 
ad5c7cd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os.path
import re
import torch
import time
import tempfile

import streamlit as st
from training.zoo.classifiers import DeepFakeClassifier
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set


def load_model():
    path = 'weights/best.pth'
    model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns")
    print("loading state dict {}".format(path))
    checkpoint = torch.load(path, map_location="cpu")
    state_dict = checkpoint.get("state_dict", checkpoint)
    model.load_state_dict(
        {re.sub("^module.", "", k): v for k, v in state_dict.items()},
        strict=True)
    model.eval()
    del checkpoint
    return model


def write_bytesio_to_file(filename, bytesio):
    with open(filename, "wb") as outfile:
        outfile.write(bytesio.getbuffer())


def load_video():
    uploaded_file = st.file_uploader(label='Pick a video (mp4) file to test')
    if uploaded_file is not None:
        video_data = uploaded_file.getvalue()
        tfile = tempfile.NamedTemporaryFile(delete=False)
        tfile.write(video_data)
        return tfile.name
    else:
        return None


def inference(model, test_video):
    frames_per_video = 32
    video_reader = VideoReader()
    video_read_fn = lambda x: video_reader.read_frames(
        x, num_frames=frames_per_video)
    face_extractor = FaceExtractor(video_read_fn)
    input_size = 380
    strategy = confident_strategy

    test_videos = [test_video]
    print("Predicting {} videos".format(len(test_videos)))
    models = [model]
    predictions = predict_on_video_set(face_extractor=face_extractor,
                                       input_size=input_size, models=models,
                                       strategy=strategy,
                                       frames_per_video=frames_per_video,
                                       videos=test_videos,
                                       num_workers=6, test_dir="test_video")
    st.write("Prediction: ", predictions[0])


def main():
    st.title('Deepfake video inference demo')
    model = load_model()
    video_data_path = load_video()

    if video_data_path is not None and os.path.exists(video_data_path):
        st.video(video_data_path)

        result = st.button('Run on video')
        if result:
            st.write("Inference on video...")
            stime = time.time()
            inference(model, video_data_path)
            st.write("Elapsed time: ", time.time() - stime, " seconds")


if __name__ == '__main__':
    main()