deepfake / app.py
Intae's picture
Add training and weights
641e847
raw
history blame contribute delete
No virus
2.54 kB
import os.path
import re
import torch
import time
import tempfile
import streamlit as st
from training.zoo.classifiers import DeepFakeClassifier
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set
def load_model():
path = 'weights/best.pth'
model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns")
print("loading state dict {}".format(path))
checkpoint = torch.load(path, map_location="cpu")
state_dict = checkpoint.get("state_dict", checkpoint)
model.load_state_dict(
{re.sub("^module.", "", k): v for k, v in state_dict.items()},
strict=True)
model.eval()
del checkpoint
return model
def write_bytesio_to_file(filename, bytesio):
with open(filename, "wb") as outfile:
outfile.write(bytesio.getbuffer())
def load_video():
uploaded_file = st.file_uploader(label='Pick a video (mp4) file to test')
if uploaded_file is not None:
video_data = uploaded_file.getvalue()
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(video_data)
return tfile.name
else:
return None
def inference(model, test_video):
frames_per_video = 32
video_reader = VideoReader()
video_read_fn = lambda x: video_reader.read_frames(
x, num_frames=frames_per_video)
face_extractor = FaceExtractor(video_read_fn)
input_size = 380
strategy = confident_strategy
test_videos = [test_video]
print("Predicting {} videos".format(len(test_videos)))
models = [model]
predictions = predict_on_video_set(face_extractor=face_extractor,
input_size=input_size, models=models,
strategy=strategy,
frames_per_video=frames_per_video,
videos=test_videos,
num_workers=6, test_dir="test_video")
st.write("Prediction: ", predictions[0])
def main():
st.title('Deepfake video inference demo')
model = load_model()
video_data_path = load_video()
if video_data_path is not None and os.path.exists(video_data_path):
st.video(video_data_path)
result = st.button('Run on video')
if result:
st.write("Inference on video...")
stime = time.time()
inference(model, video_data_path)
st.write("Elapsed time: ", time.time() - stime, " seconds")
if __name__ == '__main__':
main()