jungwoonshin commited on
Commit
5cd7059
1 Parent(s): d8e6c94
Files changed (3) hide show
  1. app.py +74 -0
  2. predict/app.py +68 -0
  3. predict/kernel_utils.py +358 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ # import argparse
3
+ # import os
4
+ # import re
5
+ # import time
6
+
7
+ # import torch
8
+ # import pandas as pd
9
+
10
+ # # import os, sys
11
+ # # root_folder = os.path.abspath(
12
+ # # os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
+ # # )
14
+ # # sys.path.append(root_folder)
15
+ # from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set
16
+ # from classifiers import DeepFakeClassifier
17
+ # import gradio as gr
18
+
19
+
20
+
21
+ # def predict(video):
22
+
23
+ # # video_index = int(video_index)
24
+
25
+ # frames_per_video = 32
26
+ # video_reader = VideoReader()
27
+ # video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
28
+ # face_extractor = FaceExtractor(video_read_fn)
29
+ # input_size = 380
30
+ # strategy = confident_strategy
31
+
32
+ # # test_videos = sorted([x for x in os.listdir(args.test_dir) if x[-4:] == ".mp4"])[video_index]
33
+ # # print(f"Predicting {video_index} videos")
34
+ # predictions = predict_on_video_set(face_extractor=face_extractor, input_size=input_size, models=models,
35
+ # strategy=strategy, frames_per_video=frames_per_video, videos=video,
36
+ # num_workers=6, test_dir=args.test_dir)
37
+ # return predictions
38
+
39
+ # def get_args_models():
40
+ # parser = argparse.ArgumentParser("Predict test videos")
41
+ # arg = parser.add_argument
42
+ # arg('--weights-dir', type=str, default="weights", help="path to directory with checkpoints")
43
+ # arg('--models', type=str, default='classifier_DeepFakeClassifier_tf_efficientnet_b7_ns_1_best_dice', help="checkpoint files") # nargs='+',
44
+ # arg('--test-dir', type=str, default='test_dataset', help="path to directory with videos")
45
+ # arg('--output', type=str, required=False, help="path to output csv", default="submission.csv")
46
+ # args = parser.parse_args()
47
+
48
+ # models = []
49
+ # # model_paths = [os.path.join(args.weights_dir, model) for model in args.models]
50
+ # model_paths = [os.path.join(args.weights_dir, args.models)]
51
+ # for path in model_paths:
52
+ # model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to("cpu")
53
+ # print("loading state dict {}".format(path))
54
+ # checkpoint = torch.load(path, map_location="cpu")
55
+ # state_dict = checkpoint.get("state_dict", checkpoint)
56
+ # model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
57
+ # model.eval()
58
+ # del checkpoint
59
+ # models.append(model.half())
60
+ # return args, models
61
+
62
+ def greet(name):
63
+ return "Hello " + name + "!!"
64
+
65
+ if __name__ == '__main__':
66
+ # global args, models
67
+ # args, models = get_args_models()
68
+
69
+ # stime = time.time()
70
+ # print("Elapsed:", time.time() - stime)
71
+
72
+ demo = gr.Interface(fn=greet, inputs="video", outputs="text")
73
+ demo.launch()
74
+
predict/app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import argparse
3
+ import os
4
+ import re
5
+ import time
6
+
7
+ import torch
8
+ import pandas as pd
9
+
10
+ import os, sys
11
+ root_folder = os.path.abspath(
12
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
+ )
14
+ sys.path.append(root_folder)
15
+ from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set
16
+ from training.zoo.classifiers import DeepFakeClassifier
17
+
18
+
19
+
20
+ def predict(video):
21
+ # video_index = int(video_index)
22
+
23
+ frames_per_video = 32
24
+ video_reader = VideoReader()
25
+ video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
26
+ face_extractor = FaceExtractor(video_read_fn)
27
+ input_size = 380
28
+ strategy = confident_strategy
29
+
30
+ # test_videos = sorted([x for x in os.listdir(args.test_dir) if x[-4:] == ".mp4"])[video_index]
31
+ # print(f"Predicting {video_index} videos")
32
+ predictions = predict_on_video_set(face_extractor=face_extractor, input_size=input_size, models=models,
33
+ strategy=strategy, frames_per_video=frames_per_video, videos=video,
34
+ num_workers=6, test_dir=args.test_dir)
35
+ return predictions
36
+
37
+ def get_args_models():
38
+ parser = argparse.ArgumentParser("Predict test videos")
39
+ arg = parser.add_argument
40
+ arg('--weights-dir', type=str, default="weights", help="path to directory with checkpoints")
41
+ arg('--models', type=str, default='classifier_DeepFakeClassifier_tf_efficientnet_b7_ns_1_best_dice', help="checkpoint files") # nargs='+',
42
+ arg('--test-dir', type=str, default='test_dataset', help="path to directory with videos")
43
+ arg('--output', type=str, required=False, help="path to output csv", default="submission.csv")
44
+ args = parser.parse_args()
45
+
46
+ models = []
47
+ # model_paths = [os.path.join(args.weights_dir, model) for model in args.models]
48
+ model_paths = [os.path.join(args.weights_dir, args.models)]
49
+ for path in model_paths:
50
+ model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to("cpu")
51
+ print("loading state dict {}".format(path))
52
+ checkpoint = torch.load(path, map_location="cpu")
53
+ state_dict = checkpoint.get("state_dict", checkpoint)
54
+ model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
55
+ model.eval()
56
+ del checkpoint
57
+ models.append(model.half())
58
+ return args, models
59
+
60
+ if __name__ == '__main__':
61
+ global models, args
62
+ stime = time.time()
63
+ print("Elapsed:", time.time() - stime)
64
+ args, models = get_args_models()
65
+
66
+ demo = gr.Interface(fn=predict, inputs="image", outputs="text")
67
+ demo.launch()
68
+
predict/kernel_utils.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from albumentations.augmentations.functional import image_compression
8
+ from facenet_pytorch.models.mtcnn import MTCNN
9
+ from concurrent.futures import ThreadPoolExecutor
10
+
11
+ from torchvision.transforms import Normalize
12
+
13
+ mean = [0.485, 0.456, 0.406]
14
+ std = [0.229, 0.224, 0.225]
15
+ normalize_transform = Normalize(mean, std)
16
+
17
+
18
+ class VideoReader:
19
+ """Helper class for reading one or more frames from a video file."""
20
+
21
+ def __init__(self, verbose=True, insets=(0, 0)):
22
+ """Creates a new VideoReader.
23
+
24
+ Arguments:
25
+ verbose: whether to print warnings and error messages
26
+ insets: amount to inset the image by, as a percentage of
27
+ (width, height). This lets you "zoom in" to an image
28
+ to remove unimportant content around the borders.
29
+ Useful for face detection, which may not work if the
30
+ faces are too small.
31
+ """
32
+ self.verbose = verbose
33
+ self.insets = insets
34
+
35
+ def read_frames(self, path, num_frames, jitter=0, seed=None):
36
+ """Reads frames that are always evenly spaced throughout the video.
37
+
38
+ Arguments:
39
+ path: the video file
40
+ num_frames: how many frames to read, -1 means the entire video
41
+ (warning: this will take up a lot of memory!)
42
+ jitter: if not 0, adds small random offsets to the frame indices;
43
+ this is useful so we don't always land on even or odd frames
44
+ seed: random seed for jittering; if you set this to a fixed value,
45
+ you probably want to set it only on the first video
46
+ """
47
+ assert num_frames > 0
48
+
49
+ capture = cv2.VideoCapture(path)
50
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
51
+ if frame_count <= 0: return None
52
+
53
+ frame_idxs = np.linspace(0, frame_count - 1, num_frames, endpoint=True, dtype=np.int)
54
+ if jitter > 0:
55
+ np.random.seed(seed)
56
+ jitter_offsets = np.random.randint(-jitter, jitter, len(frame_idxs))
57
+ frame_idxs = np.clip(frame_idxs + jitter_offsets, 0, frame_count - 1)
58
+
59
+ result = self._read_frames_at_indices(path, capture, frame_idxs)
60
+ capture.release()
61
+ return result
62
+
63
+ def read_random_frames(self, path, num_frames, seed=None):
64
+ """Picks the frame indices at random.
65
+
66
+ Arguments:
67
+ path: the video file
68
+ num_frames: how many frames to read, -1 means the entire video
69
+ (warning: this will take up a lot of memory!)
70
+ """
71
+ assert num_frames > 0
72
+ np.random.seed(seed)
73
+
74
+ capture = cv2.VideoCapture(path)
75
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
76
+ if frame_count <= 0: return None
77
+
78
+ frame_idxs = sorted(np.random.choice(np.arange(0, frame_count), num_frames))
79
+ result = self._read_frames_at_indices(path, capture, frame_idxs)
80
+
81
+ capture.release()
82
+ return result
83
+
84
+ def read_frames_at_indices(self, path, frame_idxs):
85
+ """Reads frames from a video and puts them into a NumPy array.
86
+
87
+ Arguments:
88
+ path: the video file
89
+ frame_idxs: a list of frame indices. Important: should be
90
+ sorted from low-to-high! If an index appears multiple
91
+ times, the frame is still read only once.
92
+
93
+ Returns:
94
+ - a NumPy array of shape (num_frames, height, width, 3)
95
+ - a list of the frame indices that were read
96
+
97
+ Reading stops if loading a frame fails, in which case the first
98
+ dimension returned may actually be less than num_frames.
99
+
100
+ Returns None if an exception is thrown for any reason, or if no
101
+ frames were read.
102
+ """
103
+ assert len(frame_idxs) > 0
104
+ capture = cv2.VideoCapture(path)
105
+ result = self._read_frames_at_indices(path, capture, frame_idxs)
106
+ capture.release()
107
+ return result
108
+
109
+ def _read_frames_at_indices(self, path, capture, frame_idxs):
110
+ try:
111
+ frames = []
112
+ idxs_read = []
113
+ for frame_idx in range(frame_idxs[0], frame_idxs[-1] + 1):
114
+ # Get the next frame, but don't decode if we're not using it.
115
+ ret = capture.grab()
116
+ if not ret:
117
+ if self.verbose:
118
+ print("Error grabbing frame %d from movie %s" % (frame_idx, path))
119
+ break
120
+
121
+ # Need to look at this frame?
122
+ current = len(idxs_read)
123
+ if frame_idx == frame_idxs[current]:
124
+ ret, frame = capture.retrieve()
125
+ if not ret or frame is None:
126
+ if self.verbose:
127
+ print("Error retrieving frame %d from movie %s" % (frame_idx, path))
128
+ break
129
+
130
+ frame = self._postprocess_frame(frame)
131
+ frames.append(frame)
132
+ idxs_read.append(frame_idx)
133
+
134
+ if len(frames) > 0:
135
+ return np.stack(frames), idxs_read
136
+ if self.verbose:
137
+ print("No frames read from movie %s" % path)
138
+ return None
139
+ except:
140
+ if self.verbose:
141
+ print("Exception while reading movie %s" % path)
142
+ return None
143
+
144
+ def read_middle_frame(self, path):
145
+ """Reads the frame from the middle of the video."""
146
+ capture = cv2.VideoCapture(path)
147
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
148
+ result = self._read_frame_at_index(path, capture, frame_count // 2)
149
+ capture.release()
150
+ return result
151
+
152
+ def read_frame_at_index(self, path, frame_idx):
153
+ """Reads a single frame from a video.
154
+
155
+ If you just want to read a single frame from the video, this is more
156
+ efficient than scanning through the video to find the frame. However,
157
+ for reading multiple frames it's not efficient.
158
+
159
+ My guess is that a "streaming" approach is more efficient than a
160
+ "random access" approach because, unless you happen to grab a keyframe,
161
+ the decoder still needs to read all the previous frames in order to
162
+ reconstruct the one you're asking for.
163
+
164
+ Returns a NumPy array of shape (1, H, W, 3) and the index of the frame,
165
+ or None if reading failed.
166
+ """
167
+ capture = cv2.VideoCapture(path)
168
+ result = self._read_frame_at_index(path, capture, frame_idx)
169
+ capture.release()
170
+ return result
171
+
172
+ def _read_frame_at_index(self, path, capture, frame_idx):
173
+ capture.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
174
+ ret, frame = capture.read()
175
+ if not ret or frame is None:
176
+ if self.verbose:
177
+ print("Error retrieving frame %d from movie %s" % (frame_idx, path))
178
+ return None
179
+ else:
180
+ frame = self._postprocess_frame(frame)
181
+ return np.expand_dims(frame, axis=0), [frame_idx]
182
+
183
+ def _postprocess_frame(self, frame):
184
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
185
+
186
+ if self.insets[0] > 0:
187
+ W = frame.shape[1]
188
+ p = int(W * self.insets[0])
189
+ frame = frame[:, p:-p, :]
190
+
191
+ if self.insets[1] > 0:
192
+ H = frame.shape[1]
193
+ q = int(H * self.insets[1])
194
+ frame = frame[q:-q, :, :]
195
+
196
+ return frame
197
+
198
+
199
+ class FaceExtractor:
200
+ def __init__(self, video_read_fn):
201
+ self.video_read_fn = video_read_fn
202
+ self.detector = MTCNN(margin=0, thresholds=[0.7, 0.8, 0.8], device="cuda")
203
+
204
+ def process_videos(self, videos):
205
+ videos_read = []
206
+ frames_read = []
207
+ frames = []
208
+ results = []
209
+ # for video_idx in video_idxs:
210
+ # Read the full-size frames from this video.
211
+ # filename = filenames[video_idx]
212
+ # video_path = os.path.join(input_dir, filename)
213
+ # result = self.video_read_fn(video_path)
214
+ result = videos
215
+ # Error? Then skip this video.
216
+
217
+ # Keep track of the original frames (need them later).
218
+ my_frames, my_idxs = result
219
+
220
+ frames.append(my_frames)
221
+ frames_read.append(my_idxs)
222
+ for i, frame in enumerate(my_frames):
223
+ h, w = frame.shape[:2]
224
+ img = Image.fromarray(frame.astype(np.uint8))
225
+ img = img.resize(size=[s // 2 for s in img.size])
226
+
227
+ batch_boxes, probs = self.detector.detect(img, landmarks=False)
228
+
229
+ faces = []
230
+ scores = []
231
+ if batch_boxes is None:
232
+ continue
233
+ for bbox, score in zip(batch_boxes, probs):
234
+ if bbox is not None:
235
+ xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox]
236
+ w = xmax - xmin
237
+ h = ymax - ymin
238
+ p_h = h // 3
239
+ p_w = w // 3
240
+ crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w]
241
+ faces.append(crop)
242
+ scores.append(score)
243
+
244
+ frame_dict = {"video_idx": video_idx,
245
+ "frame_idx": my_idxs[i],
246
+ "frame_w": w,
247
+ "frame_h": h,
248
+ "faces": faces,
249
+ "scores": scores}
250
+ results.append(frame_dict)
251
+
252
+ return results
253
+
254
+ def process_video(self, video_path):
255
+ """Convenience method for doing face extraction on a single video."""
256
+ input_dir = os.path.dirname(video_path)
257
+ filenames = [os.path.basename(video_path)]
258
+ return self.process_videos(input_dir, filenames, [0])
259
+
260
+
261
+
262
+ def confident_strategy(pred, t=0.8):
263
+ pred = np.array(pred)
264
+ sz = len(pred)
265
+ fakes = np.count_nonzero(pred > t)
266
+ # 11 frames are detected as fakes with high probability
267
+ if fakes > sz // 2.5 and fakes > 11:
268
+ return np.mean(pred[pred > t])
269
+ elif np.count_nonzero(pred < 0.2) > 0.9 * sz:
270
+ return np.mean(pred[pred < 0.2])
271
+ else:
272
+ return np.mean(pred)
273
+
274
+ strategy = confident_strategy
275
+
276
+
277
+ def put_to_center(img, input_size):
278
+ img = img[:input_size, :input_size]
279
+ image = np.zeros((input_size, input_size, 3), dtype=np.uint8)
280
+ start_w = (input_size - img.shape[1]) // 2
281
+ start_h = (input_size - img.shape[0]) // 2
282
+ image[start_h:start_h + img.shape[0], start_w: start_w + img.shape[1], :] = img
283
+ return image
284
+
285
+
286
+ def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
287
+ h, w = img.shape[:2]
288
+ if max(w, h) == size:
289
+ return img
290
+ if w > h:
291
+ scale = size / w
292
+ h = h * scale
293
+ w = size
294
+ else:
295
+ scale = size / h
296
+ w = w * scale
297
+ h = size
298
+ interpolation = interpolation_up if scale > 1 else interpolation_down
299
+ resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
300
+ return resized
301
+
302
+
303
+ def predict_on_video(face_extractor, video_path, videos, batch_size, input_size, models, strategy=np.mean,
304
+ apply_compression=False):
305
+ batch_size *= 4
306
+ try:
307
+ faces = face_extractor.process_video(videos)
308
+ if len(faces) > 0:
309
+ x = np.zeros((batch_size, input_size, input_size, 3), dtype=np.uint8)
310
+ n = 0
311
+ for frame_data in faces:
312
+ for face in frame_data["faces"]:
313
+ resized_face = isotropically_resize_image(face, input_size)
314
+ resized_face = put_to_center(resized_face, input_size)
315
+ if apply_compression:
316
+ resized_face = image_compression(resized_face, quality=90, image_type=".jpg")
317
+ if n + 1 < batch_size:
318
+ x[n] = resized_face
319
+ n += 1
320
+ else:
321
+ pass
322
+ if n > 0:
323
+ x = torch.tensor(x, device="cuda").float()
324
+ # Preprocess the images.
325
+ x = x.permute((0, 3, 1, 2))
326
+ for i in range(len(x)):
327
+ x[i] = normalize_transform(x[i] / 255.)
328
+ # Make a prediction, then take the average.
329
+ with torch.no_grad():
330
+ preds = []
331
+ for model in models:
332
+ y_pred = model(x[:n].half())
333
+ y_pred = torch.sigmoid(y_pred.squeeze())
334
+ bpred = y_pred[:n].cpu().numpy()
335
+ preds.append(strategy(bpred))
336
+ return np.mean(preds)
337
+ except Exception as e:
338
+ print("Prediction error on video %s: %s" % (video_path, str(e)))
339
+
340
+ return 0.5
341
+
342
+
343
+ def predict_on_video_set(face_extractor, videos, input_size, num_workers, test_dir, frames_per_video, models,
344
+ strategy=np.mean,
345
+ apply_compression=False):
346
+ def process_file(i):
347
+ filename = videos
348
+ y_pred = predict_on_video(face_extractor=face_extractor, video_path=os.path.join(test_dir, filename),
349
+ videos=videos,
350
+ input_size=input_size,
351
+ batch_size=frames_per_video,
352
+ models=models, strategy=strategy, apply_compression=apply_compression)
353
+ return y_pred
354
+
355
+ with ThreadPoolExecutor(max_workers=num_workers) as ex:
356
+ predictions = ex.map(process_file, [1])
357
+ return list(predictions)
358
+