jungwoonshin commited on
Commit
a8ff7ce
1 Parent(s): 5cd7059

Delete training

Browse files
training/pipelines/app.py DELETED
@@ -1,67 +0,0 @@
1
- import gradio as gr
2
- import argparse
3
- import os
4
- import re
5
- import time
6
-
7
- import torch
8
- import pandas as pd
9
-
10
- import os, sys
11
- root_folder = os.path.abspath(
12
- os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
- )
14
- sys.path.append(root_folder)
15
- from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set
16
- from training.zoo.classifiers import DeepFakeClassifier
17
-
18
-
19
-
20
- def predict(video_index):
21
- video_index = int(video_index)
22
-
23
- frames_per_video = 32
24
- video_reader = VideoReader()
25
- video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
26
- face_extractor = FaceExtractor(video_read_fn)
27
- input_size = 380
28
- strategy = confident_strategy
29
-
30
- test_videos = sorted([x for x in os.listdir(args.test_dir) if x[-4:] == ".mp4"])[video_index]
31
- print(f"Predicting {video_index} videos")
32
- predictions = predict_on_video_set(face_extractor=face_extractor, input_size=input_size, models=models,
33
- strategy=strategy, frames_per_video=frames_per_video, videos=test_videos,
34
- num_workers=6, test_dir=args.test_dir)
35
- return predictions
36
-
37
- def get_args_models():
38
- parser = argparse.ArgumentParser("Predict test videos")
39
- arg = parser.add_argument
40
- arg('--weights-dir', type=str, default="weights", help="path to directory with checkpoints")
41
- arg('--models', type=str, default='classifier_DeepFakeClassifier_tf_efficientnet_b7_ns_1_best_dice', help="checkpoint files") # nargs='+',
42
- arg('--test-dir', type=str, default='test_dataset', help="path to directory with videos")
43
- arg('--output', type=str, required=False, help="path to output csv", default="submission.csv")
44
- args = parser.parse_args()
45
-
46
- models = []
47
- # model_paths = [os.path.join(args.weights_dir, model) for model in args.models]
48
- model_paths = [os.path.join(args.weights_dir, args.models)]
49
- for path in model_paths:
50
- model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to("cuda")
51
- print("loading state dict {}".format(path))
52
- checkpoint = torch.load(path, map_location="cpu")
53
- state_dict = checkpoint.get("state_dict", checkpoint)
54
- model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
55
- model.eval()
56
- del checkpoint
57
- models.append(model.half())
58
- return args, models
59
-
60
- if __name__ == '__main__':
61
- global models, args
62
- stime = time.time()
63
- print("Elapsed:", time.time() - stime)
64
- args, models = get_args_models()
65
- demo = gr.Interface(fn=predict, inputs="text", outputs="text")
66
- demo.launch()
67
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/pipelines/train_classifier_gradio.py DELETED
@@ -1,67 +0,0 @@
1
- import gradio as gr
2
- import argparse
3
- import os
4
- import re
5
- import time
6
-
7
- import torch
8
- import pandas as pd
9
-
10
- import os, sys
11
- root_folder = os.path.abspath(
12
- os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
- )
14
- sys.path.append(root_folder)
15
- from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set
16
- from training.zoo.classifiers import DeepFakeClassifier
17
-
18
-
19
-
20
- def predict(video_index):
21
- video_index = int(video_index)
22
-
23
- frames_per_video = 32
24
- video_reader = VideoReader()
25
- video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
26
- face_extractor = FaceExtractor(video_read_fn)
27
- input_size = 380
28
- strategy = confident_strategy
29
-
30
- test_videos = sorted([x for x in os.listdir(args.test_dir) if x[-4:] == ".mp4"])[video_index]
31
- print(f"Predicting {video_index} videos")
32
- predictions = predict_on_video_set(face_extractor=face_extractor, input_size=input_size, models=models,
33
- strategy=strategy, frames_per_video=frames_per_video, videos=test_videos,
34
- num_workers=6, test_dir=args.test_dir)
35
- return predictions
36
-
37
- def get_args_models():
38
- parser = argparse.ArgumentParser("Predict test videos")
39
- arg = parser.add_argument
40
- arg('--weights-dir', type=str, default="weights", help="path to directory with checkpoints")
41
- arg('--models', type=str, default='classifier_DeepFakeClassifier_tf_efficientnet_b7_ns_1_best_dice', help="checkpoint files") # nargs='+',
42
- arg('--test-dir', type=str, default='test_dataset', help="path to directory with videos")
43
- arg('--output', type=str, required=False, help="path to output csv", default="submission.csv")
44
- args = parser.parse_args()
45
-
46
- models = []
47
- # model_paths = [os.path.join(args.weights_dir, model) for model in args.models]
48
- model_paths = [os.path.join(args.weights_dir, args.models)]
49
- for path in model_paths:
50
- model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to("cuda")
51
- print("loading state dict {}".format(path))
52
- checkpoint = torch.load(path, map_location="cpu")
53
- state_dict = checkpoint.get("state_dict", checkpoint)
54
- model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
55
- model.eval()
56
- del checkpoint
57
- models.append(model.half())
58
- return args, models
59
-
60
- if __name__ == '__main__':
61
- global models, args
62
- stime = time.time()
63
- print("Elapsed:", time.time() - stime)
64
- args, models = get_args_models()
65
- demo = gr.Interface(fn=predict, inputs="text", outputs="text")
66
- demo.launch()
67
-