jungwoonshin commited on
Commit
d8e6c94
1 Parent(s): 4de0f77

second commit

Browse files
Files changed (1) hide show
  1. training/pipelines/app.py +67 -0
training/pipelines/app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import argparse
3
+ import os
4
+ import re
5
+ import time
6
+
7
+ import torch
8
+ import pandas as pd
9
+
10
+ import os, sys
11
+ root_folder = os.path.abspath(
12
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
+ )
14
+ sys.path.append(root_folder)
15
+ from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set
16
+ from training.zoo.classifiers import DeepFakeClassifier
17
+
18
+
19
+
20
+ def predict(video_index):
21
+ video_index = int(video_index)
22
+
23
+ frames_per_video = 32
24
+ video_reader = VideoReader()
25
+ video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
26
+ face_extractor = FaceExtractor(video_read_fn)
27
+ input_size = 380
28
+ strategy = confident_strategy
29
+
30
+ test_videos = sorted([x for x in os.listdir(args.test_dir) if x[-4:] == ".mp4"])[video_index]
31
+ print(f"Predicting {video_index} videos")
32
+ predictions = predict_on_video_set(face_extractor=face_extractor, input_size=input_size, models=models,
33
+ strategy=strategy, frames_per_video=frames_per_video, videos=test_videos,
34
+ num_workers=6, test_dir=args.test_dir)
35
+ return predictions
36
+
37
+ def get_args_models():
38
+ parser = argparse.ArgumentParser("Predict test videos")
39
+ arg = parser.add_argument
40
+ arg('--weights-dir', type=str, default="weights", help="path to directory with checkpoints")
41
+ arg('--models', type=str, default='classifier_DeepFakeClassifier_tf_efficientnet_b7_ns_1_best_dice', help="checkpoint files") # nargs='+',
42
+ arg('--test-dir', type=str, default='test_dataset', help="path to directory with videos")
43
+ arg('--output', type=str, required=False, help="path to output csv", default="submission.csv")
44
+ args = parser.parse_args()
45
+
46
+ models = []
47
+ # model_paths = [os.path.join(args.weights_dir, model) for model in args.models]
48
+ model_paths = [os.path.join(args.weights_dir, args.models)]
49
+ for path in model_paths:
50
+ model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to("cuda")
51
+ print("loading state dict {}".format(path))
52
+ checkpoint = torch.load(path, map_location="cpu")
53
+ state_dict = checkpoint.get("state_dict", checkpoint)
54
+ model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
55
+ model.eval()
56
+ del checkpoint
57
+ models.append(model.half())
58
+ return args, models
59
+
60
+ if __name__ == '__main__':
61
+ global models, args
62
+ stime = time.time()
63
+ print("Elapsed:", time.time() - stime)
64
+ args, models = get_args_models()
65
+ demo = gr.Interface(fn=predict, inputs="text", outputs="text")
66
+ demo.launch()
67
+