Sara Mandelli commited on
Commit
6bd8735
1 Parent(s): f6b58ff

Update detector

Browse files
gan_vs_real_detector.py CHANGED
@@ -11,39 +11,14 @@ torch.multiprocessing.set_sharing_strategy('file_system')
11
  import albumentations as A
12
  import albumentations.pytorch as Ap
13
  from utils import architectures
 
14
  from PIL import Image
15
 
16
 
17
  class Detector:
18
  def __init__(self):
19
 
20
- # model directory and path for detector A
21
- # model_A_dir = 'weights/method_A/net-EfficientNetB4_lr-0.001_img_aug-[\'flip\', \'rotate\', \'clahe\', \'blur\', ' \
22
- # '\'brightness&contrast\', \'jitter\', \'downscale\', \'hsv\', \'resize\', \'jpeg\']' \
23
- # '_img_aug_p-0.5_patch_size-128_patch_number-1_batch_size-250_num_classes-2'
24
- #
25
- # # model directory and path for detector B
26
- # model_B_dir = 'weights/method_B/net-EfficientNetB4_lr-0.001_aug-[\'flip\', \'rotate\', \'clahe\', \'blur\', ' \
27
- # '\'crop&resize\', \'brightness&contrast\', \'jitter\', \'downscale\', \'hsv\']' \
28
- # '_aug_p-0.5_jpeg_aug_p-0.7_patch_size-128_patch_number-1_batch_size-250_num_classes-2'
29
- #
30
- # # model directory and path for detector C
31
- # model_C_dir = 'weights/method_C/net-EfficientNetB4_lr-0.001_aug-[\'flip\', \'rotate\', \'clahe\', \'blur\',' \
32
- # ' \'crop&resize\', \'brightness&contrast\', \'jitter\', \'downscale\', \'hsv\']' \
33
- # '_aug_p-0.5_jpeg_aug_p-0_patch_size-128_patch_number-5_batch_size-50_num_classes-2'
34
- #
35
- # # model directory and path for detector D
36
- # model_D_dir = 'weights/method_D/net-EfficientNetB4_lr-0.001_aug-[\'flip\', \'rotate\', \'clahe\', \'blur\',' \
37
- # '\'crop&resize\', \'brightness&contrast\', \'jitter\', \'downscale\', \'hsv\']' \
38
- # '_aug_p-0.5_jpeg_aug_p-0_patch_size-128_patch_number-10_batch_size-25_num_classes-2'
39
- #
40
- # # model directory for detector E
41
- # model_E_dir = 'weights/method_E/net-EfficientNetB4_lr-0.001_aug-[\'flip\', \'rotate\', \'clahe\', \'blur\',' \
42
- # ' \'crop&resize\', \'brightness&contrast\', \'jitter\', \'downscale\', \'hsv\']' \
43
- # '_aug_p-0.5_jpeg_aug_p-0.7_patch_size-128_patch_number-1_batch_size-250_num_classes-2'
44
-
45
- self.weights_path_list = [os.path.join('weights', f'method_{x}.pth') for x in 'ABCDE']
46
- # self.model_path = os.path.join(model_dir, 'bestval.pth')
47
 
48
  # GPU configuration if available
49
  self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
@@ -72,17 +47,15 @@ class Detector:
72
  Ap.transforms.ToTensorV2()
73
  ]
74
  self.trans = A.Compose(transform)
75
-
76
  self.cropper = A.RandomCrop(width=128, height=128, always_apply=True, p=1.)
77
-
78
  self.criterion = torch.nn.CrossEntropyLoss(reduction='none')
79
 
80
- def synth_real_detector(self, img_path: str, n_patch: int = 50):
81
 
82
  # Load image:
83
  img = np.asarray(Image.open(img_path))
84
 
85
- # Optout if image is non conforming
86
  if img.shape == ():
87
  print('{} None dimension'.format(img_path))
88
  return None
@@ -96,47 +69,52 @@ class Detector:
96
  print('Omitting alpha channel')
97
  img = img[:, :, :3]
98
 
99
- # Extract test_N random patches from image:
100
- patch_list = [self.cropper(image=img)['image'] for _ in range(n_patch)]
101
 
102
- # Normalization
103
- transf_patch_list = [self.trans(image=patch)['image'] for patch in patch_list]
104
 
105
- # Compute scores
106
- transf_patch_tensor = torch.stack(transf_patch_list, dim=0).to(self.device)
107
- with torch.no_grad():
108
- patch_scores = self.net(transf_patch_tensor)
109
- softmax_scores = torch.softmax(patch_scores, dim=1)
110
- predictions = torch.argmax(softmax_scores, dim=1)
111
 
112
- # Majority voting on patches
113
- if sum(predictions) > len(predictions) // 2:
114
- majority_voting = 1
115
- else:
116
- majority_voting = 0
 
 
117
 
118
- # get an output score from softmax scores:
119
- # LLR < 0: real
120
- # LLR > 0: synthetic
121
 
