deepfake / app.py
Intae's picture
Add training and weights
641e847
import os.path
import re
import torch
import time
import tempfile
import streamlit as st
from training.zoo.classifiers import DeepFakeClassifier
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set
def load_model():
path = 'weights/best.pth'
model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns")
print("loading state dict {}".format(path))
checkpoint = torch.load(path, map_location="cpu")
state_dict = checkpoint.get("state_dict", checkpoint)
model.load_state_dict(
{re.sub("^module.", "", k): v for k, v in state_dict.items()},
strict=True)
model.eval()
del checkpoint
return model
def write_bytesio_to_file(filename, bytesio):
with open(filename, "wb") as outfile:
outfile.write(bytesio.getbuffer())
def load_video():
uploaded_file = st.file_uploader(label='Pick a video (mp4) file to test')
if uploaded_file is not None:
video_data = uploaded_file.getvalue()
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(video_data)
return tfile.name
else:
return None
def inference(model, test_video):
frames_per_video = 32
video_reader = VideoReader()
video_read_fn = lambda x: video_reader.read_frames(
x, num_frames=frames_per_video)
face_extractor = FaceExtractor(video_read_fn)
input_size = 380
strategy = confident_strategy
test_videos = [test_video]
print("Predicting {} videos".format(len(test_videos)))
models = [model]
predictions = predict_on_video_set(face_extractor=face_extractor,
input_size=input_size, models=models,
strategy=strategy,
frames_per_video=frames_per_video,
videos=test_videos,
num_workers=6, test_dir="test_video")
st.write("Prediction: ", predictions[0])
def main():
st.title('Deepfake video inference demo')
model = load_model()
video_data_path = load_video()
if video_data_path is not None and os.path.exists(video_data_path):
st.video(video_data_path)
result = st.button('Run on video')
if result:
st.write("Inference on video...")
stime = time.time()
inference(model, video_data_path)
st.write("Elapsed time: ", time.time() - stime, " seconds")
if __name__ == '__main__':
main()