import os.path import re import torch import time import tempfile import streamlit as st from training.zoo.classifiers import DeepFakeClassifier from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set def load_model(): path = 'weights/best.pth' model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns") print("loading state dict {}".format(path)) checkpoint = torch.load(path, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint) model.load_state_dict( {re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True) model.eval() del checkpoint return model def write_bytesio_to_file(filename, bytesio): with open(filename, "wb") as outfile: outfile.write(bytesio.getbuffer()) def load_video(): uploaded_file = st.file_uploader(label='Pick a video (mp4) file to test') if uploaded_file is not None: video_data = uploaded_file.getvalue() tfile = tempfile.NamedTemporaryFile(delete=False) tfile.write(video_data) return tfile.name else: return None def inference(model, test_video): frames_per_video = 32 video_reader = VideoReader() video_read_fn = lambda x: video_reader.read_frames( x, num_frames=frames_per_video) face_extractor = FaceExtractor(video_read_fn) input_size = 380 strategy = confident_strategy test_videos = [test_video] print("Predicting {} videos".format(len(test_videos))) models = [model] predictions = predict_on_video_set(face_extractor=face_extractor, input_size=input_size, models=models, strategy=strategy, frames_per_video=frames_per_video, videos=test_videos, num_workers=6, test_dir="test_video") st.write("Prediction: ", predictions[0]) def main(): st.title('Deepfake video inference demo') model = load_model() video_data_path = load_video() if video_data_path is not None and os.path.exists(video_data_path): st.video(video_data_path) result = st.button('Run on video') if result: st.write("Inference on video...") stime = time.time() inference(model, video_data_path) st.write("Elapsed time: ", time.time() - stime, " seconds") if __name__ == '__main__': main()