122
- sign_predictions = majority_voting * 2 - 1
123
- # select only the scores associated with the estimated class (by majority voting)
124
- softmax_scores = softmax_scores[:, majority_voting]
125
- normalized_prediction = torch.max(softmax_scores).item() * sign_predictions
 
126
 
127
- return normalized_prediction
 
 
 
 
 
 
 
128
 
129
 
130
  def main():
131
- # img_path
132
- img_path = "/nas/public/exchange/semafor/eval1/stylegan2/100k-generated-images/car-512x384_cropped/stylegan2-" \
133
- "config-f-psi-0.5/097000/097001.png"
134
 
135
- # number of random patches to extract from images
136
- test_N = 50
137
 
138
  detector = Detector()
139
- detector.synth_real_detector(img_path, test_N)
 
 
140
 
141
  return 0
142
 
 
11
  import albumentations as A
12
  import albumentations.pytorch as Ap
13
  from utils import architectures
14
+ from utils.python_patch_extractor.PatchExtractor import PatchExtractor
15
  from PIL import Image
16
 
17
 
18
  class Detector:
19
  def __init__(self):
20
 
21
+ self.weights_path_list = [os.path.join('/nas/home/nbonettini/projects/StyleGAN3-detection/weights', f'method_{x}.pth') for x in 'ABCDE']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # GPU configuration if available
24
  self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
 
47
  Ap.transforms.ToTensorV2()
48
  ]
49
  self.trans = A.Compose(transform)
 
50
  self.cropper = A.RandomCrop(width=128, height=128, always_apply=True, p=1.)
 
51
  self.criterion = torch.nn.CrossEntropyLoss(reduction='none')
52
 
53
+ def synth_real_detector(self, img_path: str, n_patch: int = 200):
54
 
55
  # Load image:
56
  img = np.asarray(Image.open(img_path))
57
 
58
+ # Opt-out if image is non conforming
59
  if img.shape == ():
60
  print('{} None dimension'.format(img_path))
61
  return None
 
69
  print('Omitting alpha channel')
70
  img = img[:, :, :3]
71
 
72
+ img_net_scores = []
73
+ for net_idx, net in enumerate(self.nets):
74
 
75
+ if net_idx == 0:
 
76
 
77
+ # only for detector A, extract N = 200 random patches per image
78
+ patch_list = [self.cropper(image=img)['image'] for _ in range(n_patch)]
79
+
80
+ else:
 
 
81
 
