Intae commited on
Commit
641e847
1 Parent(s): ad5c7cd

Add training and weights

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .DS_Store
2
+ __pycache__
3
+ .idea
app.py CHANGED
@@ -10,7 +10,7 @@ from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict
10
 
11
 
12
  def load_model():
13
- path = 'weights/final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23'
14
  model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns")
15
  print("loading state dict {}".format(path))
16
  checkpoint = torch.load(path, map_location="cpu")
 
10
 
11
 
12
  def load_model():
13
+ path = 'weights/best.pth'
14
  model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns")
15
  print("loading state dict {}".format(path))
16
  checkpoint = torch.load(path, map_location="cpu")
kernel_utils.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from albumentations.augmentations.functional import image_compression
8
+ from facenet_pytorch.models.mtcnn import MTCNN
9
+ from concurrent.futures import ThreadPoolExecutor
10
+
11
+ from torchvision.transforms import Normalize
12
+
13
+ mean = [0.485, 0.456, 0.406]
14
+ std = [0.229, 0.224, 0.225]
15
+ normalize_transform = Normalize(mean, std)
16
+
17
+
18
+ class VideoReader:
19
+ """Helper class for reading one or more frames from a video file."""
20
+
21
+ def __init__(self, verbose=True, insets=(0, 0)):
22
+ """Creates a new VideoReader.
23
+
24
+ Arguments:
25
+ verbose: whether to print warnings and error messages
26
+ insets: amount to inset the image by, as a percentage of
27
+ (width, height). This lets you "zoom in" to an image
28
+ to remove unimportant content around the borders.
29
+ Useful for face detection, which may not work if the
30
+ faces are too small.
31
+ """
32
+ self.verbose = verbose
33
+ self.insets = insets
34
+
35
+ def read_frames(self, path, num_frames, jitter=0, seed=None):
36
+ """Reads frames that are always evenly spaced throughout the video.
37
+
38
+ Arguments:
39
+ path: the video file
40
+ num_frames: how many frames to read, -1 means the entire video
41
+ (warning: this will take up a lot of memory!)
42
+ jitter: if not 0, adds small random offsets to the frame indices;
43
+ this is useful so we don't always land on even or odd frames
44
+ seed: random seed for jittering; if you set this to a fixed value,
45
+ you probably want to set it only on the first video
46
+ """
47
+ assert num_frames > 0
48
+
49
+ capture = cv2.VideoCapture(path)
50
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
51
+ if frame_count <= 0: return None
52
+
53
+ frame_idxs = np.linspace(0, frame_count - 1, num_frames, endpoint=True, dtype=np.int)
54
+ if jitter > 0:
55
+ np.random.seed(seed)
56
+ jitter_offsets = np.random.randint(-jitter, jitter, len(frame_idxs))
57
+ frame_idxs = np.clip(frame_idxs + jitter_offsets, 0, frame_count - 1)
58
+
59
+ result = self._read_frames_at_indices(path, capture, frame_idxs)
60
+ capture.release()
61
+ return result
62
+
63
+ def read_random_frames(self, path, num_frames, seed=None):
64
+ """Picks the frame indices at random.
65
+
66
+ Arguments:
67
+ path: the video file
68
+ num_frames: how many frames to read, -1 means the entire video
69
+ (warning: this will take up a lot of memory!)
70
+ """
71
+ assert num_frames > 0
72
+ np.random.seed(seed)
73
+
74
+ capture = cv2.VideoCapture(path)
75
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
76
+ if frame_count <= 0: return None
77
+
78
+ frame_idxs = sorted(np.random.choice(np.arange(0, frame_count), num_frames))
79
+ result = self._read_frames_at_indices(path, capture, frame_idxs)
80
+
81
+ capture.release()
82
+ return result
83
+
84
+ def read_frames_at_indices(self, path, frame_idxs):
85
+ """Reads frames from a video and puts them into a NumPy array.
86
+
87
+ Arguments:
88
+ path: the video file
89
+ frame_idxs: a list of frame indices. Important: should be
90
+ sorted from low-to-high! If an index appears multiple
91
+ times, the frame is still read only once.
92
+
93
+ Returns:
94
+ - a NumPy array of shape (num_frames, height, width, 3)
95
+ - a list of the frame indices that were read
96
+
97
+ Reading stops if loading a frame fails, in which case the first
98
+ dimension returned may actually be less than num_frames.
99
+
100
+ Returns None if an exception is thrown for any reason, or if no
101
+ frames were read.
102
+ """
103
+ assert len(frame_idxs) > 0
104
+ capture = cv2.VideoCapture(path)
105
+ result = self._read_frames_at_indices(path, capture, frame_idxs)
106
+ capture.release()
107
+ return result
108
+
109
+ def _read_frames_at_indices(self, path, capture, frame_idxs):
110
+ try:
111
+ frames = []
112
+ idxs_read = []
113
+ for frame_idx in range(frame_idxs[0], frame_idxs[-1] + 1):
114
+ # Get the next frame, but don't decode if we're not using it.
115
+ ret = capture.grab()
116
+ if not ret:
117
+ if self.verbose:
118
+ print("Error grabbing frame %d from movie %s" % (frame_idx, path))
119
+ break
120
+
121
+ # Need to look at this frame?
122
+ current = len(idxs_read)
123
+ if frame_idx == frame_idxs[current]:
124
+ ret, frame = capture.retrieve()
125
+ if not ret or frame is None:
126
+ if self.verbose:
127
+ print("Error retrieving frame %d from movie %s" % (frame_idx, path))
128
+ break
129
+
130
+ frame = self._postprocess_frame(frame)
131
+ frames.append(frame)
132
+ idxs_read.append(frame_idx)
133
+
134
+ if len(frames) > 0:
135
+ return np.stack(frames), idxs_read
136
+ if self.verbose:
137
+ print("No frames read from movie %s" % path)
138
+ return None
139
+ except:
140
+ if self.verbose:
141
+ print("Exception while reading movie %s" % path)
142
+ return None
143
+
144
+ def read_middle_frame(self, path):
145
+ """Reads the frame from the middle of the video."""
146
+ capture = cv2.VideoCapture(path)
147
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
148
+ result = self._read_frame_at_index(path, capture, frame_count // 2)
149
+ capture.release()
150
+ return result
151
+
152
+ def read_frame_at_index(self, path, frame_idx):
153
+ """Reads a single frame from a video.
154
+
155
+ If you just want to read a single frame from the video, this is more
156
+ efficient than scanning through the video to find the frame. However,
157
+ for reading multiple frames it's not efficient.
158
+
159
+ My guess is that a "streaming" approach is more efficient than a
160
+ "random access" approach because, unless you happen to grab a keyframe,
161
+ the decoder still needs to read all the previous frames in order to
162
+ reconstruct the one you're asking for.
163
+
164
+ Returns a NumPy array of shape (1, H, W, 3) and the index of the frame,
165
+ or None if reading failed.
166
+ """
167
+ capture = cv2.VideoCapture(path)
168
+ result = self._read_frame_at_index(path, capture, frame_idx)
169
+ capture.release()
170
+ return result
171
+
172
+ def _read_frame_at_index(self, path, capture, frame_idx):
173
+ capture.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
174
+ ret, frame = capture.read()
175
+ if not ret or frame is None:
176
+ if self.verbose:
177
+ print("Error retrieving frame %d from movie %s" % (frame_idx, path))
178
+ return None
179
+ else:
180
+ frame = self._postprocess_frame(frame)
181
+ return np.expand_dims(frame, axis=0), [frame_idx]
182
+
183
+ def _postprocess_frame(self, frame):
184
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
185
+
186
+ if self.insets[0] > 0:
187
+ W = frame.shape[1]
188
+ p = int(W * self.insets[0])
189
+ frame = frame[:, p:-p, :]
190
+
191
+ if self.insets[1] > 0:
192
+ H = frame.shape[1]
193
+ q = int(H * self.insets[1])
194
+ frame = frame[q:-q, :, :]
195
+
196
+ return frame
197
+
198
+
199
+ class FaceExtractor:
200
+ def __init__(self, video_read_fn):
201
+ self.video_read_fn = video_read_fn
202
+ self.detector = MTCNN(margin=0, thresholds=[0.7, 0.8, 0.8])
203
+
204
+ def process_videos(self, input_dir, filenames, video_idxs):
205
+ videos_read = []
206
+ frames_read = []
207
+ frames = []
208
+ results = []
209
+ for video_idx in video_idxs:
210
+ # Read the full-size frames from this video.
211
+ filename = filenames[video_idx]
212
+ video_path = os.path.join(input_dir, filename)
213
+ result = self.video_read_fn(video_path)
214
+ # Error? Then skip this video.
215
+ if result is None: continue
216
+
217
+ videos_read.append(video_idx)
218
+
219
+ # Keep track of the original frames (need them later).
220
+ my_frames, my_idxs = result
221
+
222
+ frames.append(my_frames)
223
+ frames_read.append(my_idxs)
224
+ for i, frame in enumerate(my_frames):
225
+ h, w = frame.shape[:2]
226
+ img = Image.fromarray(frame.astype(np.uint8))
227
+ img = img.resize(size=[s // 2 for s in img.size])
228
+
229
+ batch_boxes, probs = self.detector.detect(img, landmarks=False)
230
+
231
+ faces = []
232
+ scores = []
233
+ if batch_boxes is None:
234
+ continue
235
+ for bbox, score in zip(batch_boxes, probs):
236
+ if bbox is not None:
237
+ xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox]
238
+ w = xmax - xmin
239
+ h = ymax - ymin
240
+ p_h = h // 3
241
+ p_w = w // 3
242
+ crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w]
243
+ faces.append(crop)
244
+ scores.append(score)
245
+
246
+ frame_dict = {"video_idx": video_idx,
247
+ "frame_idx": my_idxs[i],
248
+ "frame_w": w,
249
+ "frame_h": h,
250
+ "faces": faces,
251
+ "scores": scores}
252
+ results.append(frame_dict)
253
+
254
+ return results
255
+
256
+ def process_video(self, video_path):
257
+ """Convenience method for doing face extraction on a single video."""
258
+ input_dir = os.path.dirname(video_path)
259
+ filenames = [os.path.basename(video_path)]
260
+ return self.process_videos(input_dir, filenames, [0])
261
+
262
+
263
+
264
+ def confident_strategy(pred, t=0.8):
265
+ pred = np.array(pred)
266
+ sz = len(pred)
267
+ fakes = np.count_nonzero(pred > t)
268
+ # 11 frames are detected as fakes with high probability
269
+ if fakes > sz // 2.5 and fakes > 11:
270
+ return np.mean(pred[pred > t])
271
+ elif np.count_nonzero(pred < 0.2) > 0.9 * sz:
272
+ return np.mean(pred[pred < 0.2])
273
+ else:
274
+ return np.mean(pred)
275
+
276
+ strategy = confident_strategy
277
+
278
+
279
+ def put_to_center(img, input_size):
280
+ img = img[:input_size, :input_size]
281
+ image = np.zeros((input_size, input_size, 3), dtype=np.uint8)
282
+ start_w = (input_size - img.shape[1]) // 2
283
+ start_h = (input_size - img.shape[0]) // 2
284
+ image[start_h:start_h + img.shape[0], start_w: start_w + img.shape[1], :] = img
285
+ return image
286
+
287
+
288
+ def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
289
+ h, w = img.shape[:2]
290
+ if max(w, h) == size:
291
+ return img
292
+ if w > h:
293
+ scale = size / w
294
+ h = h * scale
295
+ w = size
296
+ else:
297
+ scale = size / h
298
+ w = w * scale
299
+ h = size
300
+ interpolation = interpolation_up if scale > 1 else interpolation_down
301
+ resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
302
+ return resized
303
+
304
+
305
+ def predict_on_video(face_extractor, video_path, batch_size, input_size, models, strategy=np.mean,
306
+ apply_compression=False):
307
+ batch_size *= 4
308
+ try:
309
+ faces = face_extractor.process_video(video_path)
310
+ if len(faces) > 0:
311
+ x = np.zeros((batch_size, input_size, input_size, 3), dtype=np.uint8)
312
+ n = 0
313
+ for frame_data in faces:
314
+ for face in frame_data["faces"]:
315
+ resized_face = isotropically_resize_image(face, input_size)
316
+ resized_face = put_to_center(resized_face, input_size)
317
+ if apply_compression:
318
+ resized_face = image_compression(resized_face, quality=90, image_type=".jpg")
319
+ if n + 1 < batch_size:
320
+ x[n] = resized_face
321
+ n += 1
322
+ else:
323
+ pass
324
+ if n > 0:
325
+ x = torch.tensor(x).float()
326
+ # Preprocess the images.
327
+ x = x.permute((0, 3, 1, 2))
328
+ for i in range(len(x)):
329
+ x[i] = normalize_transform(x[i] / 255.)
330
+ # Make a prediction, then take the average.
331
+ with torch.no_grad():
332
+ preds = []
333
+ for model in models:
334
+ y_pred = model(x[:n])
335
+ y_pred = torch.sigmoid(y_pred.squeeze())
336
+ bpred = y_pred[:n].cpu().numpy()
337
+ preds.append(strategy(bpred))
338
+ return np.mean(preds)
339
+ except Exception as e:
340
+ print("Prediction error on video %s: %s" % (video_path, str(e)))
341
+
342
+ return 0.5
343
+
344
+
345
+ def predict_on_video_set(face_extractor, videos, input_size, num_workers, test_dir, frames_per_video, models,
346
+ strategy=np.mean,
347
+ apply_compression=False):
348
+ def process_file(i):
349
+ filename = videos[i]
350
+ y_pred = predict_on_video(face_extractor=face_extractor, video_path=os.path.join(test_dir, filename),
351
+ input_size=input_size,
352
+ batch_size=frames_per_video,
353
+ models=models, strategy=strategy, apply_compression=apply_compression)
354
+ return y_pred
355
+
356
+ with ThreadPoolExecutor(max_workers=num_workers) as ex:
357
+ predictions = ex.map(process_file, range(len(videos)))
358
+ return list(predictions)
359
+
requirements.txt ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations==1.2.1
2
+ altair==4.2.0
3
+ attrs==22.1.0
4
+ blinker==1.5
5
+ cachetools==5.2.0
6
+ certifi==2022.6.15
7
+ charset-normalizer==2.1.1
8
+ click==8.1.3
9
+ commonmark==0.9.1
10
+ decorator==5.1.1
11
+ entrypoints==0.4
12
+ facenet-pytorch==2.5.2
13
+ gitdb==4.0.9
14
+ GitPython==3.1.27
15
+ idna==3.3
16
+ imageio==2.21.2
17
+ importlib-metadata==4.12.0
18
+ Jinja2==3.1.2
19
+ joblib==1.1.0
20
+ jsonschema==4.15.0
21
+ MarkupSafe==2.1.1
22
+ networkx==2.8.6
23
+ numpy==1.23.2
24
+ opencv-python==4.6.0.66
25
+ opencv-python-headless==4.6.0.66
26
+ packaging==21.3
27
+ pandas==1.4.4
28
+ Pillow==9.2.0
29
+ protobuf==3.20.1
30
+ pyarrow==9.0.0
31
+ pydeck==0.8.0b1
32
+ Pygments==2.13.0
33
+ Pympler==1.0.1
34
+ pyparsing==3.0.9
35
+ pyrsistent==0.18.1
36
+ python-dateutil==2.8.2
37
+ pytz==2022.2.1
38
+ pytz-deprecation-shim==0.1.0.post0
39
+ PyWavelets==1.3.0
40
+ PyYAML==6.0
41
+ qudida==0.0.4
42
+ requests==2.28.1
43
+ rich==12.5.1
44
+ scikit-image==0.19.3
45
+ scikit-learn==1.1.2
46
+ scipy==1.9.1
47
+ semver==2.13.0
48
+ six==1.16.0
49
+ smmap==5.0.0
50
+ streamlit==1.12.2
51
+ threadpoolctl==3.1.0
52
+ tifffile==2022.8.12
53
+ timm==0.6.7
54
+ toml==0.10.2
55
+ toolz==0.12.0
56
+ torch==1.12.1
57
+ torchvision==0.13.1
58
+ tornado==6.2
59
+ typing_extensions==4.3.0
60
+ tzdata==2022.2
61
+ tzlocal==4.2
62
+ urllib3==1.26.12
63
+ validators==0.20.0
64
+ watchdog==2.1.9
65
+ zipp==3.8.1
training/__init__.py ADDED
File without changes
training/datasets/__init__.py ADDED
File without changes
training/datasets/classifier_dataset.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import sys
5
+ import traceback
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import pandas as pd
10
+ import skimage.draw
11
+ from albumentations import ImageCompression, OneOf, GaussianBlur, Blur
12
+ from albumentations.augmentations.functional import image_compression, rot90
13
+ from albumentations.pytorch.functional import img_to_tensor
14
+ from scipy.ndimage import binary_erosion, binary_dilation
15
+ from skimage import measure
16
+ from torch.utils.data import Dataset
17
+ import dlib
18
+
19
+ from training.datasets.validation_set import PUBLIC_SET
20
+
21
+
22
+ def prepare_bit_masks(mask):
23
+ h, w = mask.shape
24
+ mid_w = w // 2
25
+ mid_h = w // 2
26
+ masks = []
27
+ ones = np.ones_like(mask)
28
+ ones[:mid_h] = 0
29
+ masks.append(ones)
30
+ ones = np.ones_like(mask)
31
+ ones[mid_h:] = 0
32
+ masks.append(ones)
33
+ ones = np.ones_like(mask)
34
+ ones[:, :mid_w] = 0
35
+ masks.append(ones)
36
+ ones = np.ones_like(mask)
37
+ ones[:, mid_w:] = 0
38
+ masks.append(ones)
39
+ ones = np.ones_like(mask)
40
+ ones[:mid_h, :mid_w] = 0
41
+ ones[mid_h:, mid_w:] = 0
42
+ masks.append(ones)
43
+ ones = np.ones_like(mask)
44
+ ones[:mid_h, mid_w:] = 0
45
+ ones[mid_h:, :mid_w] = 0
46
+ masks.append(ones)
47
+ return masks
48
+
49
+
50
+ detector = dlib.get_frontal_face_detector()
51
+ predictor = dlib.shape_predictor('libs/shape_predictor_68_face_landmarks.dat')
52
+
53
+
54
+ def blackout_convex_hull(img):
55
+ try:
56
+ rect = detector(img)[0]
57
+ sp = predictor(img, rect)
58
+ landmarks = np.array([[p.x, p.y] for p in sp.parts()])
59
+ outline = landmarks[[*range(17), *range(26, 16, -1)]]
60
+ Y, X = skimage.draw.polygon(outline[:, 1], outline[:, 0])
61
+ cropped_img = np.zeros(img.shape[:2], dtype=np.uint8)
62
+ cropped_img[Y, X] = 1
63
+ # if random.random() > 0.5:
64
+ # img[cropped_img == 0] = 0
65
+ # #leave only face
66
+ # return img
67
+
68
+ y, x = measure.centroid(cropped_img)
69
+ y = int(y)
70
+ x = int(x)
71
+ first = random.random() > 0.5
72
+ if random.random() > 0.5:
73
+ if first:
74
+ cropped_img[:y, :] = 0
75
+ else:
76
+ cropped_img[y:, :] = 0
77
+ else:
78
+ if first:
79
+ cropped_img[:, :x] = 0
80
+ else:
81
+ cropped_img[:, x:] = 0
82
+
83
+ img[cropped_img > 0] = 0
84
+ except Exception as e:
85
+ pass
86
+
87
+
88
+ def dist(p1, p2):
89
+ return math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
90
+
91
+
92
+ def remove_eyes(image, landmarks):
93
+ image = image.copy()
94
+ (x1, y1), (x2, y2) = landmarks[:2]
95
+ mask = np.zeros_like(image[..., 0])
96
+ line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2)
97
+ w = dist((x1, y1), (x2, y2))
98
+ dilation = int(w // 4)
99
+ line = binary_dilation(line, iterations=dilation)
100
+ image[line, :] = 0
101
+ return image
102
+
103
+
104
+ def remove_nose(image, landmarks):
105
+ image = image.copy()
106
+ (x1, y1), (x2, y2) = landmarks[:2]
107
+ x3, y3 = landmarks[2]
108
+ mask = np.zeros_like(image[..., 0])
109
+ x4 = int((x1 + x2) / 2)
110
+ y4 = int((y1 + y2) / 2)
111
+ line = cv2.line(mask, (x3, y3), (x4, y4), color=(1), thickness=2)
112
+ w = dist((x1, y1), (x2, y2))
113
+ dilation = int(w // 4)
114
+ line = binary_dilation(line, iterations=dilation)
115
+ image[line, :] = 0
116
+ return image
117
+
118
+
119
+ def remove_mouth(image, landmarks):
120
+ image = image.copy()
121
+ (x1, y1), (x2, y2) = landmarks[-2:]
122
+ mask = np.zeros_like(image[..., 0])
123
+ line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2)
124
+ w = dist((x1, y1), (x2, y2))
125
+ dilation = int(w // 3)
126
+ line = binary_dilation(line, iterations=dilation)
127
+ image[line, :] = 0
128
+ return image
129
+
130
+
131
+ def remove_landmark(image, landmarks):
132
+ if random.random() > 0.5:
133
+ image = remove_eyes(image, landmarks)
134
+ elif random.random() > 0.5:
135
+ image = remove_mouth(image, landmarks)
136
+ elif random.random() > 0.5:
137
+ image = remove_nose(image, landmarks)
138
+ return image
139
+
140
+
141
+ def change_padding(image, part=5):
142
+ h, w = image.shape[:2]
143
+ # original padding was done with 1/3 from each side, too much
144
+ pad_h = int(((3 / 5) * h) / part)
145
+ pad_w = int(((3 / 5) * w) / part)
146
+ image = image[h // 5 - pad_h:-h // 5 + pad_h, w // 5 - pad_w:-w // 5 + pad_w]
147
+ return image
148
+
149
+
150
+ def blackout_random(image, mask, label):
151
+ binary_mask = mask > 0.4 * 255
152
+ h, w = binary_mask.shape[:2]
153
+
154
+ tries = 50
155
+ current_try = 1
156
+ while current_try < tries:
157
+ first = random.random() < 0.5
158
+ if random.random() < 0.5:
159
+ pivot = random.randint(h // 2 - h // 5, h // 2 + h // 5)
160
+ bitmap_msk = np.ones_like(binary_mask)
161
+ if first:
162
+ bitmap_msk[:pivot, :] = 0
163
+ else:
164
+ bitmap_msk[pivot:, :] = 0
165
+ else:
166
+ pivot = random.randint(w // 2 - w // 5, w // 2 + w // 5)
167
+ bitmap_msk = np.ones_like(binary_mask)
168
+ if first:
169
+ bitmap_msk[:, :pivot] = 0
170
+ else:
171
+ bitmap_msk[:, pivot:] = 0
172
+
173
+ if label < 0.5 and np.count_nonzero(image * np.expand_dims(bitmap_msk, axis=-1)) / 3 > (h * w) / 5 \
174
+ or np.count_nonzero(binary_mask * bitmap_msk) > 40:
175
+ mask *= bitmap_msk
176
+ image *= np.expand_dims(bitmap_msk, axis=-1)
177
+ break
178
+ current_try += 1
179
+ return image
180
+
181
+
182
+ def blend_original(img):
183
+ img = img.copy()
184
+ h, w = img.shape[:2]
185
+ rect = detector(img)
186
+ if len(rect) == 0:
187
+ return img
188
+ else:
189
+ rect = rect[0]
190
+ sp = predictor(img, rect)
191
+ landmarks = np.array([[p.x, p.y] for p in sp.parts()])
192
+ outline = landmarks[[*range(17), *range(26, 16, -1)]]
193
+ Y, X = skimage.draw.polygon(outline[:, 1], outline[:, 0])
194
+ raw_mask = np.zeros(img.shape[:2], dtype=np.uint8)
195
+ raw_mask[Y, X] = 1
196
+ face = img * np.expand_dims(raw_mask, -1)
197
+
198
+ # add warping
199
+ h1 = random.randint(h - h // 2, h + h // 2)
200
+ w1 = random.randint(w - w // 2, w + w // 2)
201
+ while abs(h1 - h) < h // 3 and abs(w1 - w) < w // 3:
202
+ h1 = random.randint(h - h // 2, h + h // 2)
203
+ w1 = random.randint(w - w // 2, w + w // 2)
204
+ face = cv2.resize(face, (w1, h1), interpolation=random.choice([cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC]))
205
+ face = cv2.resize(face, (w, h), interpolation=random.choice([cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC]))
206
+
207
+ raw_mask = binary_erosion(raw_mask, iterations=random.randint(4, 10))
208
+ img[raw_mask, :] = face[raw_mask, :]
209
+ if random.random() < 0.2:
210
+ img = OneOf([GaussianBlur(), Blur()], p=0.5)(image=img)["image"]
211
+ # image compression
212
+ if random.random() < 0.5:
213
+ img = ImageCompression(quality_lower=40, quality_upper=95)(image=img)["image"]
214
+ return img
215
+
216
+
217
+ class DeepFakeClassifierDataset(Dataset):
218
+
219
+ def __init__(self,
220
+ data_path="/mnt/sota/datasets/deepfake",
221
+ fold=0,
222
+ label_smoothing=0.01,
223
+ padding_part=3,
224
+ hardcore=True,
225
+ crops_dir="crops",
226
+ folds_csv="folds.csv",
227
+ normalize={"mean": [0.485, 0.456, 0.406],
228
+ "std": [0.229, 0.224, 0.225]},
229
+ rotation=False,
230
+ mode="train",
231
+ reduce_val=True,
232
+ oversample_real=True,
233
+ transforms=None
234
+ ):
235
+ super().__init__()
236
+ self.data_root = data_path
237
+ self.fold = fold
238
+ self.folds_csv = folds_csv
239
+ self.mode = mode
240
+ self.rotation = rotation
241
+ self.padding_part = padding_part
242
+ self.hardcore = hardcore
243
+ self.crops_dir = crops_dir
244
+ self.label_smoothing = label_smoothing
245
+ self.normalize = normalize
246
+ self.transforms = transforms
247
+ self.df = pd.read_csv(self.folds_csv)
248
+ self.oversample_real = oversample_real
249
+ self.reduce_val = reduce_val
250
+
251
+ def __getitem__(self, index: int):
252
+
253
+ while True:
254
+ video, img_file, label, ori_video, frame, fold = self.data[index]
255
+ try:
256
+ if self.mode == "train":
257
+ label = np.clip(label, self.label_smoothing, 1 - self.label_smoothing)
258
+ img_path = os.path.join(self.data_root, self.crops_dir, video, img_file)
259
+ image = cv2.imread(img_path, cv2.IMREAD_COLOR)
260
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
261
+ mask = np.zeros(image.shape[:2], dtype=np.uint8)
262
+ diff_path = os.path.join(self.data_root, "diffs", video, img_file[:-4] + "_diff.png")
263
+ try:
264
+ msk = cv2.imread(diff_path, cv2.IMREAD_GRAYSCALE)
265
+ if msk is not None:
266
+ mask = msk
267
+ except:
268
+ print("not found mask", diff_path)
269
+ pass
270
+ if self.mode == "train" and self.hardcore and not self.rotation:
271
+ landmark_path = os.path.join(self.data_root, "landmarks", ori_video, img_file[:-4] + ".npy")
272
+ if os.path.exists(landmark_path) and random.random() < 0.7:
273
+ landmarks = np.load(landmark_path)
274
+ image = remove_landmark(image, landmarks)
275
+ elif random.random() < 0.2:
276
+ blackout_convex_hull(image)
277
+ elif random.random() < 0.1:
278
+ binary_mask = mask > 0.4 * 255
279
+ masks = prepare_bit_masks((binary_mask * 1).astype(np.uint8))
280
+ tries = 6
281
+ current_try = 1
282
+ while current_try < tries:
283
+ bitmap_msk = random.choice(masks)
284
+ if label < 0.5 or np.count_nonzero(mask * bitmap_msk) > 20:
285
+ mask *= bitmap_msk
286
+ image *= np.expand_dims(bitmap_msk, axis=-1)
287
+ break
288
+ current_try += 1
289
+ if self.mode == "train" and self.padding_part > 3:
290
+ image = change_padding(image, self.padding_part)
291
+ valid_label = np.count_nonzero(mask[mask > 20]) > 32 or label < 0.5
292
+ valid_label = 1 if valid_label else 0
293
+ rotation = 0
294
+ if self.transforms:
295
+ data = self.transforms(image=image, mask=mask)
296
+ image = data["image"]
297
+ mask = data["mask"]
298
+ if self.mode == "train" and self.hardcore and self.rotation:
299
+ # landmark_path = os.path.join(self.data_root, "landmarks", ori_video, img_file[:-4] + ".npy")
300
+ dropout = 0.8 if label > 0.5 else 0.6
301
+ if self.rotation:
302
+ dropout *= 0.7
303
+ elif random.random() < dropout:
304
+ blackout_random(image, mask, label)
305
+
306
+ #
307
+ # os.makedirs("../images", exist_ok=True)
308
+ # cv2.imwrite(os.path.join("../images", video+ "_" + str(1 if label > 0.5 else 0) + "_"+img_file), image[...,::-1])
309
+
310
+ if self.mode == "train" and self.rotation:
311
+ rotation = random.randint(0, 3)
312
+ image = rot90(image, rotation)
313
+
314
+ image = img_to_tensor(image, self.normalize)
315
+ return {"image": image, "labels": np.array((label,)), "img_name": os.path.join(video, img_file),
316
+ "valid": valid_label, "rotations": rotation}
317
+ except Exception as e:
318
+ traceback.print_exc(file=sys.stdout)
319
+ print("Broken image", os.path.join(self.data_root, self.crops_dir, video, img_file))
320
+ index = random.randint(0, len(self.data) - 1)
321
+
322
+ def random_blackout_landmark(self, image, mask, landmarks):
323
+ x, y = random.choice(landmarks)
324
+ first = random.random() > 0.5
325
+ # crop half face either vertically or horizontally
326
+ if random.random() > 0.5:
327
+ # width
328
+ if first:
329
+ image[:, :x] = 0
330
+ mask[:, :x] = 0
331
+ else:
332
+ image[:, x:] = 0
333
+ mask[:, x:] = 0
334
+ else:
335
+ # height
336
+ if first:
337
+ image[:y, :] = 0
338
+ mask[:y, :] = 0
339
+ else:
340
+ image[y:, :] = 0
341
+ mask[y:, :] = 0
342
+
343
+ def reset(self, epoch, seed):
344
+ self.data = self._prepare_data(epoch, seed)
345
+
346
+ def __len__(self) -> int:
347
+ return len(self.data)
348
+
349
+ def _prepare_data(self, epoch, seed):
350
+ df = self.df
351
+ if self.mode == "train":
352
+ rows = df[df["fold"] != self.fold]
353
+ else:
354
+ rows = df[df["fold"] == self.fold]
355
+ seed = (epoch + 1) * seed
356
+ if self.oversample_real:
357
+ rows = self._oversample(rows, seed)
358
+ if self.mode == "val" and self.reduce_val:
359
+ # every 2nd frame, to speed up validation
360
+ rows = rows[rows["frame"] % 20 == 0]
361
+ # another option is to use public validation set
362
+ #rows = rows[rows["video"].isin(PUBLIC_SET)]
363
+
364
+ print(
365
+ "real {} fakes {} mode {}".format(len(rows[rows["label"] == 0]), len(rows[rows["label"] == 1]), self.mode))
366
+ data = rows.values
367
+
368
+ np.random.seed(seed)
369
+ np.random.shuffle(data)
370
+ return data
371
+
372
+ def _oversample(self, rows: pd.DataFrame, seed):
373
+ real = rows[rows["label"] == 0]
374
+ fakes = rows[rows["label"] == 1]
375
+ num_real = real["video"].count()
376
+ if self.mode == "train":
377
+ fakes = fakes.sample(n=num_real, replace=False, random_state=seed)
378
+ return pd.concat([real, fakes])
training/datasets/validation_set.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ PUBLIC_SET = {'tjuihawuqm', 'prwsfljdjo', 'scrbqgpvzz', 'ziipxxchai', 'uubgqnvfdl', 'wclvkepakb', 'xjvxtuakyd',
4
+ 'qlvsqdroqo', 'bcbqxhziqz', 'yzuestxcbq', 'hxwtsaydal', 'kqlvggiqee', 'vtunvalyji', 'mohiqoogpb',
5
+ 'siebfpwuhu', 'cekwtyxdoo', 'hszwwswewp', 'orekjthsef', 'huvlwkxoxm', 'fmhiujydwo', 'lhvjzhjxdp',
6
+ 'ibxfxggtqh', 'bofrwgeyjo', 'rmufsuogzn', 'zbgssotnjm', 'dpevefkefv', 'sufvvwmbha', 'ncoeewrdlo',
7
+ 'qhsehzgxqj', 'yxadevzohx', 'aomqqjipcp', 'pcyswtgick', 'wfzjxzhdkj', 'rcjfxxhcal', 'lnjkpdviqb',
8
+ 'xmkwsnuzyq', 'ouaowjmigq', 'bkuzquigyt', 'vwxednhlwz', 'mszblrdprw', 'blnmxntbey', 'gccnvdoknm',
9
+ 'mkzaekkvej', 'hclsparpth', 'eryjktdexi', 'hfsvqabzfq', 'acazlolrpz', 'yoyhmxtrys', 'rerpivllud',
10
+ 'elackxuccp', 'zgbhzkditd', 'vjljdfopjg', 'famlupsgqm', 'nymodlmxni', 'qcbkztamqc', 'qclpbcbgeq',
11
+ 'lpkgabskbw', 'mnowxangqx', 'czfqlbcfpa', 'qyyhuvqmyf', 'toinozytsp', 'ztyvglkcsf', 'nplviymzlg',
12
+ 'opvqdabdap', 'uxuvkrjhws', 'mxahsihabr', 'cqxxumarvp', 'ptbfnkajyi', 'njzshtfmcw', 'dcqodpzomd',
13
+ 'ajiyrjfyzp', 'ywauoonmlr', 'gochxzemmq', 'lpgxwdgnio', 'hnfwagcxdf', 'gfcycflhbo', 'gunamloolc',
14
+ 'yhjlnisfel', 'srfefmyjvt', 'evysmtpnrf', 'aktnlyqpah', 'gpsxfxrjrr', 'zfobicuigx', 'mnzabbkpmt',
15
+ 'rfjuhbnlro', 'zuwwbbusgl', 'csnkohqxdv', 'bzvzpwrabw', 'yietrwuncf', 'wynotylpnm', 'ekboxwrwuv',
16
+ 'rcecrgeotc', 'rklawjhbpv', 'ilqwcbprqa', 'jsysgmycsx', 'sqixhnilfm', 'wnlubukrki', 'nikynwcvuh',
17
+ 'sjkfxrlxxs', 'btdxnajogv', 'wjhpisoeaj', 'dyjklprkoc', 'qlqhjcshpk', 'jyfvaequfg', 'dozjwhnedd',
18
+ 'owaogcehvc', 'oyqgwjdwaj', 'vvfszaosiv', 'kmcdjxmnoa', 'jiswxuqzyz', 'ddtbarpcgo', 'wqysrieiqu',
19
+ 'xcruhaccxc', 'honxqdilvv', 'nxgzmgzkfv', 'cxsvvnxpyz', 'demuhxssgl', 'hzoiotcykp', 'fwykevubzy',
20
+ 'tejfudfgpq', 'kvmpmhdxly', 'oojxonbgow', 'vurjckblge', 'oysopgovhu', 'khpipxnsvx', 'pqthmvwonf',
21
+ 'fddmkqjwsh', 'pcoxcmtroa', 'cnxccbjlct', 'ggzjfrirjh', 'jquevmhdvc', 'ecumyiowzs', 'esmqxszybs',
22
+ 'mllzkpgatp', 'ryxaqpfubf', 'hbufmvbium', 'vdtsbqidjb', 'sjwywglgym', 'qxyrtwozyw', 'upmgtackuf',
23
+ 'ucthmsajay', 'zgjosltkie', 'snlyjbnpgw', 'nswtvttxre', 'iznnzjvaxc', 'jhczqfefgw', 'htzbnroagi',
24
+ 'pdswwyyntw', 'uvrzaczrbx', 'vbcgoyxsvn', 'hzssdinxec', 'novarhxpbj', 'vizerpsvbz', 'jawgcggquk',
25
+ 'iorbtaarte', 'yarpxfqejd', 'vhbbwdflyh', 'rrrfjhugvb', 'fneqiqpqvs', 'jytrvwlewz', 'bfjsthfhbd',
26
+ 'rxdoimqble', 'ekelfsnqof', 'uqvxjfpwdo', 'cjkctqqakb', 'tynfsthodx', 'yllztsrwjw', 'bktkwbcawi',
27
+ 'wcqvzujamg', 'bcvheslzrq', 'aqrsylrzgi', 'sktpeppbkc', 'mkmgcxaztt', 'etdliwticv', 'hqzwudvhih',
28
+ 'swsaoktwgi', 'temjefwaas', 'papagllumt', 'xrtvqhdibb', 'oelqpetgwj', 'ggdpclfcgk', 'imdmhwkkni',
29
+ 'lebzjtusnr', 'xhtppuyqdr', 'nxzgekegsp', 'waucvvmtkq', 'rnfcjxynfa', 'adohdulfwb', 'tjywwgftmv',
30
+ 'fjrueenjyp', 'oaguiggjyv', 'ytopzxrswu', 'yxvmusxvcz', 'rukyxomwcx', 'qdqdsaiitt', 'mxlipjhmqk',
31
+ 'voawxrmqyl', 'kezwvsxxzj', 'oocincvedt', 'qooxnxqqjb', 'mwwploizlj', 'yaxgpxhavq', 'uhakqelqri',
32
+ 'bvpeerislp', 'bkcyglmfci', 'jyoxdvxpza', 'gkutjglghz', 'knxltsvzyu', 'ybbrkacebd', 'apvzjkvnwn',
33
+ 'ahjnxtiamx', 'hsbljbsgxr', 'fnxgqcvlsd', 'xphdfgmfmz', 'scbdenmaed', 'ywxpquomgt', 'yljecirelf',
34
+ 'wcvsqnplsk', 'vmxfwxgdei', 'icbsahlivv', 'yhylappzid', 'irqzdokcws', 'petmyhjclt', 'rmlzgerevr',
35
+ 'qarqtkvgby', 'nkhzxomani', 'viteugozpv', 'qhkzlnzruj', 'eisofhptvk', 'gqnaxievjx', 'heiyoojifp',
36
+ 'zcxcmneefk', 'wvgviwnwob', 'gcdtglsoqj', 'yqhouqakbx', 'fopjiyxiqd', 'hierggamuo', 'ypbtpunjvm',
37
+ 'sjinmmbipg', 'kmqkiihrmj', 'wmoqzxddkb', 'lnhkjhyhvw', 'wixbuuzygv', 'fsdrwikhge', 'sfsayjgzrh',
38
+ 'pqdeutauqc', 'frqfsucgao', 'pdufsewrec', 'bfdopzvxbi', 'shnsajrsow', 'rvvpazsffd', 'pxcfrszlgi',
39
+ 'itfsvvmslp', 'ayipraspbn', 'prhmixykhr', 'doniqevxeg', 'dvtpwatuja', 'jiavqbrkyk', 'ipkpxvwroe',
40
+ 'syxobtuucp', 'syuxttuyhm', 'nwvsbmyndn', 'eqslzbqfea', 'ytddugrwph', 'vokrpfjpeb', 'bdshuoldwx',
41
+ 'fmvvmcbdrw', 'bnuwxhfahw', 'gbnzicjyhz', 'txnmkabufs', 'gfdjzwnpyp', 'hweshqpfwe', 'dxgnpnowgk',
42
+ 'xugmhbetrw', 'rktrpsdlci', 'nthpnwylxo', 'ihglzxzroo', 'ocgdbrgmtq', 'ruhtnngrqv', 'xljemofssi',
43
+ 'zxacihctqp', 'ghnpsltzyn', 'lbigytrrtr', 'ndikguxzek', 'mdfndlljvt', 'lyoslorecs', 'oefukgnvel',
44
+ 'zmxeiipnqb', 'cosghhimnd', 'alrtntfxtd', 'eywdmustbb', 'ooafcxxfrs', 'fqgypsunzr', 'hevcclcklc',
45
+ 'uhrqlmlclw', 'ipvwtgdlre', 'wcssbghcpc', 'didzujjhtg', 'fjxovgmwnm', 'dmmvuaikkv', 'hitfycdavv',
46
+ 'zyufpqvpyu', 'coujjnypba', 'temeqbmzxu', 'apedduehoy', 'iksxzpqxzi', 'kwfdyqofzw', 'aassnaulhq',
47
+ 'eyguqfmgzh', 'yiykshcbaz', 'sngjsueuhs', 'okgelildpc', 'ztyuiqrhdk', 'tvhjcfnqtg', 'gfgcwxkbjd',
48
+ 'lbfqksftuo', 'kowiwvrjht', 'dkuqbduxev', 'mwnibuujwz', 'sodvtfqbpf', 'hsbwhlolsn', 'qsjiypnjwi',
49
+ 'blszgmxkvu', 'ystdtnetgj', 'rfwxcinshk', 'vnlzxqwthl', 'ljouzjaqqe', 'gahgyuwzbu', 'xxzefxwyku',
50
+ 'xitgdpzbxv', 'sylnrepacf', 'igpvrfjdzc', 'nxnmkytwze', 'psesikjaxx', 'dvwpvqdflx', 'bjyaxvggle',
51
+ 'dpmgoiwhuf', 'wadvzjhwtw', 'kcjvhgvhpt', 'eppyqpgewp', 'tyjpjpglgx', 'cekarydqba', 'dvkdfhrpph',
52
+ 'cnpanmywno', 'ljauauuyka', 'hicjuubiau', 'cqhwesrciw', 'dnmowthjcj', 'lujvyveojc', 'wndursivcx',
53
+ 'espkiocpxq', 'jsbpkpxwew', 'dsnxgrfdmd', 'hyjqolupxn', 'xdezcezszc', 'axfhbpkdlc', 'qqnlrngaft',
54
+ 'coqwgzpbhx', 'ncmpqwmnzb', 'sznkemeqro', 'omphqltjdd', 'uoccaiathd', 'jzmzdispyo', 'pxjkzvqomp',
55
+ 'udxqbhgvvx', 'dzkyxbbqkr', 'dtozwcapoa', 'qswlzfgcgj', 'tgawasvbbr', 'lmdyicksrv', 'fzvpbrzssi',
56
+ 'dxfdovivlw', 'zzmgnglanj', 'vssmlqoiti', 'vajkicalux', 'ekvwecwltj', 'ylxwcwhjjd', 'keioymnobc',
57
+ 'usqqvxcjmg', 'phjvutxpoi', 'nycmyuzpml', 'bwdmzwhdnw', 'fxuxxtryjn', 'orixbcfvdz', 'hefisnapds',
58
+ 'fpevfidstw', 'halvwiltfs', 'dzojiwfvba', 'ojsxxkalat', 'esjdyghhog', 'ptbnewtvon', 'hcanfkwivl',
59
+ 'yronlutbgm', 'llplvmcvbl', 'yxirnfyijn', 'nwvloufjty', 'rtpbawlmxr', 'aayfryxljh', 'zfrrixsimm',
60
+ 'txmnoyiyte'}
training/losses.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from pytorch_toolbelt.losses import BinaryFocalLoss
4
+ from torch import nn
5
+ from torch.nn.modules.loss import BCEWithLogitsLoss
6
+
7
+
8
+ class WeightedLosses(nn.Module):
9
+ def __init__(self, losses, weights):
10
+ super().__init__()
11
+ self.losses = losses
12
+ self.weights = weights
13
+
14
+ def forward(self, *input: Any, **kwargs: Any):
15
+ cum_loss = 0
16
+ for loss, w in zip(self.losses, self.weights):
17
+ cum_loss += w * loss.forward(*input, **kwargs)
18
+ return cum_loss
19
+
20
+
21
+ class BinaryCrossentropy(BCEWithLogitsLoss):
22
+ pass
23
+
24
+
25
+ class FocalLoss(BinaryFocalLoss):
26
+ def __init__(self, alpha=None, gamma=3, ignore_index=None, reduction="mean", normalized=False,
27
+ reduced_threshold=None):
28
+ super().__init__(alpha, gamma, ignore_index, reduction, normalized, reduced_threshold)
training/pipelines/__init__.py ADDED
File without changes
training/pipelines/train_classifier.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from collections import defaultdict
5
+
6
+ from sklearn.metrics import log_loss
7
+ from torch import topk
8
+
9
+ from training import losses
10
+ from training.datasets.classifier_dataset import DeepFakeClassifierDataset
11
+ from training.losses import WeightedLosses
12
+ from training.tools.config import load_config
13
+ from training.tools.utils import create_optimizer, AverageMeter
14
+ from training.transforms.albu import IsotropicResize
15
+ from training.zoo import classifiers
16
+
17
+ os.environ["MKL_NUM_THREADS"] = "1"
18
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
19
+ os.environ["OMP_NUM_THREADS"] = "1"
20
+
21
+ import cv2
22
+
23
+ cv2.ocl.setUseOpenCL(False)
24
+ cv2.setNumThreads(0)
25
+ import numpy as np
26
+ from albumentations import Compose, RandomBrightnessContrast, \
27
+ HorizontalFlip, FancyPCA, HueSaturationValue, OneOf, ToGray, \
28
+ ShiftScaleRotate, ImageCompression, PadIfNeeded, GaussNoise, GaussianBlur
29
+
30
+ from apex.parallel import DistributedDataParallel, convert_syncbn_model
31
+ from tensorboardX import SummaryWriter
32
+
33
+ from apex import amp
34
+
35
+ import torch
36
+ from torch.backends import cudnn
37
+ from torch.nn import DataParallel
38
+ from torch.utils.data import DataLoader
39
+ from tqdm import tqdm
40
+ import torch.distributed as dist
41
+
42
+ torch.backends.cudnn.benchmark = True
43
+
44
+
45
+ def create_train_transforms(size=300):
46
+ return Compose([
47
+ ImageCompression(quality_lower=60, quality_upper=100, p=0.5),
48
+ GaussNoise(p=0.1),
49
+ GaussianBlur(blur_limit=3, p=0.05),
50
+ HorizontalFlip(),
51
+ OneOf([
52
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
53
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR),
54
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR),
55
+ ], p=1),
56
+ PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
57
+ OneOf([RandomBrightnessContrast(), FancyPCA(), HueSaturationValue()], p=0.7),
58
+ ToGray(p=0.2),
59
+ ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5),
60
+ ]
61
+ )
62
+
63
+
64
+ def create_val_transforms(size=300):
65
+ return Compose([
66
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
67
+ PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
68
+ ])
69
+
70
+
71
+ def main():
72
+ parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
73
+ arg = parser.add_argument
74
+ arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
75
+ arg('--workers', type=int, default=6, help='number of cpu threads to use')
76
+ arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3')
77
+ arg('--output-dir', type=str, default='weights/')
78
+ arg('--resume', type=str, default='')
79
+ arg('--fold', type=int, default=0)
80
+ arg('--prefix', type=str, default='classifier_')
81
+ arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake")
82
+ arg('--folds-csv', type=str, default='folds.csv')
83
+ arg('--crops-dir', type=str, default='crops')
84
+ arg('--label-smoothing', type=float, default=0.01)
85
+ arg('--logdir', type=str, default='logs')
86
+ arg('--zero-score', action='store_true', default=False)
87
+ arg('--from-zero', action='store_true', default=False)
88
+ arg('--distributed', action='store_true', default=False)
89
+ arg('--freeze-epochs', type=int, default=0)
90
+ arg("--local_rank", default=0, type=int)
91
+ arg("--seed", default=777, type=int)
92
+ arg("--padding-part", default=3, type=int)
93
+ arg("--opt-level", default='O1', type=str)
94
+ arg("--test_every", type=int, default=1)
95
+ arg("--no-oversample", action="store_true")
96
+ arg("--no-hardcore", action="store_true")
97
+ arg("--only-changed-frames", action="store_true")
98
+
99
+ args = parser.parse_args()
100
+ os.makedirs(args.output_dir, exist_ok=True)
101
+ if args.distributed:
102
+ torch.cuda.set_device(args.local_rank)
103
+ torch.distributed.init_process_group(backend='nccl', init_method='env://')
104
+ else:
105
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
106
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
107
+
108
+ cudnn.benchmark = True
109
+
110
+ conf = load_config(args.config)
111
+ model = classifiers.__dict__[conf['network']](encoder=conf['encoder'])
112
+
113
+ model = model
114
+ if args.distributed:
115
+ model = convert_syncbn_model(model)
116
+ ohem = conf.get("ohem_samples", None)
117
+ reduction = "mean"
118
+ if ohem:
119
+ reduction = "none"
120
+ loss_fn = []
121
+ weights = []
122
+ for loss_name, weight in conf["losses"].items():
123
+ loss_fn.append(losses.__dict__[loss_name](reduction=reduction))
124
+ weights.append(weight)
125
+ loss = WeightedLosses(loss_fn, weights)
126
+ loss_functions = {"classifier_loss": loss}
127
+ optimizer, scheduler = create_optimizer(conf['optimizer'], model)
128
+ bce_best = 100
129
+ start_epoch = 0
130
+ batch_size = conf['optimizer']['batch_size']
131
+
132
+ data_train = DeepFakeClassifierDataset(mode="train",
133
+ oversample_real=not args.no_oversample,
134
+ fold=args.fold,
135
+ padding_part=args.padding_part,
136
+ hardcore=not args.no_hardcore,
137
+ crops_dir=args.crops_dir,
138
+ data_path=args.data_dir,
139
+ label_smoothing=args.label_smoothing,
140
+ folds_csv=args.folds_csv,
141
+ transforms=create_train_transforms(conf["size"]),
142
+ normalize=conf.get("normalize", None))
143
+ data_val = DeepFakeClassifierDataset(mode="val",
144
+ fold=args.fold,
145
+ padding_part=args.padding_part,
146
+ crops_dir=args.crops_dir,
147
+ data_path=args.data_dir,
148
+ folds_csv=args.folds_csv,
149
+ transforms=create_val_transforms(conf["size"]),
150
+ normalize=conf.get("normalize", None))
151
+ val_data_loader = DataLoader(data_val, batch_size=batch_size * 2, num_workers=args.workers, shuffle=False,
152
+ pin_memory=False)
153
+ os.makedirs(args.logdir, exist_ok=True)
154
+ summary_writer = SummaryWriter(args.logdir + '/' + conf.get("prefix", args.prefix) + conf['encoder'] + "_" + str(args.fold))
155
+ if args.resume:
156
+ if os.path.isfile(args.resume):
157
+ print("=> loading checkpoint '{}'".format(args.resume))
158
+ checkpoint = torch.load(args.resume, map_location='cpu')
159
+ state_dict = checkpoint['state_dict']
160
+ state_dict = {k[7:]: w for k, w in state_dict.items()}
161
+ model.load_state_dict(state_dict, strict=False)
162
+ if not args.from_zero:
163
+ start_epoch = checkpoint['epoch']
164
+ if not args.zero_score:
165
+ bce_best = checkpoint.get('bce_best', 0)
166
+ print("=> loaded checkpoint '{}' (epoch {}, bce_best {})"
167
+ .format(args.resume, checkpoint['epoch'], checkpoint['bce_best']))
168
+ else:
169
+ print("=> no checkpoint found at '{}'".format(args.resume))
170
+ if args.from_zero:
171
+ start_epoch = 0
172
+ current_epoch = start_epoch
173
+
174
+ if conf['fp16']:
175
+ model, optimizer = amp.initialize(model, optimizer,
176
+ opt_level=args.opt_level,
177
+ loss_scale='dynamic')
178
+
179
+ snapshot_name = "{}{}_{}_{}".format(conf.get("prefix", args.prefix), conf['network'], conf['encoder'], args.fold)
180
+
181
+ if args.distributed:
182
+ model = DistributedDataParallel(model, delay_allreduce=True)
183
+ else:
184
+ model = DataParallel(model)
185
+ data_val.reset(1, args.seed)
186
+ max_epochs = conf['optimizer']['schedule']['epochs']
187
+ for epoch in range(start_epoch, max_epochs):
188
+ data_train.reset(epoch, args.seed)
189
+ train_sampler = None
190
+ if args.distributed:
191
+ train_sampler = torch.utils.data.distributed.DistributedSampler(data_train)
192
+ train_sampler.set_epoch(epoch)
193
+ if epoch < args.freeze_epochs:
194
+ print("Freezing encoder!!!")
195
+ model.module.encoder.eval()
196
+ for p in model.module.encoder.parameters():
197
+ p.requires_grad = False
198
+ else:
199
+ model.module.encoder.train()
200
+ for p in model.module.encoder.parameters():
201
+ p.requires_grad = True
202
+
203
+ train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers,
204
+ shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False,
205
+ drop_last=True)
206
+
207
+ train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
208
+ args.local_rank, args.only_changed_frames)
209
+ model = model.eval()
210
+
211
+ if args.local_rank == 0:
212
+ torch.save({
213
+ 'epoch': current_epoch + 1,
214
+ 'state_dict': model.state_dict(),
215
+ 'bce_best': bce_best,
216
+ }, args.output_dir + '/' + snapshot_name + "_last")
217
+ torch.save({
218
+ 'epoch': current_epoch + 1,
219
+ 'state_dict': model.state_dict(),
220
+ 'bce_best': bce_best,
221
+ }, args.output_dir + snapshot_name + "_{}".format(current_epoch))
222
+ if (epoch + 1) % args.test_every == 0:
223
+ bce_best = evaluate_val(args, val_data_loader, bce_best, model,
224
+ snapshot_name=snapshot_name,
225
+ current_epoch=current_epoch,
226
+ summary_writer=summary_writer)
227
+ current_epoch += 1
228
+
229
+
230
+ def evaluate_val(args, data_val, bce_best, model, snapshot_name, current_epoch, summary_writer):
231
+ print("Test phase")
232
+ model = model.eval()
233
+
234
+ bce, probs, targets = validate(model, data_loader=data_val)
235
+ if args.local_rank == 0:
236
+ summary_writer.add_scalar('val/bce', float(bce), global_step=current_epoch)
237
+ if bce < bce_best:
238
+ print("Epoch {} improved from {} to {}".format(current_epoch, bce_best, bce))
239
+ if args.output_dir is not None:
240
+ torch.save({
241
+ 'epoch': current_epoch + 1,
242
+ 'state_dict': model.state_dict(),
243
+ 'bce_best': bce,
244
+ }, args.output_dir + snapshot_name + "_best_dice")
245
+ bce_best = bce
246
+ with open("predictions_{}.json".format(args.fold), "w") as f:
247
+ json.dump({"probs": probs, "targets": targets}, f)
248
+ torch.save({
249
+ 'epoch': current_epoch + 1,
250
+ 'state_dict': model.state_dict(),
251
+ 'bce_best': bce_best,
252
+ }, args.output_dir + snapshot_name + "_last")
253
+ print("Epoch: {} bce: {}, bce_best: {}".format(current_epoch, bce, bce_best))
254
+ return bce_best
255
+
256
+
257
+ def validate(net, data_loader, prefix=""):
258
+ probs = defaultdict(list)
259
+ targets = defaultdict(list)
260
+
261
+ with torch.no_grad():
262
+ for sample in tqdm(data_loader):
263
+ imgs = sample["image"]
264
+ img_names = sample["img_name"]
265
+ labels = sample["labels"].float()
266
+ out = net(imgs)
267
+ labels = labels.cpu().numpy()
268
+ preds = torch.sigmoid(out).cpu().numpy()
269
+ for i in range(out.shape[0]):
270
+ video, img_id = img_names[i].split("/")
271
+ probs[video].append(preds[i].tolist())
272
+ targets[video].append(labels[i].tolist())
273
+ data_x = []
274
+ data_y = []
275
+ for vid, score in probs.items():
276
+ score = np.array(score)
277
+ lbl = targets[vid]
278
+
279
+ score = np.mean(score)
280
+ lbl = np.mean(lbl)
281
+ data_x.append(score)
282
+ data_y.append(lbl)
283
+ y = np.array(data_y)
284
+ x = np.array(data_x)
285
+ fake_idx = y > 0.1
286
+ real_idx = y < 0.1
287
+ fake_loss = log_loss(y[fake_idx], x[fake_idx], labels=[0, 1])
288
+ real_loss = log_loss(y[real_idx], x[real_idx], labels=[0, 1])
289
+ print("{}fake_loss".format(prefix), fake_loss)
290
+ print("{}real_loss".format(prefix), real_loss)
291
+
292
+ return (fake_loss + real_loss) / 2, probs, targets
293
+
294
+
295
+ def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
296
+ local_rank, only_valid):
297
+ losses = AverageMeter()
298
+ fake_losses = AverageMeter()
299
+ real_losses = AverageMeter()
300
+ max_iters = conf["batches_per_epoch"]
301
+ print("training epoch {}".format(current_epoch))
302
+ model.train()
303
+ pbar = tqdm(enumerate(train_data_loader), total=max_iters, desc="Epoch {}".format(current_epoch), ncols=0)
304
+ if conf["optimizer"]["schedule"]["mode"] == "epoch":
305
+ scheduler.step(current_epoch)
306
+ for i, sample in pbar:
307
+ imgs = sample["image"]
308
+ labels = sample["labels"].float()
309
+ out_labels = model(imgs)
310
+ if only_valid:
311
+ valid_idx = sample["valid"].float() > 0
312
+ out_labels = out_labels[valid_idx]
313
+ labels = labels[valid_idx]
314
+ if labels.size(0) == 0:
315
+ continue
316
+
317
+ fake_loss = 0
318
+ real_loss = 0
319
+ fake_idx = labels > 0.5
320
+ real_idx = labels <= 0.5
321
+
322
+ ohem = conf.get("ohem_samples", None)
323
+ if torch.sum(fake_idx * 1) > 0:
324
+ fake_loss = loss_functions["classifier_loss"](out_labels[fake_idx], labels[fake_idx])
325
+ if torch.sum(real_idx * 1) > 0:
326
+ real_loss = loss_functions["classifier_loss"](out_labels[real_idx], labels[real_idx])
327
+ if ohem:
328
+ fake_loss = topk(fake_loss, k=min(ohem, fake_loss.size(0)), sorted=False)[0].mean()
329
+ real_loss = topk(real_loss, k=min(ohem, real_loss.size(0)), sorted=False)[0].mean()
330
+
331
+ loss = (fake_loss + real_loss) / 2
332
+ losses.update(loss.item(), imgs.size(0))
333
+ fake_losses.update(0 if fake_loss == 0 else fake_loss.item(), imgs.size(0))
334
+ real_losses.update(0 if real_loss == 0 else real_loss.item(), imgs.size(0))
335
+
336
+ optimizer.zero_grad()
337
+ pbar.set_postfix({"lr": float(scheduler.get_lr()[-1]), "epoch": current_epoch, "loss": losses.avg,
338
+ "fake_loss": fake_losses.avg, "real_loss": real_losses.avg})
339
+
340
+ if conf['fp16']:
341
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
342
+ scaled_loss.backward()
343
+ else:
344
+ loss.backward()
345
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
346
+ optimizer.step()
347
+ if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"):
348
+ scheduler.step(i + current_epoch * max_iters)
349
+ if i == max_iters - 1:
350
+ break
351
+ pbar.close()
352
+ if local_rank == 0:
353
+ for idx, param_group in enumerate(optimizer.param_groups):
354
+ lr = param_group['lr']
355
+ summary_writer.add_scalar('group{}/lr'.format(idx), float(lr), global_step=current_epoch)
356
+ summary_writer.add_scalar('train/loss', float(losses.avg), global_step=current_epoch)
357
+
358
+
359
+ if __name__ == '__main__':
360
+ main()
training/tools/__init__.py ADDED
File without changes
training/tools/config.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ DEFAULTS = {
4
+ "network": "dpn",
5
+ "encoder": "dpn92",
6
+ "model_params": {},
7
+ "optimizer": {
8
+ "batch_size": 32,
9
+ "type": "SGD", # supported: SGD, Adam
10
+ "momentum": 0.9,
11
+ "weight_decay": 0,
12
+ "clip": 1.,
13
+ "learning_rate": 0.1,
14
+ "classifier_lr": -1,
15
+ "nesterov": True,
16
+ "schedule": {
17
+ "type": "constant", # supported: constant, step, multistep, exponential, linear, poly
18
+ "mode": "epoch", # supported: epoch, step
19
+ "epochs": 10,
20
+ "params": {}
21
+ }
22
+ },
23
+ "normalize": {
24
+ "mean": [0.485, 0.456, 0.406],
25
+ "std": [0.229, 0.224, 0.225]
26
+ }
27
+ }
28
+
29
+
30
+ def _merge(src, dst):
31
+ for k, v in src.items():
32
+ if k in dst:
33
+ if isinstance(v, dict):
34
+ _merge(src[k], dst[k])
35
+ else:
36
+ dst[k] = v
37
+
38
+
39
+ def load_config(config_file, defaults=DEFAULTS):
40
+ with open(config_file, "r") as fd:
41
+ config = json.load(fd)
42
+ _merge(defaults, config)
43
+ return config
training/tools/schedulers.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bisect import bisect_right
2
+
3
+ from torch.optim.lr_scheduler import _LRScheduler
4
+
5
+
6
+ class LRStepScheduler(_LRScheduler):
7
+ def __init__(self, optimizer, steps, last_epoch=-1):
8
+ self.lr_steps = steps
9
+ super().__init__(optimizer, last_epoch)
10
+
11
+ def get_lr(self):
12
+ pos = max(bisect_right([x for x, y in self.lr_steps], self.last_epoch) - 1, 0)
13
+ return [self.lr_steps[pos][1] if self.lr_steps[pos][0] <= self.last_epoch else base_lr for base_lr in self.base_lrs]
14
+
15
+
16
+ class PolyLR(_LRScheduler):
17
+ """Sets the learning rate of each parameter group according to poly learning rate policy
18
+ """
19
+ def __init__(self, optimizer, max_iter=90000, power=0.9, last_epoch=-1):
20
+ self.max_iter = max_iter
21
+ self.power = power
22
+ super(PolyLR, self).__init__(optimizer, last_epoch)
23
+
24
+ def get_lr(self):
25
+ self.last_epoch = (self.last_epoch + 1) % self.max_iter
26
+ return [base_lr * ((1 - float(self.last_epoch) / self.max_iter) ** (self.power)) for base_lr in self.base_lrs]
27
+
28
+ class ExponentialLRScheduler(_LRScheduler):
29
+ """Decays the learning rate of each parameter group by gamma every epoch.
30
+ When last_epoch=-1, sets initial lr as lr.
31
+
32
+ Args:
33
+ optimizer (Optimizer): Wrapped optimizer.
34
+ gamma (float): Multiplicative factor of learning rate decay.
35
+ last_epoch (int): The index of last epoch. Default: -1.
36
+ """
37
+
38
+ def __init__(self, optimizer, gamma, last_epoch=-1):
39
+ self.gamma = gamma
40
+ super(ExponentialLRScheduler, self).__init__(optimizer, last_epoch)
41
+
42
+ def get_lr(self):
43
+ if self.last_epoch <= 0:
44
+ return self.base_lrs
45
+ return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]
46
+
training/tools/utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from apex.optimizers import FusedAdam, FusedSGD
3
+ from timm.optim import AdamW
4
+ from torch import optim
5
+ from torch.optim import lr_scheduler
6
+ from torch.optim.rmsprop import RMSprop
7
+ from torch.optim.adamw import AdamW
8
+ from torch.optim.lr_scheduler import MultiStepLR, CyclicLR
9
+
10
+ from training.tools.schedulers import ExponentialLRScheduler, PolyLR, LRStepScheduler
11
+
12
+ cv2.ocl.setUseOpenCL(False)
13
+ cv2.setNumThreads(0)
14
+
15
+
16
+ class AverageMeter(object):
17
+ """Computes and stores the average and current value"""
18
+
19
+ def __init__(self):
20
+ self.reset()
21
+
22
+ def reset(self):
23
+ self.val = 0
24
+ self.avg = 0
25
+ self.sum = 0
26
+ self.count = 0
27
+
28
+ def update(self, val, n=1):
29
+ self.val = val
30
+ self.sum += val * n
31
+ self.count += n
32
+ self.avg = self.sum / self.count
33
+
34
+ def create_optimizer(optimizer_config, model, master_params=None):
35
+ """Creates optimizer and schedule from configuration
36
+
37
+ Parameters
38
+ ----------
39
+ optimizer_config : dict
40
+ Dictionary containing the configuration options for the optimizer.
41
+ model : Model
42
+ The network model.
43
+
44
+ Returns
45
+ -------
46
+ optimizer : Optimizer
47
+ The optimizer.
48
+ scheduler : LRScheduler
49
+ The learning rate scheduler.
50
+ """
51
+ if optimizer_config.get("classifier_lr", -1) != -1:
52
+ # Separate classifier parameters from all others
53
+ net_params = []
54
+ classifier_params = []
55
+ for k, v in model.named_parameters():
56
+ if not v.requires_grad:
57
+ continue
58
+ if k.find("encoder") != -1:
59
+ net_params.append(v)
60
+ else:
61
+ classifier_params.append(v)
62
+ params = [
63
+ {"params": net_params},
64
+ {"params": classifier_params, "lr": optimizer_config["classifier_lr"]},
65
+ ]
66
+ else:
67
+ if master_params:
68
+ params = master_params
69
+ else:
70
+ params = model.parameters()
71
+
72
+ if optimizer_config["type"] == "SGD":
73
+ optimizer = optim.SGD(params,
74
+ lr=optimizer_config["learning_rate"],
75
+ momentum=optimizer_config["momentum"],
76
+ weight_decay=optimizer_config["weight_decay"],
77
+ nesterov=optimizer_config["nesterov"])
78
+ elif optimizer_config["type"] == "FusedSGD":
79
+ optimizer = FusedSGD(params,
80
+ lr=optimizer_config["learning_rate"],
81
+ momentum=optimizer_config["momentum"],
82
+ weight_decay=optimizer_config["weight_decay"],
83
+ nesterov=optimizer_config["nesterov"])
84
+ elif optimizer_config["type"] == "Adam":
85
+ optimizer = optim.Adam(params,
86
+ lr=optimizer_config["learning_rate"],
87
+ weight_decay=optimizer_config["weight_decay"])
88
+ elif optimizer_config["type"] == "FusedAdam":
89
+ optimizer = FusedAdam(params,
90
+ lr=optimizer_config["learning_rate"],
91
+ weight_decay=optimizer_config["weight_decay"])
92
+ elif optimizer_config["type"] == "AdamW":
93
+ optimizer = AdamW(params,
94
+ lr=optimizer_config["learning_rate"],
95
+ weight_decay=optimizer_config["weight_decay"])
96
+ elif optimizer_config["type"] == "RmsProp":
97
+ optimizer = RMSprop(params,
98
+ lr=optimizer_config["learning_rate"],
99
+ weight_decay=optimizer_config["weight_decay"])
100
+ else:
101
+ raise KeyError("unrecognized optimizer {}".format(optimizer_config["type"]))
102
+
103
+ if optimizer_config["schedule"]["type"] == "step":
104
+ scheduler = LRStepScheduler(optimizer, **optimizer_config["schedule"]["params"])
105
+ elif optimizer_config["schedule"]["type"] == "clr":
106
+ scheduler = CyclicLR(optimizer, **optimizer_config["schedule"]["params"])
107
+ elif optimizer_config["schedule"]["type"] == "multistep":
108
+ scheduler = MultiStepLR(optimizer, **optimizer_config["schedule"]["params"])
109
+ elif optimizer_config["schedule"]["type"] == "exponential":
110
+ scheduler = ExponentialLRScheduler(optimizer, **optimizer_config["schedule"]["params"])
111
+ elif optimizer_config["schedule"]["type"] == "poly":
112
+ scheduler = PolyLR(optimizer, **optimizer_config["schedule"]["params"])
113
+ elif optimizer_config["schedule"]["type"] == "constant":
114
+ scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 1.0)
115
+ elif optimizer_config["schedule"]["type"] == "linear":
116
+ def linear_lr(it):
117
+ return it * optimizer_config["schedule"]["params"]["alpha"] + optimizer_config["schedule"]["params"]["beta"]
118
+
119
+ scheduler = lr_scheduler.LambdaLR(optimizer, linear_lr)
120
+
121
+ return optimizer, scheduler
training/transforms/__init__.py ADDED
File without changes
training/transforms/albu.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from albumentations import DualTransform, ImageOnlyTransform
6
+ from albumentations.augmentations.functional import crop
7
+
8
+
9
+ def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
10
+ h, w = img.shape[:2]
11
+ if max(w, h) == size:
12
+ return img
13
+ if w > h:
14
+ scale = size / w
15
+ h = h * scale
16
+ w = size
17
+ else:
18
+ scale = size / h
19
+ w = w * scale
20
+ h = size
21
+ interpolation = interpolation_up if scale > 1 else interpolation_down
22
+ resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
23
+ return resized
24
+
25
+
26
+ class IsotropicResize(DualTransform):
27
+ def __init__(self, max_side, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC,
28
+ always_apply=False, p=1):
29
+ super(IsotropicResize, self).__init__(always_apply, p)
30
+ self.max_side = max_side
31
+ self.interpolation_down = interpolation_down
32
+ self.interpolation_up = interpolation_up
33
+
34
+ def apply(self, img, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, **params):
35
+ return isotropically_resize_image(img, size=self.max_side, interpolation_down=interpolation_down,
36
+ interpolation_up=interpolation_up)
37
+
38
+ def apply_to_mask(self, img, **params):
39
+ return self.apply(img, interpolation_down=cv2.INTER_NEAREST, interpolation_up=cv2.INTER_NEAREST, **params)
40
+
41
+ def get_transform_init_args_names(self):
42
+ return ("max_side", "interpolation_down", "interpolation_up")
43
+
44
+
45
+ class Resize4xAndBack(ImageOnlyTransform):
46
+ def __init__(self, always_apply=False, p=0.5):
47
+ super(Resize4xAndBack, self).__init__(always_apply, p)
48
+
49
+ def apply(self, img, **params):
50
+ h, w = img.shape[:2]
51
+ scale = random.choice([2, 4])
52
+ img = cv2.resize(img, (w // scale, h // scale), interpolation=cv2.INTER_AREA)
53
+ img = cv2.resize(img, (w, h),
54
+ interpolation=random.choice([cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_NEAREST]))
55
+ return img
56
+
57
+
58
+ class RandomSizedCropNonEmptyMaskIfExists(DualTransform):
59
+
60
+ def __init__(self, min_max_height, w2h_ratio=[0.7, 1.3], always_apply=False, p=0.5):
61
+ super(RandomSizedCropNonEmptyMaskIfExists, self).__init__(always_apply, p)
62
+
63
+ self.min_max_height = min_max_height
64
+ self.w2h_ratio = w2h_ratio
65
+
66
+ def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params):
67
+ cropped = crop(img, x_min, y_min, x_max, y_max)
68
+ return cropped
69
+
70
+ @property
71
+ def targets_as_params(self):
72
+ return ["mask"]
73
+
74
+ def get_params_dependent_on_targets(self, params):
75
+ mask = params["mask"]
76
+ mask_height, mask_width = mask.shape[:2]
77
+ crop_height = int(mask_height * random.uniform(self.min_max_height[0], self.min_max_height[1]))
78
+ w2h_ratio = random.uniform(*self.w2h_ratio)
79
+ crop_width = min(int(crop_height * w2h_ratio), mask_width - 1)
80
+ if mask.sum() == 0:
81
+ x_min = random.randint(0, mask_width - crop_width + 1)
82
+ y_min = random.randint(0, mask_height - crop_height + 1)
83
+ else:
84
+ mask = mask.sum(axis=-1) if mask.ndim == 3 else mask
85
+ non_zero_yx = np.argwhere(mask)
86
+ y, x = random.choice(non_zero_yx)
87
+ x_min = x - random.randint(0, crop_width - 1)
88
+ y_min = y - random.randint(0, crop_height - 1)
89
+ x_min = np.clip(x_min, 0, mask_width - crop_width)
90
+ y_min = np.clip(y_min, 0, mask_height - crop_height)
91
+
92
+ x_max = x_min + crop_height
93
+ y_max = y_min + crop_width
94
+ y_max = min(mask_height, y_max)
95
+ x_max = min(mask_width, x_max)
96
+ return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}
97
+
98
+ def get_transform_init_args_names(self):
99
+ return "min_max_height", "height", "width", "w2h_ratio"
training/zoo/__init__.py ADDED
File without changes
training/zoo/classifiers.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ import torch
5
+ from timm.models.efficientnet import tf_efficientnet_b4_ns, tf_efficientnet_b3_ns, \
6
+ tf_efficientnet_b5_ns, tf_efficientnet_b2_ns, tf_efficientnet_b6_ns, tf_efficientnet_b7_ns
7
+ from torch import nn
8
+ from torch.nn.modules.dropout import Dropout
9
+ from torch.nn.modules.linear import Linear
10
+ from torch.nn.modules.pooling import AdaptiveAvgPool2d
11
+
12
+ encoder_params = {
13
+ "tf_efficientnet_b3_ns": {
14
+ "features": 1536,
15
+ "init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
16
+ },
17
+ "tf_efficientnet_b2_ns": {
18
+ "features": 1408,
19
+ "init_op": partial(tf_efficientnet_b2_ns, pretrained=False, drop_path_rate=0.2)
20
+ },
21
+ "tf_efficientnet_b4_ns": {
22
+ "features": 1792,
23
+ "init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.5)
24
+ },
25
+ "tf_efficientnet_b5_ns": {
26
+ "features": 2048,
27
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
28
+ },
29
+ "tf_efficientnet_b4_ns_03d": {
30
+ "features": 1792,
31
+ "init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.3)
32
+ },
33
+ "tf_efficientnet_b5_ns_03d": {
34
+ "features": 2048,
35
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.3)
36
+ },
37
+ "tf_efficientnet_b5_ns_04d": {
38
+ "features": 2048,
39
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.4)
40
+ },
41
+ "tf_efficientnet_b6_ns": {
42
+ "features": 2304,
43
+ "init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.2)
44
+ },
45
+ "tf_efficientnet_b7_ns": {
46
+ "features": 2560,
47
+ "init_op": partial(tf_efficientnet_b7_ns, pretrained=True, drop_path_rate=0.2)
48
+ },
49
+ "tf_efficientnet_b6_ns_04d": {
50
+ "features": 2304,
51
+ "init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.4)
52
+ },
53
+ }
54
+
55
+
56
+ def setup_srm_weights(input_channels: int = 3) -> torch.Tensor:
57
+ """Creates the SRM kernels for noise analysis."""
58
+ # note: values taken from Zhou et al., "Learning Rich Features for Image Manipulation Detection", CVPR2018
59
+ srm_kernel = torch.from_numpy(np.array([
60
+ [ # srm 1/2 horiz
61
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
62
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
63
+ [0., 1., -2., 1., 0.], # noqa: E241,E201
64
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
65
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
66
+ ], [ # srm 1/4
67
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
68
+ [0., -1., 2., -1., 0.], # noqa: E241,E201
69
+ [0., 2., -4., 2., 0.], # noqa: E241,E201
70
+ [0., -1., 2., -1., 0.], # noqa: E241,E201
71
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
72
+ ], [ # srm 1/12
73
+ [-1., 2., -2., 2., -1.], # noqa: E241,E201
74
+ [2., -6., 8., -6., 2.], # noqa: E241,E201
75
+ [-2., 8., -12., 8., -2.], # noqa: E241,E201
76
+ [2., -6., 8., -6., 2.], # noqa: E241,E201
77
+ [-1., 2., -2., 2., -1.], # noqa: E241,E201
78
+ ]
79
+ ])).float()
80
+ srm_kernel[0] /= 2
81
+ srm_kernel[1] /= 4
82
+ srm_kernel[2] /= 12
83
+ return srm_kernel.view(3, 1, 5, 5).repeat(1, input_channels, 1, 1)
84
+
85
+
86
+ def setup_srm_layer(input_channels: int = 3) -> torch.nn.Module:
87
+ """Creates a SRM convolution layer for noise analysis."""
88
+ weights = setup_srm_weights(input_channels)
89
+ conv = torch.nn.Conv2d(input_channels, out_channels=3, kernel_size=5, stride=1, padding=2, bias=False)
90
+ with torch.no_grad():
91
+ conv.weight = torch.nn.Parameter(weights, requires_grad=False)
92
+ return conv
93
+
94
+
95
+ class DeepFakeClassifierSRM(nn.Module):
96
+ def __init__(self, encoder, dropout_rate=0.5) -> None:
97
+ super().__init__()
98
+ self.encoder = encoder_params[encoder]["init_op"]()
99
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
100
+ self.srm_conv = setup_srm_layer(3)
101
+ self.dropout = Dropout(dropout_rate)
102
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
103
+
104
+ def forward(self, x):
105
+ noise = self.srm_conv(x)
106
+ x = self.encoder.forward_features(noise)
107
+ x = self.avg_pool(x).flatten(1)
108
+ x = self.dropout(x)
109
+ x = self.fc(x)
110
+ return x
111
+
112
+
113
+ class GlobalWeightedAvgPool2d(nn.Module):
114
+ """
115
+ Global Weighted Average Pooling from paper "Global Weighted Average
116
+ Pooling Bridges Pixel-level Localization and Image-level Classification"
117
+ """
118
+
119
+ def __init__(self, features: int, flatten=False):
120
+ super().__init__()
121
+ self.conv = nn.Conv2d(features, 1, kernel_size=1, bias=True)
122
+ self.flatten = flatten
123
+
124
+ def fscore(self, x):
125
+ m = self.conv(x)
126
+ m = m.sigmoid().exp()
127
+ return m
128
+
129
+ def norm(self, x: torch.Tensor):
130
+ return x / x.sum(dim=[2, 3], keepdim=True)
131
+
132
+ def forward(self, x):
133
+ input_x = x
134
+ x = self.fscore(x)
135
+ x = self.norm(x)
136
+ x = x * input_x
137
+ x = x.sum(dim=[2, 3], keepdim=not self.flatten)
138
+ return x
139
+
140
+
141
+ class DeepFakeClassifier(nn.Module):
142
+ def __init__(self, encoder, dropout_rate=0.0) -> None:
143
+ super().__init__()
144
+ self.encoder = encoder_params[encoder]["init_op"]()
145
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
146
+ self.dropout = Dropout(dropout_rate)
147
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
148
+
149
+ def forward(self, x):
150
+ x = self.encoder.forward_features(x)
151
+ x = self.avg_pool(x).flatten(1)
152
+ x = self.dropout(x)
153
+ x = self.fc(x)
154
+ return x
155
+
156
+
157
+
158
+
159
+ class DeepFakeClassifierGWAP(nn.Module):
160
+ def __init__(self, encoder, dropout_rate=0.5) -> None:
161
+ super().__init__()
162
+ self.encoder = encoder_params[encoder]["init_op"]()
163
+ self.avg_pool = GlobalWeightedAvgPool2d(encoder_params[encoder]["features"])
164
+ self.dropout = Dropout(dropout_rate)
165
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
166
+
167
+ def forward(self, x):
168
+ x = self.encoder.forward_features(x)
169
+ x = self.avg_pool(x).flatten(1)
170
+ x = self.dropout(x)
171
+ x = self.fc(x)
172
+ return x
training/zoo/unet.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from timm.models.efficientnet import tf_efficientnet_b3_ns, tf_efficientnet_b5_ns
5
+ from torch import nn
6
+ from torch.nn import Dropout2d, Conv2d
7
+ from torch.nn.modules.dropout import Dropout
8
+ from torch.nn.modules.linear import Linear
9
+ from torch.nn.modules.pooling import AdaptiveAvgPool2d
10
+ from torch.nn.modules.upsampling import UpsamplingBilinear2d
11
+
12
+ encoder_params = {
13
+ "tf_efficientnet_b3_ns": {
14
+ "features": 1536,
15
+ "filters": [40, 32, 48, 136, 1536],
16
+ "decoder_filters": [64, 128, 256, 256],
17
+ "init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
18
+ },
19
+ "tf_efficientnet_b5_ns": {
20
+ "features": 2048,
21
+ "filters": [48, 40, 64, 176, 2048],
22
+ "decoder_filters": [64, 128, 256, 256],
23
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
24
+ },
25
+ }
26
+
27
+
28
+ class DecoderBlock(nn.Module):
29
+ def __init__(self, in_channels, out_channels):
30
+ super().__init__()
31
+ self.layer = nn.Sequential(
32
+ nn.Upsample(scale_factor=2),
33
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
34
+ nn.ReLU(inplace=True)
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.layer(x)
39
+
40
+
41
+ class ConcatBottleneck(nn.Module):
42
+ def __init__(self, in_channels, out_channels):
43
+ super().__init__()
44
+ self.seq = nn.Sequential(
45
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
46
+ nn.ReLU(inplace=True)
47
+ )
48
+
49
+ def forward(self, dec, enc):
50
+ x = torch.cat([dec, enc], dim=1)
51
+ return self.seq(x)
52
+
53
+
54
+ class Decoder(nn.Module):
55
+ def __init__(self, decoder_filters, filters, upsample_filters=None,
56
+ decoder_block=DecoderBlock, bottleneck=ConcatBottleneck, dropout=0):
57
+ super().__init__()
58
+ self.decoder_filters = decoder_filters
59
+ self.filters = filters
60
+ self.decoder_block = decoder_block
61
+ self.decoder_stages = nn.ModuleList([self._get_decoder(idx) for idx in range(0, len(decoder_filters))])
62
+ self.bottlenecks = nn.ModuleList([bottleneck(self.filters[-i - 2] + f, f)
63
+ for i, f in enumerate(reversed(decoder_filters))])
64
+ self.dropout = Dropout2d(dropout) if dropout > 0 else None
65
+ self.last_block = None
66
+ if upsample_filters:
67
+ self.last_block = decoder_block(decoder_filters[0], out_channels=upsample_filters)
68
+ else:
69
+ self.last_block = UpsamplingBilinear2d(scale_factor=2)
70
+
71
+ def forward(self, encoder_results: list):
72
+ x = encoder_results[0]
73
+ bottlenecks = self.bottlenecks
74
+ for idx, bottleneck in enumerate(bottlenecks):
75
+ rev_idx = - (idx + 1)
76
+ x = self.decoder_stages[rev_idx](x)
77
+ x = bottleneck(x, encoder_results[-rev_idx])
78
+ if self.last_block:
79
+ x = self.last_block(x)
80
+ if self.dropout:
81
+ x = self.dropout(x)
82
+ return x
83
+
84
+ def _get_decoder(self, layer):
85
+ idx = layer + 1
86
+ if idx == len(self.decoder_filters):
87
+ in_channels = self.filters[idx]
88
+ else:
89
+ in_channels = self.decoder_filters[idx]
90
+ return self.decoder_block(in_channels, self.decoder_filters[max(layer, 0)])
91
+
92
+
93
+ def _initialize_weights(module):
94
+ for m in module.modules():
95
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
96
+ m.weight.data = nn.init.kaiming_normal_(m.weight.data)
97
+ if m.bias is not None:
98
+ m.bias.data.zero_()
99
+ elif isinstance(m, nn.BatchNorm2d):
100
+ m.weight.data.fill_(1)
101
+ m.bias.data.zero_()
102
+
103
+
104
+ class EfficientUnetClassifier(nn.Module):
105
+ def __init__(self, encoder, dropout_rate=0.5) -> None:
106
+ super().__init__()
107
+ self.decoder = Decoder(decoder_filters=encoder_params[encoder]["decoder_filters"],
108
+ filters=encoder_params[encoder]["filters"])
109
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
110
+ self.dropout = Dropout(dropout_rate)
111
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
112
+ self.final = Conv2d(encoder_params[encoder]["decoder_filters"][0], out_channels=1, kernel_size=1, bias=False)
113
+ _initialize_weights(self)
114
+ self.encoder = encoder_params[encoder]["init_op"]()
115
+
116
+ def get_encoder_features(self, x):
117
+ encoder_results = []
118
+ x = self.encoder.conv_stem(x)
119
+ x = self.encoder.bn1(x)
120
+ x = self.encoder.act1(x)
121
+ encoder_results.append(x)
122
+ x = self.encoder.blocks[:2](x)
123
+ encoder_results.append(x)
124
+ x = self.encoder.blocks[2:3](x)
125
+ encoder_results.append(x)
126
+ x = self.encoder.blocks[3:5](x)
127
+ encoder_results.append(x)
128
+ x = self.encoder.blocks[5:](x)
129
+ x = self.encoder.conv_head(x)
130
+ x = self.encoder.bn2(x)
131
+ x = self.encoder.act2(x)
132
+ encoder_results.append(x)
133
+ encoder_results = list(reversed(encoder_results))
134
+ return encoder_results
135
+
136
+ def forward(self, x):
137
+ encoder_results = self.get_encoder_features(x)
138
+ seg = self.final(self.decoder(encoder_results))
139
+ x = encoder_results[0]
140
+ x = self.avg_pool(x).flatten(1)
141
+ x = self.dropout(x)
142
+ x = self.fc(x)
143
+ return x, seg
144
+
145
+
146
+ if __name__ == '__main__':
147
+ model = EfficientUnetClassifier("tf_efficientnet_b5_ns")
148
+ model.eval()
149
+ with torch.no_grad():
150
+ input = torch.rand(4, 3, 224, 224)
151
+ print(model(input))
weights/best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:849036dc211387420412feb68c7451e98a948072b94e213e1104e2f2bf7791ad
3
+ size 266910615