82
+ # for detectors B, C, D, E, extract patches aligned with the 8 x 8 pixel grid:
83
+ # we want more or less 200 patches per img
84
+ stride_0 = ((((img.shape[0] - 128) // 20) + 7) // 8) * 8
85
+ stride_1 = (((img.shape[1] - 128) // 10 + 7) // 8) * 8
86
+ pe = PatchExtractor(dim=(128, 128, 3), stride=(stride_0, stride_1, 3))
87
+ patches = pe.extract(img)
88
+ patch_list = list(patches.reshape((patches.shape[0]*patches.shape[1], 128, 128, 3)))
89
 
90
+ # Normalization
91
+ transf_patch_list = [self.trans(image=patch)['image'] for patch in patch_list]
 
92
 
93
+ # Compute scores
94
+ transf_patch_tensor = torch.stack(transf_patch_list, dim=0).to(self.device)
95
+ with torch.no_grad():
96
+ patch_scores = net(transf_patch_tensor).cpu().numpy()
97
+ patch_predictions = np.argmax(patch_scores, axis=1)
98
 
99
+ maj_voting = np.any(patch_predictions).astype(int)
100
+ scores_maj_voting = patch_scores[:, maj_voting]
101
+ img_net_scores.append(np.nanmax(scores_maj_voting) if maj_voting == 1 else -np.nanmax(scores_maj_voting))
102
+
103
+ # final score is the average among the 5 scores returned by the detectors
104
+ img_score = np.mean(img_net_scores)
105
+
106
+ return img_score
107
 
108
 
109
  def main():
 
 
 
110
 
111
+ # img_path on fermi:
112
+ img_path = '/home/nbonettini/nvidia_temp/nvidia-alias-free-gan/faces/alias-free-r-afhqv2-512x512/seed40000.png'
113
 
114
  detector = Detector()
115
+ score = detector.synth_real_detector(img_path)
116
+
117
+ print('Image Score: {}'.format(score))
118
 
119
  return 0
120
 
utils/architectures.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from torchvision import transforms
5
+ import torch.nn.functional as F
6
+ from efficientnet_pytorch import EfficientNet
7
+ from efficientnet_pytorch.utils import (
8
+ round_filters,
9
+ round_repeats,
10
+ drop_connect,
11
+ get_same_padding_conv2d,
12
+ get_model_params,
13
+ efficientnet_params,
14
+ load_pretrained_weights,
15
+ Swish,
16
+ MemoryEfficientSwish,
17
+ )
18
+ from efficientnet_pytorch.model import MBConvBlock
19
+ from torchvision.models import resnet
20
+ from pytorchcv.model_provider import get_model
21
+
22
+
23
+ class Head(nn.Module):
24
+ def __init__(self, in_f, out_f):
25
+ super(Head, self).__init__()
26
+
27
+ self.f = nn.Flatten()
28
+ self.l = nn.Linear(in_f, 512)
29
+ self.d = nn.Dropout(0.5)
30
+ self.o = nn.Linear(512, out_f)
31
+ self.b1 = nn.BatchNorm1d(in_f)
32
+ self.b2 = nn.BatchNorm1d(512)
33
+ self.r = nn.ReLU()
34
+
35
+ def forward(self, x):
36
+ x = self.f(x)
37
+ x = self.b1(x)
38
+ x = self.d(x)
39
+
40
+ x = self.l(x)
41
+ x = self.r(x)
42
+ x = self.b2(x)
43
+ x = self.d(x)
44
+
45
+ out = self.o(x)
46
+ return out
47
+
48
+
49
+ class FCN(nn.Module):
50
+ def __init__(self, base, in_f, out_f):
51
+ super(FCN, self).__init__()
52
+ self.base = base
53
+ self.h1 = Head(in_f, out_f)
54
+
55
+ def forward(self, x):
56
+ x = self.base(x)
57
+ return self.h1(x)
58
+
59
+
60
+ class BaseFCN(nn.Module):
61
+ def __init__(self, n_classes: int):
62
+ super(BaseFCN, self).__init__()
63
+
64
+ self.f = nn.Flatten()
65
+ self.l = nn.Linear(625, 256)
66
+ self.d = nn.Dropout(0.5)
67
+ self.o = nn.Linear(256, n_classes)
68
+
69
+ def forward(self, x):
70
+ x = self.f(x)
71
+ x = self.l(x)
72
+ x = self.d(x)
73
+ out = self.o(x)
74
+ return out
75
+
76
+ def get_trainable_parameters_cooccur(self):
77
+ return self.parameters()
78
+
79
+
80
+ class BaseFCNHigh(nn.Module):
81
+ def __init__(self, n_classes: int):
82
+ super(BaseFCNHigh, self).__init__()
83
+
84
+ self.f = nn.Flatten()
85
+ self.l = nn.Linear(625, 512)
86
+ self.d = nn.Dropout(0.5)
87
+ self.o = nn.Linear(512, n_classes)
88
+
89
+ def forward(self, x):
90
+ x = self.f(x)
91
+ x = self.l(x)
92
+ x = self.d(x)
93
+ out = self.o(x)
94
+ return out
95
+
96
+ def get_trainable_parameters_cooccur(self):
97
+ return self.parameters()
98
+
99
+
100
+ class BaseFCN4(nn.Module):
101
+ def __init__(self, n_classes: int):
102
+ super(BaseFCN4, self).__init__()
103
+
104
+ self.f = nn.Flatten()
105
+ self.l1 = nn.Linear(625, 512)
106
+ self.l2 = nn.Linear(512, 384)
107
+ self.l3 = nn.Linear(384, 256)
108
+ self.d = nn.Dropout(0.5)
109
+ self.o = nn.Linear(256, n_classes)
110
+
111
+ def forward(self, x):
112
+ x = self.f(x)
113
+ x = self.l1(x)
114
+ x = self.d(x)
115
+ x = self.l2(x)
116
+ x = self.d(x)
117
+ x = self.l3(x)
118
+ x = self.d(x)
119
+ out = self.o(x)
120
+ return out
121
+
122
+ def get_trainable_parameters_cooccur(self):
123
+ return self.parameters()
124
+
125
+
126
+ class BaseFCNBnR(nn.Module):
127
+ def __init__(self, n_classes: int):
128
+ super(BaseFCNBnR, self).__init__()
129
+
130
+ self.f = nn.Flatten()
131
+ self.b1 = nn.BatchNorm1d(625)
132
+ self.b2 = nn.BatchNorm1d(256)
133
+ self.l = nn.Linear(625, 256)
134
+ self.d = nn.Dropout(0.5)
135
+ self.o = nn.Linear(256, n_classes)
136
+ self.r = nn.ReLU()
137
+
138
+ def forward(self, x):
139
+ x = self.f(x)
140
+ x = self.b1(x)
141
+ x = self.d(x)
142
+ x = self.l(x)
143
+ x = self.r(x)
144
+ x = self.b2(x)
145
+ x = self.d(x)
146
+ out = self.o(x)
147
+ return out
148
+
149
+ def get_trainable_parameters_cooccur(self):
150
+ return self.parameters()
151
+
152
+
153
+ def forward_resnet_conv(net, x, upto: int = 4):
154
+ """
155
+ Forward ResNet only in its convolutional part
156
+ :param net:
157
+ :param x:
158
+ :param upto:
159
+ :return:
160
+ """
161
+ x = net.conv1(x) # N / 2
162
+ x = net.bn1(x)
163
+ x = net.relu(x)
164
+ x = net.maxpool(x) # N / 4
165
+
166
+ if upto >= 1:
167
+ x = net.layer1(x) # N / 4
168
+ if upto >= 2:
169
+ x = net.layer2(x) # N / 8
170
+ if upto >= 3:
171
+ x = net.layer3(x) # N / 16
172
+ if upto >= 4:
173
+ x = net.layer4(x) # N / 32
174
+ return x
175
+
176
+
177
+ class FeatureExtractor(nn.Module):
178
+ """
179
+ Abstract class to be extended when supporting features extraction.
180
+ It also provides standard normalized and parameters
181
+ """
182
+
183
+ def features(self, x: torch.Tensor) -> torch.Tensor:
184
+ raise NotImplementedError
185
+
186
+ def get_trainable_parameters(self):
187
+ return self.parameters()
188
+
189
+ @staticmethod
190
+ def get_normalizer():
191
+ return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
192
+
193
+
194
+ class FeatureExtractorGray(nn.Module):
195
+ """
196
+ Abstract class to be extended when supporting features extraction.
197
+ It also provides standard normalized and parameters
198
+ """
199
+
200
+ def features(self, x: torch.Tensor) -> torch.Tensor:
201
+ raise NotImplementedError
202
+
203
+ def get_trainable_parameters(self):
204
+ return self.parameters()
205
+
206
+ @staticmethod
207
+ def get_normalizer():
208
+ return transforms.Normalize(mean=[0.479], std=[0.226])
209
+
210
+
211
+ class EfficientNetGen(FeatureExtractor):
212
+ def __init__(self, model: str, n_classes: int, pretrained: bool):
213
+ super(EfficientNetGen, self).__init__()
214
+
215
+ if pretrained:
216
+ self.efficientnet = EfficientNet.from_pretrained(model)
217
+ else:
218
+ self.efficientnet = EfficientNet.from_name(model)
219
+
220
+ self.classifier = nn.Linear(self.efficientnet._conv_head.out_channels, n_classes)
221
+ del self.efficientnet._fc
222
+
223
+ def features(self, x: torch.Tensor) -> torch.Tensor:
224
+ x = self.efficientnet.extract_features(x)
225
+ x = self.efficientnet._avg_pooling(x)
226
+ x = x.flatten(start_dim=1)
227
+ return x
228
+
229
+ def forward(self, x):
230
+ x = self.features(x)
231
+ x = self.efficientnet._dropout(x)
232
+ x = self.classifier(x)
233
+ # x = F.softmax(x, dim=-1)
234
+ return x
235
+
236
+
237
+ class EfficientNetB0(EfficientNetGen):
238
+ def __init__(self, n_classes: int, pretrained: bool):
239
+ super(EfficientNetB0, self).__init__(model='efficientnet-b0', n_classes=n_classes, pretrained=pretrained)
240
+
241
+
242
+ class EfficientNetB4(EfficientNetGen):
243
+ def __init__(self, n_classes: int, pretrained: bool):
244
+ super(EfficientNetB4, self).__init__(model='efficientnet-b4', n_classes=n_classes, pretrained=pretrained)
245
+
246
+
247
+ class EfficientNetGenPostStem(FeatureExtractor):
248
+ def __init__(self, model: str, n_classes: int, pretrained: bool, n_ir_blocks: int):
249
+ super(EfficientNetGenPostStem, self).__init__()
250
+
251
+ if pretrained:
252
+ self.efficientnet = EfficientNet.from_pretrained(model)
253
+ else:
254
+ self.efficientnet = EfficientNet.from_name(model)
255
+
256
+ self.n_ir_blocks = n_ir_blocks
257
+ self.classifier = nn.Linear(self.efficientnet._conv_head.out_channels, n_classes)
258
+
259
+ # modify STEM
260
+ in_channels = 3 # rgb
261
+ out_channels = round_filters(32, self.efficientnet._global_params)
262
+ Conv2d = get_same_padding_conv2d(image_size=self.efficientnet._global_params.image_size)
263
+ self.efficientnet._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, bias=False)
264
+
265
+ self.init_blocks_args = self.efficientnet._blocks_args[0]
266
+ self.init_blocks_args = self.init_blocks_args._replace(output_filters=32)
267
+ self.init_block = MBConvBlock(self.init_blocks_args, self.efficientnet._global_params)
268
+
269
+ self.last_block_args = self.efficientnet._blocks_args[0]
270
+ self.last_block_args = self.last_block_args._replace(output_filters=32, stride=2)
271
+ self.last_block = MBConvBlock(self.last_block_args, self.efficientnet._global_params)
272
+
273
+ del self.efficientnet._fc
274
+
275
+
276
+ def features(self, x: torch.Tensor) -> torch.Tensor:
277
+
278
+ x = self.efficientnet._swish(self.efficientnet._bn0(self.efficientnet._conv_stem(x)))
279
+
280
+ # init blocks
281
+ for b in range(self.n_ir_blocks - 1):
282
+ x = self.init_block(x, drop_connect_rate=0)
283
+
284
+ # last block
285
+ x = self.last_block(x, drop_connect_rate=0)
286
+
287
+ # standard blocks efficientNet:
288
+ for idx, block in enumerate(self.efficientnet._blocks):
289
+ drop_connect_rate = self.efficientnet._global_params.drop_connect_rate
290
+ if drop_connect_rate:
291
+ drop_connect_rate *= float(idx) / len(self.efficientnet._blocks)
292
+ x = block(x, drop_connect_rate=drop_connect_rate)
293
+
294
+ x = self.efficientnet._swish(self.efficientnet._bn1(self.efficientnet._conv_head(x)))
295
+
296
+ x = self.efficientnet._avg_pooling(x)
297
+ x = x.flatten(start_dim=1)
298
+ return x
299
+
300
+ def forward(self, x):
301
+ x = self.features(x)
302
+ x = self.efficientnet._dropout(x)
303
+ x = self.classifier(x)
304
+ # x = F.softmax(x, dim=-1)
305
+ return x
306
+
307
+
308
+ class EfficientNetB0PostStemIR(EfficientNetGenPostStem):
309
+ def __init__(self, n_classes: int, pretrained: bool, n_ir_blocks: int):
310
+ super(EfficientNetB0PostStemIR, self).__init__(model='efficientnet-b0', n_classes=n_classes,
311
+ pretrained=pretrained, n_ir_blocks=n_ir_blocks)
312
+
313
+
314
+ class EfficientNetGenPreStem(FeatureExtractor):
315
+ def __init__(self, model: str, n_classes: int, pretrained: bool, n_ir_blocks: int):
316
+ super(EfficientNetGenPreStem, self).__init__()
317
+
318
+ if pretrained:
319
+ self.efficientnet = EfficientNet.from_pretrained(model)
320
+ else:
321
+ self.efficientnet = EfficientNet.from_name(model)
322
+
323
+ self.n_ir_blocks = n_ir_blocks
324
+ self.classifier = nn.Linear(self.efficientnet._conv_head.out_channels, n_classes)
325
+
326
+ self.init_block_args = self.efficientnet._blocks_args[0]
327
+ self.init_block_args = self.init_block_args._replace(input_filters=3, output_filters=32)
328
+ self.init_block = MBConvBlock(self.init_block_args, self.efficientnet._global_params)
329
+
330
+ self.last_blocks_args = self.efficientnet._blocks_args[0]
331
+ self.last_blocks_args = self.last_blocks_args._replace(output_filters=32)
332
+ self.last_block = MBConvBlock(self.last_blocks_args, self.efficientnet._global_params)
333
+
334
+ # modify STEM
335
+ in_channels = 32
336
+ out_channels = round_filters(32, self.efficientnet._global_params)
337
+ Conv2d = get_same_padding_conv2d(image_size=self.efficientnet._global_params.image_size)
338
+ self.efficientnet._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
339
+
340
+ del self.efficientnet._fc
341
+
342
+ def features(self, x: torch.Tensor) -> torch.Tensor:
343
+
344
+ # init block
345
+ x = self.init_block(x, drop_connect_rate=0)
346
+
347
+ # other blocks
348
+ for b in range(self.n_ir_blocks - 1):
349
+ x = self.last_block(x, drop_connect_rate=0)
350
+
351
+ # standard stem efficientNet:
352
+ x = self.efficientnet._swish(self.efficientnet._bn0(self.efficientnet._conv_stem(x)))
353
+
354
+ # standard blocks efficientNet:
355
+ for idx, block in enumerate(self.efficientnet._blocks):
356
+ drop_connect_rate = self.efficientnet._global_params.drop_connect_rate
357
+ if drop_connect_rate:
358
+ drop_connect_rate *= float(idx) / len(self.efficientnet._blocks)
359
+ x = block(x, drop_connect_rate=drop_connect_rate)
360
+
361
+ x = self.efficientnet._swish(self.efficientnet._bn1(self.efficientnet._conv_head(x)))
362
+
363
+ x = self.efficientnet._avg_pooling(x)
364
+ x = x.flatten(start_dim=1)
365
+ return x
366
+
367
+ def forward(self, x):
368
+ x = self.features(x)
369
+ x = self.efficientnet._dropout(x)
370
+ x = self.classifier(x)
371
+ # x = F.softmax(x, dim=-1)
372
+ return x
373
+
374
+
375
+ class EfficientNetB0PreStemIR(EfficientNetGenPreStem):
376
+ def __init__(self, n_classes: int, pretrained: bool, n_ir_blocks: int):
377
+ super(EfficientNetB0PreStemIR, self).__init__(model='efficientnet-b0', n_classes=n_classes,
378
+ pretrained=pretrained, n_ir_blocks=n_ir_blocks)
379
+
380
+
381
+ class ResNet50(FeatureExtractor):
382
+ def __init__(self, n_classes: int, pretrained: bool):
383
+ super(ResNet50, self).__init__()
384
+ self.resnet = resnet.resnet50(pretrained=pretrained)
385
+ self.fc = nn.Linear(in_features=self.resnet.fc.in_features, out_features=n_classes)
386
+ del self.resnet.fc
387
+
388
+ def features(self, x):
389
+ x = forward_resnet_conv(self.resnet, x)
390
+ x = self.resnet.avgpool(x).flatten(start_dim=1)
391
+ return x
392
+
393
+ def forward(self, x):
394
+ x = self.features(x)
395
+ x = self.fc(x)
396
+ return x
397
+
398
+
399
+ """
400
+ Xception from Kaggle
401
+ """
402
+
403
+
404
+ class XceptionWeiHao(FeatureExtractor):
405
+
406
+ def __init__(self, n_classes: int, pretrained: bool):
407
+ super(XceptionWeiHao, self).__init__()
408
+
409
+ self.model = get_model("xception", pretrained=pretrained)
410
+ self.model = nn.Sequential(*list(self.model.children())[:-1]) # Remove original output layer
411
+ self.model[0].final_block.pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)))
412
+ self.model = FCN(self.model, 2048, n_classes)
413
+
414
+ def features(self, x: torch.Tensor) -> torch.Tensor:
415
+ return self.model.base(x)
416
+
417
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
418
+ x = self.features(x)
419
+ return self.model.h1(x)
420
+
421
+
422
+
utils/python_patch_extractor/PatchExtractor.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Author: Nicolo' Bonettini
3
+ @Author: Luca Bondi
4
+ @Author: Francesco Picetti
5
+ """
6
+ import random
7
+ import numpy as np
8
+ from skimage.util import view_as_windows, view_as_blocks
9
+
10
+
11
+ # Score functions ---
12
+
13
+ def mid_intensity_high_texture(in_content):
14
+ """
15
+ Quality function that returns higher scores for mid intensity patches with high texture levels. Empirical.
16
+ :type in_content: ndarray
17
+ :param in_content : 2D or 3D ndarray. Values are expected in [0,1] if in_content is float, in [0,255] if in_content is uint8
18
+ :return score: float
19
+ score in [0,1].
20
+ """
21
+
22
+ if in_content.dtype == np.uint8:
23
+ in_content = in_content / 255.
24
+
25
+ mean_std_weight = .7
26
+
27
+ in_content = in_content.flatten()
28
+
29
+ mean_val = in_content.mean()
30
+ std_val = in_content.std()
31
+
32
+ ch_mean_score = -4 * mean_val ** 2 + 4 * mean_val
33
+ ch_std_score = 1 - np.exp(-2 * np.log(10) * std_val)
34
+
35
+ score = mean_std_weight * ch_mean_score + (1 - mean_std_weight) * ch_std_score
36
+ return score
37
+
38
+
39
+ def count_patches(in_size, patch_size, patch_stride):
40
+ """
41
+ Compute the number of patches
42
+ :param in_size:
43
+ :param patch_size:
44
+ :param patch_stride:
45
+ :return:
46
+ """
47
+ win_indices_shape = (((np.array(in_size) - np.array(patch_size))
48
+ // np.array(patch_stride)) + 1)
49
+ return int(np.prod(win_indices_shape))
50
+
51
+
52
+ class PatchExtractor:
53
+
54
+ def __init__(self, dim, offset=None, stride=None, rand=None, function=None, threshold=None,
55
+ num=None, indexes=None):
56
+
57
+ """
58
+ N-dimensional patch extractor
59
+ Args:
60
+ :param in_content : ndarray
61
+ the content to process as a numpy array of ndim dimensions
62
+
63
+ :param dim : tuple
64
+ patch_array dimensions as a tuple of ndim elements
65
+
66
+ Named args:
67
+ :param offset : tuple
68
+ the offsets along each axis as a tuple of ndim elements
69
+
70
+ :param stride : tuple
71
+ the stride of each axis as a tuple of ndim elements
72
+
73
+ :param rand : bool
74
+ randomize patch_array order. Mutually exclusive with function_handler
75
+
76
+ :param function : function
77
+ patch quality function handler. Mutually exclusive with rand
78
+
79
+ :param threshold: float
80
+ minimum quality threshold
81
+
82
+ :param num : int
83
+ maximum number of returned patch_array. Mutually exclusive with indexes
84
+
85
+ :param indexes : list|ndarray
86
+ explicitly return corresponding patch indexes (function_handler or C order used to index patch_array).
87
+ Mutually exclusive with num
88
+
89
+ :return ndarray: patch_array
90
+ array of patch_array
91
+ if rand==False and function_handler==None and num==None and indexes==None:
92
+ patch_array.ndim = 2 * in_content.ndim
93
+ else:
94
+ patch_array.ndim = 1 + in_content.ndim
95
+ """
96
+
97
+ # Arguments parser ---
98
+ if not isinstance(dim, tuple):
99
+ raise ValueError('dim must be a tuple')
100
+ self.dim = dim
101
+
102
+ ndim = len(dim)
103
+ self.ndim = ndim
104
+
105
+ if offset is None:
106
+ offset = tuple([0] * ndim)
107
+ if not isinstance(offset, tuple):
108
+ raise ValueError('offset must be a tuple')
109
+ if len(offset) != ndim:
110
+ raise ValueError('offset must a tuple of length {:d}'.format(ndim))
111
+ self.offset = offset
112
+
113
+ if stride is None:
114
+ stride = dim
115
+ if not isinstance(stride, tuple):
116
+ raise ValueError('stride must be a tuple')
117
+ if len(stride) != ndim:
118
+ raise ValueError('stride must a tuple of length {:d}'.format(ndim))
119
+ self.stride = stride
120
+
121
+ if rand is not None and function is not None:
122
+ raise ValueError('rand and function cannot be set at the same time')
123
+
124
+ if rand is None:
125
+ rand = False
126
+ if not isinstance(rand, bool):
127
+ raise ValueError('rand must be a boolean')
128
+ self.rand = rand
129
+
130
+ if function is not None and not callable(function):
131
+ raise ValueError('function must be a function handler')
132
+ self.function_handler = function
133
+
134
+ if threshold is None:
135
+ threshold = 0.0
136
+ if not isinstance(threshold, float):
137
+ raise ValueError('threshold must be a float')
138
+ self.threshold = threshold
139
+
140
+ if num is not None and indexes is not None:
141
+ raise ValueError('num and indexes cannot be set at the same time')
142
+
143
+ if num is not None and not isinstance(num, int):
144
+ raise ValueError('num must be an int')
145
+ self.num = num
146
+
147
+ if indexes is not None and not isinstance(indexes, list) and not isinstance(indexes, np.ndarray):
148
+ raise ValueError('indexes must be an list or a 1d ndarray')
149
+ if indexes is not None:
150
+ indexes = np.array(indexes).flatten()
151
+ self.indexes = indexes
152
+
153
+ self.in_content_original_shape = None
154
+ self.in_content_cropped_shape = None
155
+
156
+ def extract(self, in_content):
157
+
158
+ if not isinstance(in_content, np.ndarray):
159
+ raise ValueError('in_content must be of type: ' + str(np.ndarray))
160
+
161
+ if in_content.ndim != self.ndim:
162
+ raise ValueError('in_content shape must a tuple of length {:d}'.format(self.ndim))
163
+
164
+ self.in_content_original_shape = in_content.shape
165
+
166
+ # Offset ---
167
+ for dim_idx, dim_offset in enumerate(self.offset):
168
+ dim_max = in_content.shape[dim_idx]
169
+ in_content = in_content.take(range(dim_offset, dim_max), axis=dim_idx)
170
+
171
+ # Patch list ---
172
+ if self.dim == self.stride:
173
+ in_content_crop = in_content
174
+ for dim_idx in range(self.ndim):
175
+ dim_max = (in_content.shape[dim_idx] // self.dim[dim_idx]) * self.dim[dim_idx]
176
+ in_content_crop = in_content_crop.take(range(0, dim_max), axis=dim_idx)
177
+ patch_array = view_as_blocks(in_content_crop, self.dim)
178
+ else:
179
+ patch_array = view_as_windows(in_content, self.dim, self.stride)
180
+
181
+ patch_array = np.ascontiguousarray(patch_array)
182
+
183
+ patch_idx = patch_array.shape[:self.ndim]
184
+ self.in_content_cropped_shape = tuple((np.asarray(patch_idx) - 1) * np.asarray(self.stride) + np.asarray(self.dim))
185
+
186
+ # Evaluate patch_array or rand sort ---
187
+ if self.rand:
188
+ patch_array.shape = (-1,) + self.dim
189
+ random.shuffle(patch_array)
190
+ else:
191
+ if self.function_handler is not None:
192
+ patch_array.shape = (-1,) + self.dim
193
+ patch_scores = np.asarray(list(map(self.function_handler, patch_array)))
194
+ sort_idxs = np.argsort(patch_scores)[::-1]
195
+ patch_scores = patch_scores[sort_idxs]
196
+ patch_array = patch_array[sort_idxs]
197
+ patch_array = patch_array[patch_scores >= self.threshold]
198
+
199
+ if self.num is not None:
200
+ patch_array.shape = (-1,) + self.dim
201
+ patch_array = patch_array[:self.num]
202
+
203
+ if self.indexes is not None:
204
+ patch_array.shape = (-1,) + self.dim
205
+ patch_array = patch_array[self.indexes]
206
+
207
+ return patch_array
208
+
209
+ def extract_call(self, args): # TODO: verify
210
+ in_content = args.pop('in_content')
211
+ dim = args.pop('dim')
212
+
213
+ return self.extract(in_content)
214
+
215
+ def reconstruct(self, patch_array):
216
+ """
217
+ Reconstruct the N-dim image from the patch_array that has been extracted previously
218
+ :param patch_array: array of patches as output of patch_extractor
219
+ :return:
220
+ """
221
+ # Arguments parser ---
222
+ if not isinstance(patch_array, np.ndarray):
223
+ raise ValueError('patch_array must be of type: ' + str(np.ndarray))
224
+
225
+ ndim = patch_array.ndim // 2
226
+
227
+ # if not isinstance(patch_stride, tuple):
228
+ # raise ValueError('patch_stride must be a tuple')
229
+ # if len(patch_stride) != ndim:
230
+ # raise ValueError('patch_stride must be a tuple of length {:d}'.format(ndim))
231
+ #
232
+ # if not isinstance(image_shape, tuple):
233
+ # raise ValueError('patch_idx must be a tuple')
234
+ # if len(image_shape) != ndim:
235
+ # raise ValueError('patch_idx must be a tuple of length {:d}'.format(ndim))
236
+
237
+ patch_stride = self.stride
238
+ image_shape = self.in_content_cropped_shape
239
+
240
+ patch_shape = patch_array.shape[-ndim:]
241
+ patch_idx = patch_array.shape[:ndim]
242
+ image_shape_computed = tuple((np.array(patch_idx) - 1) * np.array(patch_stride) + np.array(patch_shape))
243
+ if not image_shape == image_shape_computed:
244
+ raise ValueError('There is something wrong with the dimensions!')
245
+
246
+ if ndim > 4:
247
+ raise ValueError('For now, it works only in 4D, sorry!')
248
+ numpatches = count_patches(image_shape, patch_shape, patch_stride)
249
+ patch_array_unwrapped = patch_array.reshape(numpatches, *patch_shape)
250
+ image_recon = np.zeros(image_shape)
251
+ norm_mask = np.zeros(image_shape)
252
+ counter = 0
253
+
254
+ for h in np.arange(0, image_shape[0] - patch_shape[0] + 1, patch_stride[0]):
255
+ if ndim > 1:
256
+ for i in np.arange(0, image_shape[1] - patch_shape[1] + 1, patch_stride[1]):
257
+ if ndim > 2:
258
+ for j in np.arange(0, image_shape[2] - patch_shape[2] + 1, patch_stride[2]):
259
+ if ndim > 3:
260
+ for k in np.arange(0, image_shape[3] - patch_shape[3] + 1, patch_stride[3]):
261
+ image_recon[h:h + patch_shape[0], i:i + patch_shape[1], j:j + patch_shape[2],
262
+ k:k + patch_shape[3]] += patch_array_unwrapped[counter, :, :, :, :]
263
+ norm_mask[h:h + patch_shape[0], i:i + patch_shape[1], j:j + patch_shape[2],
264
+ k:k + patch_shape[3]] += 1
265
+ counter += 1
266
+ else:
267
+ image_recon[h:h + patch_shape[0], i:i + patch_shape[1],
268
+ j:j + patch_shape[2]] += patch_array_unwrapped[counter, :, :, :]
269
+ norm_mask[h:h + patch_shape[0], i:i + patch_shape[1], j:j + patch_shape[2]] += 1
270
+ counter += 1
271
+ else:
272
+ image_recon[h:h + patch_shape[0], i:i + patch_shape[1]] += patch_array_unwrapped[counter, :, :]
273
+ norm_mask[h:h + patch_shape[0], i:i + patch_shape[1]] += 1
274
+ counter += 1
275
+ else:
276
+ image_recon[h:h + patch_shape[0]] += patch_array_unwrapped[counter, :]
277
+ norm_mask[h:h + patch_shape[0]] += 1
278
+ counter += 1
279
+
280
+ image_recon /= norm_mask
281
+
282
+ return image_recon
283
+
284
+
285
+ def main():
286
+ in_shape = (644, 481, 3)
287
+ dim = (120, 120, 3)
288
+ stride = (7, 90, 90, 3)
289
+ offset = (1, 0, 0, 0)
290
+ in_content = np.random.randint(256, size=in_shape).astype(np.uint8)
291
+ # args = {'in_content': in_content,
292
+ # 'dim': dim,
293
+ # 'offset': offset,
294
+ # 'stride': stride,
295
+ # }
296
+
297
+ # patch_array = patch_extractor_call(args)
298
+ pe = PatchExtractor(dim)
299
+ patch_array = pe.extract(in_content)
300
+ print('patch_array.shape = ' + str(patch_array.shape))
301
+ img_recon = pe.reconstruct(patch_array)
302
+ print('img_recon.shape = ' + str(img_recon.shape))
303
+
304
+
305
+ if __name__ == "__main__":
306
+ main()
utils/python_patch_extractor/__init__.py ADDED
File without changes