Luis Oala commited on
Commit
d9c7582
·
0 Parent(s):

fix aws access

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.ipynb filter=lfs diff=lfs merge=lfs -text
ABtesting.py ADDED
@@ -0,0 +1,806 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import json
4
+ from cv2 import transform
5
+
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ from torchvision.transforms import Compose, Normalize
9
+ import torch.nn.functional as F
10
+
11
+ from utils.dataset import get_dataset, Subset
12
+ from utils.base import get_mlflow_model_by_name, SmartFormatter
13
+ from processingpipeline.pipeline import RawProcessingPipeline
14
+
15
+ from utils.Cperturb import Distortions
16
+
17
+ import segmentation_models_pytorch as smp
18
+
19
+ import matplotlib.pyplot as plt
20
+
21
+ parser = argparse.ArgumentParser(description="AB testing, Show Results", formatter_class=SmartFormatter)
22
+
23
+ #Select experiment
24
+ parser.add_argument("--mode", type=str, default="ABShowImages", choices=('ABMakeTable', 'ABShowTable', 'ABShowImages', 'ABShowAllImages', 'CMakeTable', 'CShowTable', 'CShowImages', 'CShowAllImages'),
25
+ help='R|Choose operation to compute. \n'
26
+ 'A) Lens2Logit image generation: \n '
27
+ 'ABMakeTable: Compute cross-validation metrics results \n '
28
+ 'ABShowTable: Plot cross-validation results on a table \n '
29
+ 'ABShowImages: Choose a training and testing image to compare different pipelines \n '
30
+ 'ABShowAllImages: Plot all possible pipelines \n'
31
+ 'B) Hendrycks Perturbations, C-type dataset: \n '
32
+ 'CMakeTable: For each pipeline, it computes cross-validation metrics for different perturbations \n '
33
+ 'CShowTable: Plot metrics for different pipelines and perturbations \n '
34
+ 'CShowImages: Plot an image with a selected a pipeline and perturbation\n '
35
+ 'CShowAllImages: Plot all possible perturbations for a fixed pipeline' )
36
+
37
+ parser.add_argument("--dataset_name", type=str, default='Microscopy', choices=['Microscopy', 'Drone', 'DroneSegmentation'], help='Choose dataset')
38
+ parser.add_argument("--augmentation", type=str, default='weak', choices=['none','weak','strong'], help='Choose augmentation')
39
+ parser.add_argument("--N_runs", type=int, default=5, help='Number of k-fold splitting used in the training')
40
+ parser.add_argument("--download_model", default=False, action='store_true', help='Download Models in cache')
41
+
42
+ #Select pipelines
43
+ parser.add_argument("--dm_train", type=str, default='bilinear', choices= ('bilinear', 'malvar2004', 'menon2007'), help='Choose demosaicing for training processing model')
44
+ parser.add_argument("--s_train", type=str, default='sharpening_filter', choices= ('sharpening_filter', 'unsharp_masking'), help='Choose sharpening for training processing model')
45
+ parser.add_argument("--dn_train", type=str, default='gaussian_denoising', choices= ('gaussian_denoising', 'median_denoising'), help='Choose denoising for training processing model')
46
+ parser.add_argument("--dm_test", type=str, default='bilinear', choices= ('bilinear', 'malvar2004', 'menon2007'), help='Choose demosaicing for testing processing model')
47
+ parser.add_argument("--s_test", type=str, default='sharpening_filter', choices= ('sharpening_filter', 'unsharp_masking'), help='Choose sharpening for testing processing model')
48
+ parser.add_argument("--dn_test", type=str, default='gaussian_denoising', choices= ('gaussian_denoising', 'median_denoising'), help='Choose denoising for testing processing model')
49
+
50
+ #Select Ctest parameters
51
+ parser.add_argument("--transform", type=str, default='identity', choices= ('identity','gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
52
+ 'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform'), help='Choose transformation to show for Ctesting')
53
+ parser.add_argument("--severity", type=int, default=1, choices= (1,2,3,4,5), help='Choose severity for Ctesting')
54
+
55
+ args = parser.parse_args()
56
+
57
+ class metrics:
58
+ def __init__(self, confusion_matrix):
59
+ self.cm = confusion_matrix
60
+ self.N_classes = len(confusion_matrix)
61
+
62
+ def accuracy(self):
63
+ Tp = torch.diagonal(self.cm,0).sum()
64
+ N_elements = torch.sum(self.cm)
65
+ return Tp/N_elements
66
+
67
+ def precision(self):
68
+ Tp_Fp = torch.sum(self.cm, 1)
69
+ Tp_Fp[Tp_Fp == 0] = 1
70
+ return torch.diagonal(self.cm,0) / Tp_Fp
71
+
72
+ def recall(self):
73
+ Tp_Fn = torch.sum(self.cm, 0)
74
+ Tp_Fn[Tp_Fn == 0] = 1
75
+ return torch.diagonal(self.cm,0) / Tp_Fn
76
+
77
+ def f1_score(self):
78
+ prod = (self.precision()*self.recall())
79
+ sum = (self.precision() + self.recall())
80
+ sum[sum == 0.] = 1.
81
+ return 2*( prod / sum )
82
+
83
+ def over_N_runs(ms, N_runs):
84
+ m, m2 = 0, 0
85
+
86
+ for i in ms:
87
+ m += i
88
+ mu = m/N_runs
89
+
90
+ for i in ms:
91
+ m2 += (i-mu)**2
92
+
93
+ sigma = torch.sqrt( m2 / (N_runs-1) )
94
+
95
+ return mu.tolist(), sigma.tolist()
96
+
97
+ class ABtesting:
98
+ def __init__(self,
99
+ dataset_name: str,
100
+ augmentation: str,
101
+ dm_train: str,
102
+ s_train: str,
103
+ dn_train: str,
104
+ dm_test: str,
105
+ s_test: str,
106
+ dn_test: str,
107
+ N_runs: int,
108
+ severity=1,
109
+ transform='identity',
110
+ download_model=False):
111
+ self.experiment_name = 'ABtesting'
112
+ self.dataset_name = dataset_name
113
+ self.augmentation = augmentation
114
+ self.dm_train = dm_train
115
+ self.s_train = s_train
116
+ self.dn_train = dn_train
117
+ self.dm_test = dm_test
118
+ self.s_test = s_test
119
+ self.dn_test = dn_test
120
+ self.N_runs = N_runs
121
+ self.severity = severity
122
+ self.transform = transform
123
+ self.download_model = download_model
124
+
125
+ def static_pip_val(self, debayer=None, sharpening=None, denoising=None, severity=None, transform=None, plot_mode=False):
126
+
127
+ if debayer == None:
128
+ debayer = self.dm_test
129
+ if sharpening == None:
130
+ sharpening = self.s_test
131
+ if denoising == None:
132
+ denoising = self.dn_test
133
+ if severity == None:
134
+ severity = self.severity
135
+ if transform == None:
136
+ transform = self.transform
137
+
138
+ dataset = get_dataset(self.dataset_name)
139
+
140
+ if self.dataset_name == "Drone" or self.dataset_name == "DroneSegmentation":
141
+ mean = torch.tensor([0.35, 0.36, 0.35])
142
+ std = torch.tensor([0.12, 0.11, 0.12])
143
+ elif self.dataset_name == "Microscopy":
144
+ mean = torch.tensor([0.91, 0.84, 0.94])
145
+ std = torch.tensor([0.08, 0.12, 0.05])
146
+
147
+ if not plot_mode:
148
+ dataset.transform = Compose([RawProcessingPipeline(
149
+ camera_parameters=dataset.camera_parameters,
150
+ debayer=debayer,
151
+ sharpening=sharpening,
152
+ denoising=denoising,
153
+ ), Distortions(severity=severity, transform=transform),
154
+ Normalize(mean, std)])
155
+ else:
156
+ dataset.transform = Compose([RawProcessingPipeline(
157
+ camera_parameters=dataset.camera_parameters,
158
+ debayer=debayer,
159
+ sharpening=sharpening,
160
+ denoising=denoising,
161
+ ), Distortions(severity=severity, transform=transform)])
162
+
163
+ return dataset
164
+
165
+ def ABclassification(self):
166
+
167
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
168
+
169
+ parent_run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"
170
+
171
+ print(f'\nTraining pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_train}, Sharpening: {self.s_train}, Denoiser: {self.dn_train} \n')
172
+ print(f'\nTesting pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_test}, Sharpening: {self.s_test}, Denoiser: {self.dn_test} \n Transform: {self.transform}, Severity: {self.severity}\n')
173
+
174
+ accuracies, precisions, recalls, f1_scores = [],[],[],[]
175
+
176
+ os.system('rm -r /tmp/py*')
177
+
178
+ for N_run in range(self.N_runs):
179
+
180
+ print(f"Evaluating Run {N_run}")
181
+
182
+ run_name = parent_run_name+'_'+str(N_run)
183
+
184
+ state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name,
185
+ download_model=self.download_model)
186
+
187
+ dataset = self.static_pip_val()
188
+ valid_set = Subset(dataset, indices=state_dict['valid_indices'])
189
+ valid_loader = DataLoader(valid_set, batch_size=1, num_workers=16, shuffle=False)
190
+
191
+ model.eval()
192
+
193
+ len_classes = len(dataset.classes)
194
+ confusion_matrix = torch.zeros((len_classes, len_classes))
195
+
196
+ for img, label in valid_loader:
197
+
198
+ prediction = model(img.to(DEVICE)).detach().cpu()
199
+ prediction = torch.argmax(prediction, dim=1)
200
+ confusion_matrix[label,prediction] += 1 # Real value rows, Declared columns
201
+
202
+ m = metrics(confusion_matrix)
203
+
204
+ accuracies.append(m.accuracy())
205
+ precisions.append(m.precision())
206
+ recalls.append(m.recall())
207
+ f1_scores.append(m.f1_score())
208
+
209
+ os.system('rm -r /tmp/t*')
210
+
211
+ accuracy = metrics.over_N_runs(accuracies, self.N_runs)
212
+ precision = metrics.over_N_runs(precisions, self.N_runs)
213
+ recall = metrics.over_N_runs(recalls, self.N_runs)
214
+ f1_score = metrics.over_N_runs(f1_scores, self.N_runs)
215
+ return dataset.classes, accuracy, precision, recall, f1_score
216
+
217
+ def ABsegmentation(self):
218
+
219
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
220
+
221
+ parent_run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"
222
+
223
+ print(f'\nTraining pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_train}, Sharpening: {self.s_train}, Denoiser: {self.dn_train} \n')
224
+ print(f'\nTesting pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_test}, Sharpening: {self.s_test}, Denoiser: {self.dn_test} \n Transform: {self.transform}, Severity: {self.severity}\n')
225
+
226
+ IoUs = []
227
+
228
+ os.system('rm -r /tmp/py*')
229
+
230
+ for N_run in range(self.N_runs):
231
+
232
+ print(f"Evaluating Run {N_run}")
233
+
234
+ run_name = parent_run_name+'_'+str(N_run)
235
+
236
+ state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name,
237
+ download_model=self.download_model)
238
+
239
+ dataset = self.static_pip_val()
240
+
241
+ valid_set = Subset(dataset, indices=state_dict['valid_indices'])
242
+ valid_loader = DataLoader(valid_set, batch_size=1, num_workers=16, shuffle=False)
243
+
244
+ model.eval()
245
+
246
+ IoU=0
247
+
248
+ for img, label in valid_loader:
249
+
250
+ prediction = model(img.to(DEVICE)).detach().cpu()
251
+ prediction = F.logsigmoid(prediction).exp().squeeze()
252
+ IoU += smp.utils.metrics.IoU()(prediction,label)
253
+
254
+ IoU = IoU/len(valid_loader)
255
+ IoUs.append(IoU.item())
256
+
257
+ os.system('rm -r /tmp/t*')
258
+
259
+ IoU = metrics.over_N_runs(torch.tensor(IoUs), self.N_runs)
260
+ return IoU
261
+
262
+ def ABShowImages(self):
263
+
264
+ path = 'results/ABtesting/imgs/'
265
+ if not os.path.exists(path):
266
+ os.makedirs(path)
267
+
268
+ path = os.path.join(path, f'{self.dataset_name}_{self.augmentation}_{self.dm_train[:2]}{self.s_train[0]}{self.dn_train[:2]}_{self.dm_test[:2]}{self.s_test[0]}{self.dn_test[:2]}')
269
+
270
+ if not os.path.exists(path):
271
+ os.makedirs(path)
272
+
273
+ run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"+'_'+str(0)
274
+
275
+ state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name, download_model=self.download_model)
276
+
277
+ model.augmentation = None
278
+
279
+ for t in ([self.dm_train, self.s_train, self.dn_train, 'train_img'],
280
+ [self.dm_test, self.s_test, self.dn_test, 'test_img']):
281
+
282
+ debayer, sharpening, denoising, img_type = t[0], t[1], t[2], t[3]
283
+
284
+ dataset = self.static_pip_val(debayer=debayer, sharpening=sharpening, denoising=denoising, plot_mode=True)
285
+ valid_set = Subset(dataset, indices=state_dict['valid_indices'])
286
+
287
+ img, _ = next(iter(valid_set))
288
+
289
+ plt.figure()
290
+ plt.imshow(img.permute(1,2,0))
291
+ if img_type == 'train_img':
292
+ plt.title('Train Image')
293
+ plt.savefig(os.path.join(path, f'img_train.png'))
294
+ imgA = img
295
+ else:
296
+ plt.title('Test Image')
297
+ plt.savefig(os.path.join(path,f'img_test.png'))
298
+
299
+ for c, color in enumerate(['Red','Green','Blue']):
300
+ diff = torch.abs(imgA-img)
301
+ plt.figure()
302
+ # plt.imshow(diff.permute(1,2,0))
303
+ plt.imshow(diff[c,50:200,50:200], cmap=f'{color}s')
304
+ plt.title(f'|Train Image - Test Image| - {color}')
305
+ plt.colorbar()
306
+ plt.savefig(os.path.join(path, f'diff_{color}.png'))
307
+ plt.figure()
308
+ diff[diff == 0.]= 1e-5
309
+ # plt.imshow(torch.log(diff.permute(1,2,0)))
310
+ plt.imshow(torch.log(diff)[c])
311
+ plt.title(f'log(|Train Image - Test Image|) - color')
312
+ plt.colorbar()
313
+ plt.savefig(os.path.join(path, f'logdiff_{color}.png'))
314
+
315
+ if self.dataset_name == 'DroneSegmentation':
316
+ plt.figure()
317
+ plt.imshow(model(img[None].cuda()).detach().cpu().squeeze())
318
+ if img_type == 'train_img':
319
+ plt.savefig(os.path.join(path, f'mask_train.png'))
320
+ else:
321
+ plt.savefig(os.path.join(path,f'mask_test.png'))
322
+
323
+ def ABShowAllImages(self):
324
+ if not os.path.exists('results/ABtesting'):
325
+ os.makedirs('results/ABtesting')
326
+
327
+ demosaicings=['bilinear','malvar2004', 'menon2007']
328
+ sharpenings=['sharpening_filter', 'unsharp_masking']
329
+ denoisings=['median_denoising', 'gaussian_denoising']
330
+
331
+ fig = plt.figure()
332
+ columns=4
333
+ rows=3
334
+
335
+ i=1
336
+
337
+ for dm in demosaicings:
338
+ for s in sharpenings:
339
+ for dn in denoisings:
340
+
341
+ dataset = self.static_pip_val(self.dm_test, self.s_test,
342
+ self.dn_test, plot_mode=True)
343
+
344
+ img,_ = dataset[0]
345
+
346
+ fig.add_subplot(rows, columns, i)
347
+ plt.imshow(img.permute(1,2,0))
348
+ plt.title(f'{dm}\n{s}\n{dn}', fontsize=8)
349
+ plt.xticks([])
350
+ plt.yticks([])
351
+ plt.tight_layout()
352
+
353
+ i+=1
354
+
355
+ plt.show()
356
+ plt.savefig(f'results/ABtesting/ABpipelines.png')
357
+
358
+ def CShowImages(self):
359
+
360
+ path = 'results/Ctesting/imgs/'
361
+ if not os.path.exists(path):
362
+ os.makedirs(path)
363
+
364
+ run_name = f"{self.dataset_name}_{self.dm_test}_{self.s_test}_{self.dn_test}_{self.augmentation}"+'_'+str(0)
365
+
366
+ state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name, download_model=True)
367
+
368
+ model.augmentation = None
369
+
370
+ dataset = self.static_pip_val(self.dm_test, self.s_test, self.dn_test, self.severity, self.transform, plot_mode=True)
371
+ valid_set = Subset(dataset, indices=state_dict['valid_indices'])
372
+
373
+ img, _ = next(iter(valid_set))
374
+
375
+ plt.figure()
376
+ plt.imshow(img.permute(1,2,0))
377
+ plt.savefig(os.path.join(path, f'{self.dataset_name}_{self.augmentation}_{self.dm_train[:2]}{self.s_train[0]}{self.dn_train[:2]}_{self.transform}_sev{self.severity}'))
378
+
379
+ def CShowAllImages(self):
380
+ if not os.path.exists('results/Cimages'):
381
+ os.makedirs('results/Cimages')
382
+
383
+ transforms = ['identity','gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
384
+ 'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
385
+
386
+ for i,t in enumerate(transforms):
387
+
388
+ fig = plt.figure(figsize=(10,6))
389
+ columns = 5
390
+ rows = 1
391
+
392
+ for sev in range(1,6):
393
+
394
+ dataset = self.static_pip_val(severity=sev, transform=t, plot_mode=True)
395
+
396
+ img,_ = dataset[0]
397
+
398
+ fig.add_subplot(rows, columns, sev)
399
+ plt.imshow(img.permute(1,2,0))
400
+ plt.title(f'Severity: {sev}')
401
+ plt.xticks([])
402
+ plt.yticks([])
403
+ plt.tight_layout()
404
+
405
+ if '_' in t:
406
+ t=t.replace('_', ' ')
407
+ t=t[0].upper()+t[1:]
408
+
409
+ fig.suptitle(f'{t}', x=0.5, y=0.8, fontsize=24)
410
+ plt.show()
411
+ plt.savefig(f'results/Cimages/{i+1}_{t.lower()}.png')
412
+
413
+ def ABMakeTable(dataset_name:str, augmentation: str,
414
+ N_runs: int, download_model: bool):
415
+
416
+ demosaicings=['bilinear','malvar2004', 'menon2007']
417
+ sharpenings=['sharpening_filter', 'unsharp_masking']
418
+ denoisings=['median_denoising', 'gaussian_denoising']
419
+
420
+ path='results/ABtesting/tables'
421
+ if not os.path.exists(path):
422
+ os.makedirs(path)
423
+
424
+ runs={}
425
+ i=0
426
+
427
+ for dm_train in demosaicings:
428
+ for s_train in sharpenings:
429
+ for dn_train in denoisings:
430
+ for dm_test in demosaicings:
431
+ for s_test in sharpenings:
432
+ for dn_test in denoisings:
433
+ train_pip = [dm_train, s_train, dn_train]
434
+ test_pip = [dm_test, s_test, dn_test]
435
+ runs[f'run{i}'] = {
436
+ 'dataset': dataset_name,
437
+ 'augmentation': augmentation,
438
+ 'train_pip': train_pip,
439
+ 'test_pip': test_pip,
440
+ 'N_runs': N_runs
441
+ }
442
+ ABclass = ABtesting(
443
+ dataset_name=dataset_name,
444
+ augmentation=augmentation,
445
+ dm_train = dm_train,
446
+ s_train = s_train,
447
+ dn_train = dn_train,
448
+ dm_test = dm_test,
449
+ s_test = s_test,
450
+ dn_test = dn_test,
451
+ N_runs=N_runs,
452
+ download_model=download_model
453
+ )
454
+
455
+ if dataset_name == 'DroneSegmentation':
456
+ IoU = ABclass.ABsegmentation()
457
+ runs[f'run{i}']['IoU'] = IoU
458
+ else:
459
+ classes, accuracy, precision, recall, f1_score = ABclass.ABclassification()
460
+ runs[f'run{i}']['classes'] = classes
461
+ runs[f'run{i}']['accuracy'] = accuracy
462
+ runs[f'run{i}']['precision'] = precision
463
+ runs[f'run{i}']['recall'] = recall
464
+ runs[f'run{i}']['f1_score'] = f1_score
465
+
466
+ with open(os.path.join(path,f'{dataset_name}_{augmentation}_runs.txt'), 'w') as outfile:
467
+ json.dump(runs, outfile)
468
+
469
+ i+=1
470
+
471
+ def ABShowTable(dataset_name: str, augmentation: str):
472
+
473
+ path='results/ABtesting/tables'
474
+ assert os.path.exists(path), 'No tables to plot'
475
+
476
+ json_file = os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt')
477
+
478
+ with open(json_file, 'r') as run_file:
479
+ runs = json.load(run_file)
480
+
481
+ metrics=torch.zeros((2,12,12))
482
+ classes=[]
483
+
484
+ i,j=0,0
485
+
486
+ for r in range(len(runs)):
487
+
488
+ run = runs['run'+str(r)]
489
+ if dataset_name == 'DroneSegmentation':
490
+ acc = run['IoU']
491
+ else:
492
+ acc = run['accuracy']
493
+ if len(classes) < 12:
494
+ class_list = run['test_pip']
495
+ class_name = f'{class_list[0][:2]},{class_list[1][:1]},{class_list[2][:2]}'
496
+ classes.append(class_name)
497
+ mu,sigma = round(acc[0],4),round(acc[1],4)
498
+
499
+ metrics[0,j,i] = mu
500
+ metrics[1,j,i] = sigma
501
+
502
+ i+=1
503
+
504
+ if i == 12:
505
+ i=0
506
+ j+=1
507
+
508
+ differences = torch.zeros_like(metrics)
509
+
510
+ diag_mu = torch.diagonal(metrics[0],0)
511
+ diag_sigma = torch.diagonal(metrics[1],0)
512
+
513
+ for r in range(len(metrics[0])):
514
+ differences[0,r] = diag_mu[r] - metrics[0,r]
515
+ differences[1,r] = torch.sqrt(metrics[1,r]**2 + diag_sigma[r]**2)
516
+
517
+ # Plot with scatter
518
+
519
+ for i,img in enumerate([metrics, differences]):
520
+
521
+ x, y = torch.arange(12), torch.arange(12)
522
+ x, y = torch.meshgrid(x, y)
523
+
524
+ if i == 0:
525
+ vmin = max(0.65, round(img[0].min().item(),2))
526
+ vmax = round(img[0].max().item(),2)
527
+ step = 0.02
528
+ elif i == 1:
529
+ vmin = round(img[0].min().item(),2)
530
+ if augmentation == 'none':
531
+ vmax = min(0.15, round(img[0].max().item(),2))
532
+ if augmentation == 'weak':
533
+ vmax = min(0.08, round(img[0].max().item(),2))
534
+ if augmentation == 'strong':
535
+ vmax = min(0.05, round(img[0].max().item(),2))
536
+ step = 0.01
537
+
538
+ vmin = int(vmin/step)*step
539
+ vmax = int(vmax/step)*step
540
+
541
+ fig = plt.figure(figsize=(10,6.2))
542
+ ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
543
+ marker_size=350
544
+ plt.scatter(x, y, c=torch.rot90(img[1][x,y],-1,[0,1]), vmin = 0., vmax = img[1].max(), cmap='viridis', s=marker_size*2, marker='s')
545
+ ticks = torch.arange(0.,img[1].max(),0.03).tolist()
546
+ ticks = [round(tick,2) for tick in ticks]
547
+ cba = plt.colorbar(pad=0.06)
548
+ cba.set_ticks(ticks)
549
+ cba.ax.set_yticklabels(ticks)
550
+ # cmap = plt.cm.get_cmap('tab20c').reversed()
551
+ cmap = plt.cm.get_cmap('Reds')
552
+ plt.scatter(x,y, c=torch.rot90(img[0][x,y],-1,[0,1]), vmin = vmin, vmax = vmax, cmap=cmap, s=marker_size, marker='s')
553
+ ticks = torch.arange(vmin, vmax, step).tolist()
554
+ ticks = [round(tick,2) for tick in ticks]
555
+ if ticks[-1] != vmax:
556
+ ticks.append(vmax)
557
+ cbb = plt.colorbar(pad=0.06)
558
+ cbb.set_ticks(ticks)
559
+ if i == 0:
560
+ ticks[0] = f'<{str(ticks[0])}'
561
+ elif i == 1:
562
+ ticks[-1] = f'>{str(ticks[-1])}'
563
+ cbb.ax.set_yticklabels(ticks)
564
+ for x in range(12):
565
+ for y in range(12):
566
+ txt = round(torch.rot90(img[0],-1,[0,1])[x,y].item(),2)
567
+ if str(txt) == '-0.0':
568
+ txt = '0.00'
569
+ elif str(txt) == '0.0':
570
+ txt = '0.00'
571
+ elif len(str(txt)) == 3:
572
+ txt = str(txt)+'0'
573
+ else:
574
+ txt = str(txt)
575
+
576
+ plt.text(x-0.25,y-0.1,txt, color='black', fontsize='x-small')
577
+
578
+ ax.set_xticks(torch.linspace(0,11,12))
579
+ ax.set_xticklabels(classes)
580
+ ax.set_yticks(torch.linspace(0,11,12))
581
+ classes.reverse()
582
+ ax.set_yticklabels(classes)
583
+ classes.reverse()
584
+ plt.xticks(rotation = 45)
585
+ plt.yticks(rotation = 45)
586
+ cba.set_label('Standard Deviation')
587
+ plt.xlabel("Test pipelines")
588
+ plt.ylabel("Train pipelines")
589
+ plt.title(f'Dataset: {dataset_name}, Augmentation: {augmentation}')
590
+ if i == 0:
591
+ if dataset_name == 'DroneSegmentation':
592
+ cbb.set_label('IoU')
593
+ plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_IoU.png"))
594
+ else:
595
+ cbb.set_label('Accuracy')
596
+ plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_accuracies.png"))
597
+ elif i == 1:
598
+ if dataset_name == 'DroneSegmentation':
599
+ cbb.set_label('IoU_d-IoU')
600
+ else:
601
+ cbb.set_label('Accuracy_d - Accuracy')
602
+ plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_differences.png"))
603
+
604
+ def CMakeTable(dataset_name: str, augmentation: str, severity: int, N_runs: int, download_model: bool):
605
+
606
+ path='results/Ctesting/tables'
607
+ if not os.path.exists(path):
608
+ os.makedirs(path)
609
+
610
+ demosaicings=['bilinear','malvar2004', 'menon2007']
611
+ sharpenings=['sharpening_filter', 'unsharp_masking']
612
+ denoisings=['median_denoising', 'gaussian_denoising']
613
+
614
+ transformations = ['identity','gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
615
+ 'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
616
+
617
+ runs={}
618
+ i=0
619
+
620
+ for dm in demosaicings:
621
+ for s in sharpenings:
622
+ for dn in denoisings:
623
+ for t in transformations:
624
+ pip = [dm,s,dn]
625
+ runs[f'run{i}'] = {
626
+ 'dataset': dataset_name,
627
+ 'augmentation': augmentation,
628
+ 'pipeline': pip,
629
+ 'N_runs': N_runs,
630
+ 'transform': t,
631
+ 'severity': severity,
632
+ }
633
+ ABclass = ABtesting(
634
+ dataset_name=dataset_name,
635
+ augmentation=augmentation,
636
+ dm_train = dm,
637
+ s_train = s,
638
+ dn_train = dn,
639
+ dm_test = dm,
640
+ s_test = s,
641
+ dn_test = dn,
642
+ severity=severity,
643
+ transform=t,
644
+ N_runs=N_runs,
645
+ download_model=download_model
646
+ )
647
+
648
+ if dataset_name == 'DroneSegmentation':
649
+ IoU = ABclass.ABsegmentation()
650
+ runs[f'run{i}']['IoU'] = IoU
651
+ else:
652
+ classes, accuracy, precision, recall, f1_score = ABclass.ABclassification()
653
+ runs[f'run{i}']['classes'] = classes
654
+ runs[f'run{i}']['accuracy'] = accuracy
655
+ runs[f'run{i}']['precision'] = precision
656
+ runs[f'run{i}']['recall'] = recall
657
+ runs[f'run{i}']['f1_score'] = f1_score
658
+
659
+ with open(os.path.join(path,f'{dataset_name}_{augmentation}_runs.json'), 'w') as outfile:
660
+ json.dump(runs, outfile)
661
+
662
+ i+=1
663
+
664
+ def CShowTable(dataset_name, augmentation):
665
+
666
+ path='results/Ctesting/tables'
667
+ assert os.path.exists(path), 'No tables to plot'
668
+
669
+ json_file = os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt')
670
+
671
+ transforms = ['identity','gauss_noise', 'shot', 'impulse', 'speckle',
672
+ 'gauss_blur', 'zoom', 'contrast', 'brightness', 'saturate', 'elastic']
673
+
674
+ pip = []
675
+
676
+ demosaicings=['bilinear','malvar2004', 'menon2007']
677
+ sharpenings=['sharpening_filter', 'unsharp_masking']
678
+ denoisings=['median_denoising', 'gaussian_denoising']
679
+
680
+ for dm in demosaicings:
681
+ for s in sharpenings:
682
+ for dn in denoisings:
683
+ pip.append(f'{dm[:2]},{s[0]},{dn[2]}')
684
+
685
+ with open(json_file, 'r') as run_file:
686
+ runs = json.load(run_file)
687
+
688
+ metrics=torch.zeros((2,len(pip),len(transforms)))
689
+
690
+ i,j=0,0
691
+
692
+ for r in range(len(runs)):
693
+
694
+ run = runs['run'+str(r)]
695
+ if dataset_name == 'DroneSegmentation':
696
+ acc = run['IoU']
697
+ else:
698
+ acc = run['accuracy']
699
+ mu,sigma = round(acc[0],4),round(acc[1],4)
700
+
701
+ metrics[0,j,i] = mu
702
+ metrics[1,j,i] = sigma
703
+
704
+ i+=1
705
+
706
+ if i == len(transforms):
707
+ i=0
708
+ j+=1
709
+
710
+ # Plot with scatter
711
+
712
+ img = metrics
713
+
714
+ vmin=0.
715
+ vmax=1.
716
+
717
+ x, y = torch.arange(12), torch.arange(11)
718
+ x, y = torch.meshgrid(x, y)
719
+
720
+ fig = plt.figure(figsize=(10,6.2))
721
+ ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
722
+ marker_size=350
723
+ plt.scatter(x, y, c=torch.rot90(img[1][x,y],-1,[0,1]), vmin = 0., vmax = img[1].max(), cmap='viridis', s=marker_size*2, marker='s')
724
+ ticks = torch.arange(0.,img[1].max(),0.03).tolist()
725
+ ticks = [round(tick,2) for tick in ticks]
726
+ cba = plt.colorbar(pad=0.06)
727
+ cba.set_ticks(ticks)
728
+ cba.ax.set_yticklabels(ticks)
729
+ # cmap = plt.cm.get_cmap('tab20c').reversed()
730
+ cmap = plt.cm.get_cmap('Reds')
731
+ plt.scatter(x,y, c=torch.rot90(img[0][x,y],-1,[0,1]), vmin=vmin, vmax=vmax, cmap=cmap, s=marker_size, marker='s')
732
+ ticks = torch.arange(vmin, vmax, step).tolist()
733
+ ticks = [round(tick,2) for tick in ticks]
734
+ if ticks[-1] != vmax:
735
+ ticks.append(vmax)
736
+ cbb = plt.colorbar(pad=0.06)
737
+ cbb.set_ticks(ticks)
738
+ if i == 0:
739
+ ticks[0] = f'<{str(ticks[0])}'
740
+ elif i == 1:
741
+ ticks[-1] = f'>{str(ticks[-1])}'
742
+ cbb.ax.set_yticklabels(ticks)
743
+ for x in range(12):
744
+ for y in range(12):
745
+ txt = round(torch.rot90(img[0],-1,[0,1])[x,y].item(),2)
746
+ if str(txt) == '-0.0':
747
+ txt = '0.00'
748
+ elif str(txt) == '0.0':
749
+ txt = '0.00'
750
+ elif len(str(txt)) == 3:
751
+ txt = str(txt)+'0'
752
+ else:
753
+ txt = str(txt)
754
+
755
+ plt.text(x-0.25,y-0.1,txt, color='black', fontsize='x-small')
756
+
757
+ ax.set_xticks(torch.linspace(0,11,12))
758
+ ax.set_xticklabels(transforms)
759
+ ax.set_yticks(torch.linspace(0,11,12))
760
+ pip.reverse()
761
+ ax.set_yticklabels(pip)
762
+ pip.reverse()
763
+ plt.xticks(rotation = 45)
764
+ plt.yticks(rotation = 45)
765
+ cba.set_label('Standard Deviation')
766
+ plt.xlabel("Pipelines")
767
+ plt.ylabel("Distortions")
768
+ if dataset_name == 'DroneSegmentation':
769
+ cbb.set_label('IoU')
770
+ plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_IoU.png"))
771
+ else:
772
+ cbb.set_label('Accuracy')
773
+ plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_accuracies.png"))
774
+
775
+ if __name__ == '__main__':
776
+
777
+ if args.mode == 'ABMakeTable':
778
+ ABMakeTable(args.dataset_name, args.augmentation, args.N_runs, args.download_model)
779
+ elif args.mode == 'ABShowTable':
780
+ ABShowTable(args.dataset_name, args.augmentation)
781
+ elif args.mode == 'ABShowImages':
782
+ ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
783
+ args.s_train, args.dn_train, args.dm_test, args.s_test,
784
+ args.dn_test, args.N_runs, download_model=args.download_model)
785
+ ABclass.ABShowImages()
786
+ elif args.mode == 'ABShowAllImages':
787
+ ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
788
+ args.s_train, args.dn_train, args.dm_test, args.s_test,
789
+ args.dn_test, args.N_runs, download_model=args.download_model)
790
+ ABclass.ABShowAllImages()
791
+ elif args.mode == 'CMakeTable':
792
+ CMakeTable(args.dataset_name, args.augmentation, args.severity, args.N_runs, args.download_model)
793
+ elif args.mode == 'CShowTable': # TODO test it
794
+ CShowTable(args.dataset_name, args.augmentation, args.severity)
795
+ elif args.mode == 'CShowImages':
796
+ ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
797
+ args.s_train, args.dn_train, args.dm_test, args.s_test,
798
+ args.dn_test, args.N_runs, args.severity, args.transform,
799
+ download_model=args.download_model)
800
+ ABclass.CShowImages()
801
+ elif args.mode == 'CShowAllImages':
802
+ ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
803
+ args.s_train, args.dn_train, args.dm_test, args.s_test,
804
+ args.dn_test, args.N_runs, args.severity, args.transform,
805
+ download_model=args.download_model)
806
+ ABclass.CShowAllImages()
README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Perturbed Minds
2
+
3
+ ## Conda environment and dependencies
4
+
5
+ To make running this code easier you can install the latest conda environment for this project stored in `perturbed-environment.yml`.
6
+
7
+ ### Install environment from `perturbed-environment.yml`
8
+
9
+ If you want to install the latest conda environment run
10
+
11
+ `conda env create -f perturbed-environment.yml`
12
+
13
+ ### Install segmentation_models_pytorch newest version
14
+
15
+ PyPi version is not up-to-date with github version and lacks features
16
+
17
+ `python -m pip install git+https://github.com/qubvel/segmentation_models.pytorch`
18
+
19
+ ### Update `perturbed-environment.yml`
20
+
21
+ If you add code that requires new packages, inside your perturbed-minds conda environment run
22
+
23
+ `conda env export > perturbed-environment.yml`
24
+
25
+ ## Walk-through
26
+ Link to the repository structure we put down in miro: https://miro.com/app/board/o9J_lQdgyf8=/
figure1.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ python figures.py \
2
+ --experiment_name track-test \
3
+ --run_name track-all \
4
+ --representation gradients \
5
+ --step gamma_correct \
6
+ --gif_name gradient \
7
+ --output gif \
figure2.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ python figures.py \
2
+ --experiment_name track-test \
3
+ --run_name track-all \
4
+ --output train_vs_val_loss \
figures.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mlflow
2
+ from mlflow.tracking import MlflowClient
3
+ from mlflow.entities import ViewType
4
+ import argparse
5
+ #gif
6
+ import os
7
+ import pathlib
8
+ import shutil
9
+ import imageio
10
+ #plot
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+
14
+ # -1. parse args
15
+ parser = argparse.ArgumentParser(description="results_analysis")
16
+ parser.add_argument("--tracking_uri", type=str,
17
+ default="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com", help='URI of the mlflow server on AWS')
18
+ parser.add_argument("--experiment_name", type=str, default=None,
19
+ help='Name of the experiment on the mlflow server, e.g. "processing_comparison"')
20
+ parser.add_argument("--run_name", type=str, default=None,
21
+ help='Name of the run on the mlflow server, e.g. "proc_nn"')
22
+ parser.add_argument("--representation", type=str, default=None,
23
+ choices=["processing", "gradients"], help='The representation form you want retrieve("processing" or "gradients")')
24
+ parser.add_argument("--step", type=str, default=None,
25
+ choices=["pre_debayer", "demosaic", "color_correct", "sharpening", "gaussian", "clipped", "gamma_correct", "rgb"],
26
+ help='The processing step you want to track ("pre_debayer" or "rgb")') #TODO: include predictions and ground truths
27
+ parser.add_argument("--gif_name", type=str, default=None,
28
+ help='Name of the gif that will be saved. Note: .gif will be added later by script') #TODO: option to include filepath where result should be written
29
+ #TODO: option to write results to existing run on mlflow
30
+ parser.add_argument("--local_dir", type=str, default=None,
31
+ help='Name of the local dir to be created to store mlflow data')
32
+ parser.add_argument("--cleanup", type=bool, default=True,
33
+ help='Whether to delete the local dir again after the script was run')
34
+ parser.add_argument("--output", type=str, default=None,
35
+ choices=["gif", "train_vs_val_loss"],
36
+ help='Which output to generate') #TODO: make this cleaner, atm it is confusing because each figure may need different set of args and it is not clear how to manage that
37
+ #TODO: idea -> fix the types of args for each figure which define the figure type but parametrize those things that can reasonably vary
38
+ args = parser.parse_args()
39
+
40
+ # 0. mlflow basics
41
+ mlflow.set_tracking_uri(args.tracking_uri)
42
+
43
+ # 1. specify experiment_name, run_name, representation and step
44
+ #is done via parse_args
45
+
46
+ # 2. use get_experiment_by_name to get experiment object
47
+ experiment = mlflow.get_experiment_by_name(args.experiment_name)
48
+
49
+ # 3. extract experiment_id
50
+ #experiment.experiment_id
51
+
52
+ # 4. use search_runs with experiment_id and run_name for string search query
53
+ filter_string = "tags.mlflow.runName = '{}'".format(args.run_name) #create the filter string with using the runName tag to query mlflow
54
+ runs = mlflow.search_runs(experiment.experiment_id, filter_string=filter_string) #returns a pandas data frame where each row is a run (if several exist under that name)
55
+ client = MlflowClient() #TODO: look more into the options of client
56
+
57
+ if args.output == "gif": #TODO: outsource these options to functions which are then loaded and can be called
58
+ # 5. extract run from list
59
+ #TODO: parent run and cv option for analysis
60
+ if args.local_dir:
61
+ local_dir = args.local_dir+"/artifacts"
62
+ else: #use the current working dir and make a subdir "artifacts" to store the data from mlflow
63
+ local_dir = str(pathlib.Path().resolve())+"/artifacts"
64
+ if not os.path.isdir('artifacts'):
65
+ os.mkdir(local_dir) #create the local_dir if it does not exist, yet #TODO: more advanced catching of existing files etc
66
+ dir = client.download_artifacts(runs["run_id"][0], "results", local_dir) #TODO: parametrize this number [0] so the right run is selected
67
+
68
+ # 6. get filenames in chronological sequence and write them to gif
69
+ dirs = [x[0] for x in os.walk(dir)]
70
+ dirs = sorted(dirs, key=str.lower)[1:] #sort chronologically and remove parent dir from list
71
+
72
+ with imageio.get_writer(args.gif_name+'.gif', mode='I') as writer: #https://imageio.readthedocs.io/en/stable/index.html#
73
+ for epoch in dirs: #extract the right file from each epoch
74
+ for _, _, files in os.walk(epoch): #
75
+ for name in files:
76
+ if args.representation in name and args.step in name and "png" in name:
77
+ image = imageio.imread(epoch+"/"+name)
78
+ writer.append_data(image)
79
+
80
+ # 7. cleanup the downloaded artifacts from client file system
81
+ if args.cleanup:
82
+ shutil.rmtree(local_dir) #delete the files downloaded from mlflow
83
+
84
+ elif args.output == "train_vs_val_loss":
85
+ train_loss = client.get_metric_history(runs["run_id"][0], "train_loss") #returns a list of metric entities https://www.mlflow.org/docs/latest/_modules/mlflow/entities/metric.html
86
+ val_loss = client.get_metric_history(runs["run_id"][0], "val_loss") #TODO: parametrize this number [0] so the right run is selected
87
+ train_loss = sorted(train_loss, key=lambda m: m.step) #sort the metric objects in list according to step property
88
+ val_loss = sorted(val_loss, key=lambda m: m.step)
89
+ plt.figure()
90
+ for m_train, m_val in zip(train_loss, val_loss):
91
+ plt.scatter(m_train.value, m_val.value, alpha=1/(m_train.step+1), color='blue')
92
+ plt.savefig("scatter.png") #TODO: parametrize filename
models/classifier.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import defaultdict
3
+
4
+ import torch
5
+ import torch.optim
6
+ from torchvision.models import resnet18
7
+ from torchvision.utils import make_grid, save_image
8
+ import torch.nn.functional as F
9
+
10
+ import pytorch_lightning as pl
11
+
12
+ import mlflow.pytorch
13
+
14
+
15
+ def resnet_model(model=resnet18, pretrained=True, in_channels=3, fc_out_features=2):
16
+ resnet = model(pretrained=pretrained)
17
+ # if not pretrained: # TODO: add case for in_channels=4
18
+ # resnet.conv1 = torch.nn.Conv2d(channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
19
+ resnet.fc = torch.nn.Linear(in_features=512, out_features=fc_out_features, bias=True)
20
+ return resnet
21
+
22
+
23
+ class LitModel(pl.LightningModule):
24
+
25
+ def __init__(self,
26
+ classifier,
27
+ loss,
28
+ lr=1e-3,
29
+ weight_decay=0,
30
+ loss_aux=None,
31
+ adv_training=False,
32
+ metrics=None,
33
+ processor=None,
34
+ augmentation=None,
35
+ is_segmentation_task=False,
36
+ augmentation_on_eval=False,
37
+ metrics_on_training=True,
38
+ freeze_classifier=False,
39
+ freeze_processor=False,
40
+ ):
41
+ super().__init__()
42
+
43
+ self.classifier = classifier
44
+ self.processor = processor
45
+
46
+ self.lr = lr
47
+ self.weight_decay = weight_decay
48
+ self.loss_fn = loss
49
+ self.loss_aux_fn = loss_aux
50
+ self.adv_training = adv_training
51
+ self.metrics = metrics
52
+ self.augmentation = augmentation
53
+ self.is_segmentation_task = is_segmentation_task
54
+ self.augmentation_on_eval = augmentation_on_eval
55
+ self.metrics_on_training = metrics_on_training
56
+
57
+ self.freeze_classifier = freeze_classifier
58
+ self.freeze_processor = freeze_processor
59
+
60
+ if freeze_classifier:
61
+ pl.LightningModule.freeze(self.classifier)
62
+ if freeze_processor:
63
+ pl.LightningModule.freeze(self.processor)
64
+
65
+ def forward(self, x):
66
+ x = self.processor(x)
67
+ apply_augmentation_step = self.training or self.augmentation_on_eval
68
+ if self.augmentation is not None and apply_augmentation_step:
69
+ x = self.augmentation(x, retain_state=self.is_segmentation_task)
70
+ x = self.classifier(x)
71
+ return x
72
+
73
+ def update_step(self, batch, step_name):
74
+ x, y = batch
75
+ # debug(self.processor)
76
+ # debug(self.processor.parameters())
77
+ # debug.pause()
78
+ # print('type', type(self.processor).__name__)
79
+
80
+ logits = self(x)
81
+
82
+ apply_augmentation_mask = self.is_segmentation_task and (self.training or self.augmentation_on_eval)
83
+ if self.augmentation is not None and apply_augmentation_mask:
84
+ y = self.augmentation(y, mask_transform=True).contiguous()
85
+
86
+ loss = self.loss_fn(logits, y)
87
+
88
+ if self.loss_aux_fn is not None:
89
+ loss_aux = self.loss_aux_fn(x)
90
+ loss += loss_aux
91
+
92
+ self.log(f'{step_name}_loss', loss, on_step=False, on_epoch=True)
93
+ if self.loss_aux_fn is not None:
94
+ self.log(f'{step_name}_loss_aux', loss_aux, on_step=False, on_epoch=True)
95
+
96
+ if self.is_segmentation_task:
97
+ y_hat = F.logsigmoid(logits).exp().squeeze()
98
+ else:
99
+ y_hat = torch.argmax(logits, dim=1)
100
+
101
+
102
+ if self.metrics is not None:
103
+ for metric in self.metrics:
104
+ metric_name = metric.__name__ if hasattr(metric, '__name__') else type(metric).__name__
105
+ if metric_name == 'accuracy' or not self.training or self.metrics_on_training:
106
+ m = metric(y_hat.cpu().detach(), y.cpu())
107
+ self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
108
+ prog_bar=self.training or metric_name == 'accuracy')
109
+ if metric_name == 'iou_score' or not self.training or self.metrics_on_training:
110
+ m = metric(y_hat.cpu().detach(), y.cpu())
111
+ self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
112
+ prog_bar=self.training or metric_name == 'iou_score')
113
+
114
+ return loss
115
+
116
+ def training_step(self, batch, batch_idx):
117
+ return self.update_step(batch, 'train')
118
+
119
+ def validation_step(self, batch, batch_idx):
120
+ return self.update_step(batch, 'val')
121
+
122
+ def test_step(self, batch, batch_idx):
123
+ return self.update_step(batch, 'test')
124
+
125
+ def train(self, mode=True):
126
+ self.training = mode
127
+ # self.processor.train(False)
128
+ self.processor.train(mode=mode and not self.freeze_processor)
129
+ self.classifier.train(mode=mode and not self.freeze_classifier)
130
+ if self.adv_training and self.processor.batch_norm is not None: # don't update batchnorm in adversarial training
131
+ self.processor.batch_norm.track_running_stats = False
132
+ return self
133
+
134
+ def configure_optimizers(self):
135
+ self.optimizer = torch.optim.Adam(self.parameters(), self.lr, weight_decay=self.weight_decay)
136
+ # parameters = [self.processor.additive_layer]
137
+ # self.optimizer = torch.optim.Adam(parameters, self.lr, weight_decay=self.weight_decay)
138
+ return self.optimizer
139
+ # self.scheduler = {
140
+ # 'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(
141
+ # self.optimizer, mode='min', factor=0.2, patience=2, min_lr=1e-6, verbose=True,
142
+ # ),
143
+ # 'monitor': 'val_loss',
144
+ # }
145
+ # return [self.optimizer], [self.scheduler]
146
+
147
+ def get_progress_bar_dict(self):
148
+ items = super().get_progress_bar_dict()
149
+ items.pop('v_num')
150
+ return items
151
+
152
+
153
+ class TrackImagesCallback(pl.callbacks.base.Callback):
154
+ def __init__(self, data_loader, track_every_epoch=False, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True):
155
+ super().__init__()
156
+ self.data_loader = data_loader
157
+
158
+ self.track_every_epoch = track_every_epoch
159
+
160
+ self.track_processing = track_processing
161
+ self.track_gradients = track_gradients
162
+ self.track_predictions = track_predictions
163
+ self.save_tensors = save_tensors
164
+
165
+ def callback_track_images(self, trainer, save_loc):
166
+ track_images(trainer.model,
167
+ self.data_loader,
168
+ track_processing=self.track_processing,
169
+ track_gradients=self.track_gradients,
170
+ track_predictions=self.track_predictions,
171
+ save_tensors=self.save_tensors,
172
+ save_loc=save_loc,
173
+ )
174
+
175
+ def on_fit_end(self, trainer, pl_module):
176
+ if not self.track_every_epoch:
177
+ save_loc = 'results'
178
+ self.callback_track_images(trainer, save_loc)
179
+
180
+ def on_train_epoch_end(self, trainer, pl_module, outputs):
181
+ if self.track_every_epoch:
182
+ save_loc = f'results/epoch_{trainer.current_epoch + 1:04d}'
183
+ self.callback_track_images(trainer, save_loc)
184
+
185
+
186
+ from utils.debug import debug
187
+
188
+
189
+ # @debug
190
+ def log_tensor(batch, path, save_tensors=True, nrow=8):
191
+ if save_tensors:
192
+ torch.save(batch, path)
193
+ mlflow.log_artifact(path, os.path.dirname(path))
194
+
195
+ img_path = path.replace('.pt', '.png')
196
+ split = img_path.split('/')
197
+ img_path = '/'.join(split[:-1]) + '/img_' + split[-1] # insert 'img_'; make it easier to find in mlflow
198
+
199
+ grid = make_grid(batch, nrow=nrow).squeeze()
200
+ save_image(grid, img_path)
201
+ mlflow.log_artifact(img_path, os.path.dirname(path))
202
+
203
+
204
+ def track_images(model, data_loader, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True, save_loc='results'):
205
+
206
+ device = model.device
207
+ processor = model.processor
208
+ classifier = model.classifier
209
+
210
+ if not hasattr(processor, 'stages'): # 'static' or 'none' pipeline
211
+ return
212
+
213
+ os.makedirs(save_loc, exist_ok=True)
214
+
215
+ # TODO: implement track_predictions
216
+
217
+ # inputs_full = []
218
+ labels_full = []
219
+ logits_full = []
220
+ stages_full = defaultdict(list)
221
+ grads_full = defaultdict(list)
222
+
223
+ for inputs, labels in data_loader:
224
+
225
+ inputs, labels = inputs.to(device), labels.to(device)
226
+ inputs.requires_grad = True
227
+
228
+ processed_rgb = processor(inputs)
229
+
230
+ if track_gradients or track_predictions:
231
+ logits = classifier(processed_rgb)
232
+
233
+ # NOTE: should zero grads for good measure
234
+ loss = model.loss_fn(logits, labels)
235
+ loss.backward()
236
+
237
+ if track_predictions:
238
+ labels_full.append(labels.cpu().detach())
239
+ logits_full.append(logits.cpu().detach())
240
+ # inputs_full.append(inputs.cpu().detach())
241
+
242
+ for stage, batch in processor.stages.items():
243
+ stages_full[stage].append(batch.cpu().detach())
244
+ if track_gradients:
245
+ grads_full[stage].append(batch.grad.cpu().detach())
246
+
247
+ with torch.no_grad():
248
+
249
+ stages = stages_full
250
+ grads = grads_full
251
+
252
+ if track_processing:
253
+ for stage, batch in stages_full.items():
254
+ stages[stage] = torch.cat(batch)
255
+
256
+ if track_gradients:
257
+ for stage, batch in grads_full.items():
258
+ grads[stage] = torch.cat(batch)
259
+
260
+ for stage_nr, stage_name in enumerate(stages):
261
+ if track_processing:
262
+ batch = stages[stage_name]
263
+ log_tensor(batch, os.path.join(save_loc, f'processing_{stage_nr}_{stage_name}.pt'), save_tensors)
264
+ if track_gradients:
265
+ batch_grad = grads[stage_name]
266
+ batch_grad = batch_grad.abs()
267
+ batch_grad = (batch_grad - batch_grad.min()) / (batch_grad.max() - batch_grad.min())
268
+ log_tensor(batch_grad, os.path.join(
269
+ save_loc, f'gradients_{stage_nr}_{stage_name}.pt'), save_tensors)
270
+
271
+ # inputs = torch.cat(inputs_full)
272
+
273
+ if track_predictions: #and model.is_segmentation_task:
274
+ labels = torch.cat(labels_full)
275
+ logits = torch.cat(logits_full)
276
+ masks = labels.unsqueeze(1)
277
+ predictions = logits #torch.sigmoid(logits).unsqueeze(1)
278
+ #mask_vis = torch.cat((masks, predictions, masks * predictions), dim=1)
279
+ #log_tensor(mask_vis, os.path.join(save_loc, f'masks.pt'), save_tensors)
280
+ log_tensor(masks, os.path.join(save_loc, f'targets.pt'), save_tensors)
281
+ log_tensor(predictions, os.path.join(save_loc, f'preds.pt'), save_tensors)
perturbed-environment.yml ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: perturbed
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - _ipyw_jlab_nb_ext_conf=0.1.0=py37_0
6
+ - _libgcc_mutex=0.1=main
7
+ - alabaster=0.7.12=py37_0
8
+ - anaconda=2019.10=py37_0
9
+ - anaconda-client=1.7.2=py37_0
10
+ - anaconda-navigator=1.9.7=py37_0
11
+ - anaconda-project=0.8.3=py_0
12
+ - asn1crypto=1.0.1=py37_0
13
+ - astroid=2.3.1=py37_0
14
+ - astropy=3.2.2=py37h7b6447c_0
15
+ - atomicwrites=1.3.0=py37_1
16
+ - attrs=19.2.0=py_0
17
+ - babel=2.7.0=py_0
18
+ - backcall=0.1.0=py37_0
19
+ - backports=1.0=py_2
20
+ - backports.functools_lru_cache=1.5=py_2
21
+ - backports.os=0.1.1=py37_0
22
+ - backports.shutil_get_terminal_size=1.0.0=py37_2
23
+ - backports.tempfile=1.0=py_1
24
+ - backports.weakref=1.0.post1=py_1
25
+ - beautifulsoup4=4.8.0=py37_0
26
+ - bitarray=1.0.1=py37h7b6447c_0
27
+ - bkcharts=0.2=py37_0
28
+ - blas=1.0=mkl
29
+ - bleach=3.1.0=py37_0
30
+ - blosc=1.16.3=hd408876_0
31
+ - bokeh=1.3.4=py37_0
32
+ - boto=2.49.0=py37_0
33
+ - bottleneck=1.2.1=py37h035aef0_1
34
+ - bzip2=1.0.8=h7b6447c_0
35
+ - ca-certificates=2019.8.28=0
36
+ - cairo=1.14.12=h8948797_3
37
+ - certifi=2019.9.11=py37_0
38
+ - cffi=1.12.3=py37h2e261b9_0
39
+ - chardet=3.0.4=py37_1003
40
+ - click=7.0=py37_0
41
+ - cloudpickle=1.2.2=py_0
42
+ - clyent=1.2.2=py37_1
43
+ - colorama=0.4.1=py37_0
44
+ - conda-package-handling=1.6.0=py37h7b6447c_0
45
+ - conda-verify=3.4.2=py_1
46
+ - contextlib2=0.6.0=py_0
47
+ - cryptography=2.7=py37h1ba5d50_0
48
+ - curl=7.65.3=hbc83047_0
49
+ - cycler=0.10.0=py37_0
50
+ - cython=0.29.13=py37he6710b0_0
51
+ - cytoolz=0.10.0=py37h7b6447c_0
52
+ - dask=2.5.2=py_0
53
+ - dask-core=2.5.2=py_0
54
+ - dbus=1.13.6=h746ee38_0
55
+ - decorator=4.4.0=py37_1
56
+ - defusedxml=0.6.0=py_0
57
+ - distributed=2.5.2=py_0
58
+ - docutils=0.15.2=py37_0
59
+ - entrypoints=0.3=py37_0
60
+ - et_xmlfile=1.0.1=py37_0
61
+ - expat=2.2.6=he6710b0_0
62
+ - fastcache=1.1.0=py37h7b6447c_0
63
+ - filelock=3.0.12=py_0
64
+ - flask=1.1.1=py_0
65
+ - fontconfig=2.13.0=h9420a91_0
66
+ - freetype=2.9.1=h8a8886c_1
67
+ - fribidi=1.0.5=h7b6447c_0
68
+ - future=0.17.1=py37_0
69
+ - get_terminal_size=1.0.0=haa9412d_0
70
+ - gevent=1.4.0=py37h7b6447c_0
71
+ - glib=2.56.2=hd408876_0
72
+ - glob2=0.7=py_0
73
+ - gmp=6.1.2=h6c8ec71_1
74
+ - gmpy2=2.0.8=py37h10f8cd9_2
75
+ - graphite2=1.3.13=h23475e2_0
76
+ - greenlet=0.4.15=py37h7b6447c_0
77
+ - gst-plugins-base=1.14.0=hbbd80ab_1
78
+ - gstreamer=1.14.0=hb453b48_1
79
+ - h5py=2.9.0=py37h7918eee_0
80
+ - harfbuzz=1.8.8=hffaf4a1_0
81
+ - hdf5=1.10.4=hb1b8bf9_0
82
+ - heapdict=1.0.1=py_0
83
+ - html5lib=1.0.1=py37_0
84
+ - icu=58.2=h9c2bf20_1
85
+ - idna=2.8=py37_0
86
+ - imageio=2.6.0=py37_0
87
+ - imagesize=1.1.0=py37_0
88
+ - intel-openmp=2019.4=243
89
+ - ipykernel=5.1.2=py37h39e3cac_0
90
+ - ipython=7.8.0=py37h39e3cac_0
91
+ - ipython_genutils=0.2.0=py37_0
92
+ - ipywidgets=7.5.1=py_0
93
+ - isort=4.3.21=py37_0
94
+ - itsdangerous=1.1.0=py37_0
95
+ - jbig=2.1=hdba287a_0
96
+ - jdcal=1.4.1=py_0
97
+ - jedi=0.15.1=py37_0
98
+ - jeepney=0.4.1=py_0
99
+ - jinja2=2.10.3=py_0
100
+ - joblib=0.13.2=py37_0
101
+ - jpeg=9b=h024ee3a_2
102
+ - json5=0.8.5=py_0
103
+ - jsonschema=3.0.2=py37_0
104
+ - jupyter=1.0.0=py37_7
105
+ - jupyter_client=5.3.3=py37_1
106
+ - jupyter_console=6.0.0=py37_0
107
+ - jupyter_core=4.5.0=py_0
108
+ - jupyterlab=1.1.4=pyhf63ae98_0
109
+ - jupyterlab_server=1.0.6=py_0
110
+ - keyring=18.0.0=py37_0
111
+ - kiwisolver=1.1.0=py37he6710b0_0
112
+ - krb5=1.16.1=h173b8e3_7
113
+ - lazy-object-proxy=1.4.2=py37h7b6447c_0
114
+ - libarchive=3.3.3=h5d8350f_5
115
+ - libcurl=7.65.3=h20c2e04_0
116
+ - libedit=3.1.20181209=hc058e9b_0
117
+ - libffi=3.2.1=hd88cf55_4
118
+ - libgcc-ng=9.1.0=hdf63c60_0
119
+ - libgfortran-ng=7.3.0=hdf63c60_0
120
+ - liblief=0.9.0=h7725739_2
121
+ - libpng=1.6.37=hbc83047_0
122
+ - libsodium=1.0.16=h1bed415_0
123
+ - libssh2=1.8.2=h1ba5d50_0
124
+ - libstdcxx-ng=9.1.0=hdf63c60_0
125
+ - libtiff=4.0.10=h2733197_2
126
+ - libtool=2.4.6=h7b6447c_5
127
+ - libuuid=1.0.3=h1bed415_2
128
+ - libxcb=1.13=h1bed415_1
129
+ - libxml2=2.9.9=hea5a465_1
130
+ - libxslt=1.1.33=h7d1a2b0_0
131
+ - llvmlite=0.29.0=py37hd408876_0
132
+ - locket=0.2.0=py37_1
133
+ - lxml=4.4.1=py37hefd8a0e_0
134
+ - lz4-c=1.8.1.2=h14c3975_0
135
+ - lzo=2.10=h49e0be7_2
136
+ - markupsafe=1.1.1=py37h7b6447c_0
137
+ - matplotlib=3.1.1=py37h5429711_0
138
+ - mccabe=0.6.1=py37_1
139
+ - mistune=0.8.4=py37h7b6447c_0
140
+ - mkl=2019.4=243
141
+ - mkl-service=2.3.0=py37he904b0f_0
142
+ - mkl_fft=1.0.14=py37ha843d7b_0
143
+ - mkl_random=1.1.0=py37hd6b4f25_0
144
+ - mock=3.0.5=py37_0
145
+ - more-itertools=7.2.0=py37_0
146
+ - mpc=1.1.0=h10f8cd9_1
147
+ - mpfr=4.0.1=hdf1c602_3
148
+ - mpmath=1.1.0=py37_0
149
+ - msgpack-python=0.6.1=py37hfd86e86_1
150
+ - multipledispatch=0.6.0=py37_0
151
+ - navigator-updater=0.2.1=py37_0
152
+ - nbconvert=5.6.0=py37_1
153
+ - nbformat=4.4.0=py37_0
154
+ - ncurses=6.1=he6710b0_1
155
+ - networkx=2.3=py_0
156
+ - nltk=3.4.5=py37_0
157
+ - nose=1.3.7=py37_2
158
+ - notebook=6.0.1=py37_0
159
+ - numba=0.45.1=py37h962f231_0
160
+ - numexpr=2.7.0=py37h9e4a6bb_0
161
+ - numpy=1.17.2=py37haad9e8e_0
162
+ - numpy-base=1.17.2=py37hde5b4d6_0
163
+ - numpydoc=0.9.1=py_0
164
+ - olefile=0.46=py37_0
165
+ - openpyxl=3.0.0=py_0
166
+ - openssl=1.1.1d=h7b6447c_2
167
+ - packaging=19.2=py_0
168
+ - pandoc=2.2.3.2=0
169
+ - pandocfilters=1.4.2=py37_1
170
+ - pango=1.42.4=h049681c_0
171
+ - parso=0.5.1=py_0
172
+ - partd=1.0.0=py_0
173
+ - patchelf=0.9=he6710b0_3
174
+ - path.py=12.0.1=py_0
175
+ - pathlib2=2.3.5=py37_0
176
+ - patsy=0.5.1=py37_0
177
+ - pcre=8.43=he6710b0_0
178
+ - pep8=1.7.1=py37_0
179
+ - pexpect=4.7.0=py37_0
180
+ - pickleshare=0.7.5=py37_0
181
+ - pip=19.2.3=py37_0
182
+ - pixman=0.38.0=h7b6447c_0
183
+ - pkginfo=1.5.0.1=py37_0
184
+ - pluggy=0.13.0=py37_0
185
+ - ply=3.11=py37_0
186
+ - prometheus_client=0.7.1=py_0
187
+ - prompt_toolkit=2.0.10=py_0
188
+ - psutil=5.6.3=py37h7b6447c_0
189
+ - ptyprocess=0.6.0=py37_0
190
+ - py=1.8.0=py37_0
191
+ - py-lief=0.9.0=py37h7725739_2
192
+ - pycodestyle=2.5.0=py37_0
193
+ - pycosat=0.6.3=py37h14c3975_0
194
+ - pycparser=2.19=py37_0
195
+ - pycrypto=2.6.1=py37h14c3975_9
196
+ - pycurl=7.43.0.3=py37h1ba5d50_0
197
+ - pyflakes=2.1.1=py37_0
198
+ - pygments=2.4.2=py_0
199
+ - pylint=2.4.2=py37_0
200
+ - pyodbc=4.0.27=py37he6710b0_0
201
+ - pyopenssl=19.0.0=py37_0
202
+ - pyparsing=2.4.2=py_0
203
+ - pyqt=5.9.2=py37h05f1152_2
204
+ - pyrsistent=0.15.4=py37h7b6447c_0
205
+ - pysocks=1.7.1=py37_0
206
+ - pytables=3.5.2=py37h71ec239_1
207
+ - pytest=5.2.1=py37_0
208
+ - pytest-arraydiff=0.3=py37h39e3cac_0
209
+ - pytest-astropy=0.5.0=py37_0
210
+ - pytest-doctestplus=0.4.0=py_0
211
+ - pytest-openfiles=0.4.0=py_0
212
+ - pytest-remotedata=0.3.2=py37_0
213
+ - python=3.7.4=h265db76_1
214
+ - python-dateutil=2.8.0=py37_0
215
+ - python-libarchive-c=2.8=py37_13
216
+ - pytz=2019.3=py_0
217
+ - pyyaml=5.1.2=py37h7b6447c_0
218
+ - pyzmq=18.1.0=py37he6710b0_0
219
+ - qt=5.9.7=h5867ecd_1
220
+ - qtawesome=0.6.0=py_0
221
+ - qtconsole=4.5.5=py_0
222
+ - qtpy=1.9.0=py_0
223
+ - readline=7.0=h7b6447c_5
224
+ - requests=2.22.0=py37_0
225
+ - ripgrep=0.10.0=hc07d326_0
226
+ - rope=0.14.0=py_0
227
+ - ruamel_yaml=0.15.46=py37h14c3975_0
228
+ - scikit-learn=0.21.3=py37hd81dba3_0
229
+ - scipy=1.3.1=py37h7c811a0_0
230
+ - seaborn=0.9.0=py37_0
231
+ - secretstorage=3.1.1=py37_0
232
+ - send2trash=1.5.0=py37_0
233
+ - setuptools=41.4.0=py37_0
234
+ - simplegeneric=0.8.1=py37_2
235
+ - singledispatch=3.4.0.3=py37_0
236
+ - sip=4.19.8=py37hf484d3e_0
237
+ - six=1.12.0=py37_0
238
+ - snappy=1.1.7=hbae5bb6_3
239
+ - snowballstemmer=2.0.0=py_0
240
+ - sortedcollections=1.1.2=py37_0
241
+ - sortedcontainers=2.1.0=py37_0
242
+ - soupsieve=1.9.3=py37_0
243
+ - sphinx=2.2.0=py_0
244
+ - sphinxcontrib=1.0=py37_1
245
+ - sphinxcontrib-applehelp=1.0.1=py_0
246
+ - sphinxcontrib-devhelp=1.0.1=py_0
247
+ - sphinxcontrib-htmlhelp=1.0.2=py_0
248
+ - sphinxcontrib-jsmath=1.0.1=py_0
249
+ - sphinxcontrib-qthelp=1.0.2=py_0
250
+ - sphinxcontrib-serializinghtml=1.1.3=py_0
251
+ - sphinxcontrib-websupport=1.1.2=py_0
252
+ - spyder=3.3.6=py37_0
253
+ - spyder-kernels=0.5.2=py37_0
254
+ - sqlalchemy=1.3.9=py37h7b6447c_0
255
+ - sqlite=3.30.0=h7b6447c_0
256
+ - statsmodels=0.10.1=py37hdd07704_0
257
+ - sympy=1.4=py37_0
258
+ - tbb=2019.4=hfd86e86_0
259
+ - tblib=1.4.0=py_0
260
+ - terminado=0.8.2=py37_0
261
+ - testpath=0.4.2=py37_0
262
+ - tk=8.6.8=hbc83047_0
263
+ - toolz=0.10.0=py_0
264
+ - tornado=6.0.3=py37h7b6447c_0
265
+ - traitlets=4.3.3=py37_0
266
+ - unicodecsv=0.14.1=py37_0
267
+ - unixodbc=2.3.7=h14c3975_0
268
+ - wcwidth=0.1.7=py37_0
269
+ - webencodings=0.5.1=py37_1
270
+ - werkzeug=0.16.0=py_0
271
+ - wheel=0.33.6=py37_0
272
+ - widgetsnbextension=3.5.1=py37_0
273
+ - wrapt=1.11.2=py37h7b6447c_0
274
+ - wurlitzer=1.0.3=py37_0
275
+ - xlrd=1.2.0=py37_0
276
+ - xlsxwriter=1.2.1=py_0
277
+ - xlwt=1.3.0=py37_0
278
+ - xz=5.2.4=h14c3975_4
279
+ - yaml=0.1.7=had09818_2
280
+ - zeromq=4.3.1=he6710b0_3
281
+ - zict=1.0.0=py_0
282
+ - zipp=0.6.0=py_0
283
+ - zlib=1.2.11=h7b6447c_3
284
+ - zstd=1.3.7=h0b5b093_0
285
+ - pip:
286
+ - absl-py==0.12.0
287
+ - aiohttp==3.7.4.post0
288
+ - albumentations==0.5.2
289
+ - alembic==1.4.1
290
+ - arrow==0.17.0
291
+ - async-timeout==3.0.1
292
+ - b2sdk==1.4.0
293
+ - boto3==1.17.36
294
+ - botocore==1.20.36
295
+ - cachetools==4.2.1
296
+ - colour-demosaicing==0.1.6
297
+ - colour-science==0.3.16
298
+ - configparser==5.0.0
299
+ - databricks-cli==0.10.0
300
+ - docker==4.2.0
301
+ - docopt==0.6.2
302
+ - efficientnet-pytorch==0.6.3
303
+ - fsspec==0.8.7
304
+ - funcsigs==1.0.2
305
+ - gitdb==4.0.4
306
+ - gitpython==3.1.1
307
+ - google-auth==1.28.0
308
+ - google-auth-oauthlib==0.4.3
309
+ - gorilla==0.3.0
310
+ - grpcio==1.36.1
311
+ - gunicorn==20.0.4
312
+ - imgaug==0.4.0
313
+ - importlib-metadata==3.7.3
314
+ - jmespath==0.10.0
315
+ - logfury==0.1.2
316
+ - mako==1.1.2
317
+ - markdown==3.3.4
318
+ - mlflow==1.14.1
319
+ - multidict==5.1.0
320
+ - munch==2.5.0
321
+ - oauthlib==3.1.0
322
+ - opencv-python==4.5.1.48
323
+ - opencv-python-headless==4.5.1.48
324
+ - pandas==1.2.3
325
+ - pillow==8.1.2
326
+ - pipreqs==0.4.10
327
+ - plotly==4.14.3
328
+ - pretrainedmodels==0.7.4
329
+ - prettytable==2.1.0
330
+ - prometheus-flask-exporter==0.13.0
331
+ - protobuf==3.11.3
332
+ - pyasn1==0.4.8
333
+ - pyasn1-modules==0.2.8
334
+ - python-editor==1.0.4
335
+ - pytorch-lightning==1.2.5
336
+ - pywavelets==1.1.1
337
+ - querystring-parser==1.2.4
338
+ - rawpy==0.16.0
339
+ - requests-oauthlib==1.3.0
340
+ - retrying==1.3.3
341
+ - rsa==4.7.2
342
+ - s3transfer==0.3.6
343
+ - scikit-image==0.18.1
344
+ - segmentation-models-pytorch==0.1.3
345
+ - shapely==1.7.1
346
+ - simplejson==3.17.0
347
+ - smmap==3.0.2
348
+ - sqlparse==0.3.1
349
+ - tabulate==0.8.7
350
+ - tensorboard==2.4.1
351
+ - tensorboard-plugin-wit==1.8.0
352
+ - tifffile==2021.3.17
353
+ - timm==0.3.2
354
+ - torch==1.8.0
355
+ - torchmetrics==0.2.0
356
+ - torchvision==0.9.0
357
+ - tqdm==4.59.0
358
+ - typing-extensions==3.7.4.3
359
+ - urllib3==1.25.11
360
+ - websocket-client==0.57.0
361
+ - yarg==0.1.9
362
+ - yarl==1.6.3
363
+ prefix: /home/nobis/anaconda3/envs/perturbed
processingpipeline/numpy_static_pipeline_show.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64edef77495ab24143430e7a5d880b6f211568371f37eab03e1b32fb2f5b8015
3
+ size 1906586
processingpipeline/pipeline.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Raw Image Pipeline
3
+ """
4
+ __author__ = "Marco Aversa"
5
+
6
+ import numpy as np
7
+
8
+ from rawpy import * # XXX: no * imports!
9
+ from scipy import ndimage
10
+ from scipy import fftpack
11
+ from scipy.signal import convolve2d
12
+
13
+ from skimage.filters import unsharp_mask
14
+ from skimage.color import rgb2yuv, yuv2rgb, rgb2hsv, hsv2rgb
15
+ from skimage.restoration import denoise_tv_chambolle, denoise_tv_bregman, denoise_nl_means, denoise_bilateral, denoise_wavelet, estimate_sigma
16
+
17
+ import matplotlib.pyplot as plt
18
+
19
+ from colour_demosaicing import (demosaicing_CFA_Bayer_bilinear,
20
+ demosaicing_CFA_Bayer_Malvar2004,
21
+ demosaicing_CFA_Bayer_Menon2007)
22
+
23
+ import torch
24
+ import numpy as np
25
+
26
+ from utils.dataset import Subset
27
+ from torch.utils.data import DataLoader
28
+
29
+ from colour_demosaicing import (demosaicing_CFA_Bayer_bilinear,
30
+ demosaicing_CFA_Bayer_Malvar2004,
31
+ demosaicing_CFA_Bayer_Menon2007)
32
+
33
+ import matplotlib.pyplot as plt
34
+
35
+
36
+ class RawProcessingPipeline(object):
37
+
38
+ """Applies the raw-processing pipeline from pipeline.py"""
39
+
40
+ def __init__(self, camera_parameters, debayer='bilinear', sharpening='unsharp_masking', denoising='gaussian'):
41
+ '''
42
+ Args:
43
+ camera_parameters (tuple): (black_level, white_balance, colour_matrix)
44
+ debayer (str): specifies the algorithm used as debayer; choose from {'bilinear','malvar2004','menon2007'}
45
+ sharpening (str): specifies the algorithm used for sharpening; choose from {'sharpening_filter','unsharp_masking'}
46
+ denoising (str): specifies the algorithm used for denoising; choose from choose from {'gaussian_denoising','median_denoising','fft_denoising'}
47
+ '''
48
+
49
+ self.camera_parameters = camera_parameters
50
+
51
+ self.debayer = debayer
52
+ self.sharpening = sharpening
53
+ self.denoising = denoising
54
+
55
+ def __call__(self, img):
56
+ """
57
+ Args:
58
+ img (ndarry of dtype float.32): image of size (H,W)
59
+ return:
60
+ img (tensor of dtype float): image of size (3,H,W)
61
+ """
62
+ black_level, white_balance, colour_matrix = self.camera_parameters
63
+ img = processing(img, black_level, white_balance, colour_matrix,
64
+ debayer=self.debayer, sharpening=self.sharpening, denoising=self.denoising)
65
+ img = img.transpose(2, 0, 1)
66
+
67
+ return torch.Tensor(img)
68
+
69
+
70
+ def processing(img, black_level, white_balance, colour_matrix, debayer="bilinear", sharpening="unsharp_masking",
71
+ sharp_radius=1.0, sharp_amount=1.0, denoising="median_filter", median_kernel_size=3,
72
+ gaussian_sigma=0.5, fft_fraction=0.3, weight_chambolle=0.01, weight_bregman=100,
73
+ sigma_bilateral=0.6, gamma=2.2, bits=16):
74
+ """Apply pipeline on a raw image
75
+
76
+ Args:
77
+ rawImg (ndarray): raw image
78
+ debayer (str): debayer algorithm
79
+ white_balance (None, ndarray): white balance array (if None it will take the default camera white balance array)
80
+ colour_matrix (None, ndarray): colour matrix (if None it will take the default camera colour matrix) - Size: 3x3
81
+ gamma (float): exponent for the non linear gamma correction.
82
+
83
+ Returns:
84
+ img (ndarray): post-processed image
85
+
86
+ """
87
+
88
+ # Remove Black Level
89
+ img = remove_blacklv(img, black_level)
90
+
91
+ # Apply demosaicing - We don't have access to these 3 functions
92
+ if debayer == "bilinear":
93
+ img = demosaicing_CFA_Bayer_bilinear(img)
94
+ if debayer == "malvar2004":
95
+ img = demosaicing_CFA_Bayer_Malvar2004(img)
96
+ if debayer == "menon2007":
97
+ img = demosaicing_CFA_Bayer_Menon2007(img)
98
+
99
+ # White Balance Correction
100
+
101
+ # Sunny images white balance array -> 2<r<2.8, g=1.0, 1.3<b<1.6
102
+ # Tungsten images white balance array -> 1.3<r<1.7, g=1.0, 2.2<b<2.8
103
+ # Shade images white balance array -> 2.4<r<3.2, g=1.0, 1.1<b<1.3
104
+
105
+ img = wb_correction(img, white_balance)
106
+
107
+ # Colour Correction
108
+ img = colour_correction(img, colour_matrix)
109
+
110
+ # Sharpening
111
+ if sharpening == "sharpening_filter": # Fixed sharpening
112
+ img = sharpening_filter(img)
113
+ if sharpening == "unsharp_masking": # Higher is radius and amount, higher is the sharpening
114
+ img = unsharp_masking(img, radius=sharp_radius, amount=sharp_amount, multichannel=True)
115
+
116
+ # Denoising
117
+ if denoising == "median_denoising":
118
+ img = median_denoising(img, size=median_kernel_size)
119
+ if denoising == "gaussian_denoising":
120
+ img = gaussian_denoising(img, sigma=gaussian_sigma)
121
+ if denoising == "fft_denoising": # fft_fraction = [0.0001,0.5]
122
+ img = fft_denoising(img, keep_fraction=fft_fraction, row_cut=False, column_cut=True)
123
+
124
+ # We don't have access to these 3 functions
125
+ if denoising == "tv_chambolle": # lower is weight, less is the denoising
126
+ img = denoise_tv_chambolle(img, weight=weight_chambolle, eps=0.0002, n_iter_max=200, multichannel=True)
127
+ if denoising == "tv_bregman": # lower is weight, more is the denoising
128
+ img = denoise_tv_bregman(img, weight=weight_bregman, max_iter=100,
129
+ eps=0.001, isotropic=True, multichannel=True)
130
+ # if denoising == "wavelet":
131
+ # img = denoise_wavelet(img.copy(), sigma=None, wavelet='db1', mode='soft', wavelet_levels=None, multichannel=True,
132
+ # convert2ycbcr=False, method='BayesShrink', rescale_sigma=True)
133
+ if denoising == "bilateral": # higher is sigma_spatial, more is the denoising
134
+ img = denoise_bilateral(img, win_size=None, sigma_color=None, sigma_spatial=sigma_bilateral,
135
+ bins=10000, mode='constant', cval=0, multichannel=True)
136
+
137
+ # Gamma Correction
138
+ img = np.clip(img, 0, 1)
139
+ img = adjust_gamma(img, gamma=gamma)
140
+
141
+ return img
142
+
143
+
144
+ def get_camera_parameters(rawpyImg):
145
+ black_level = rawpyImg.black_level_per_channel
146
+ white_balance = rawpyImg.camera_whitebalance[:3]
147
+ colour_matrix = rawpyImg.color_matrix[:, :3].flatten().tolist()
148
+
149
+ return black_level, white_balance, colour_matrix
150
+
151
+
152
+ def remove_blacklv(rawImg, black_level):
153
+ rawImg[0::2, 0::2] -= black_level[0] # R
154
+ rawImg[0::2, 1::2] -= black_level[1] # G
155
+ rawImg[1::2, 0::2] -= black_level[2] # G
156
+ rawImg[1::2, 1::2] -= black_level[3] # B
157
+
158
+ return rawImg
159
+
160
+
161
+ def wb_correction(img, white_balance):
162
+ return img * white_balance
163
+
164
+
165
+ def colour_correction(img, colour_matrix):
166
+ colour_matrix = np.array(colour_matrix).reshape(3, 3)
167
+ return np.einsum('ijk,lk->ijl', img, colour_matrix)
168
+
169
+
170
+ def unsharp_masking(img, radius=1.0, amount=1.0,
171
+ multichannel=False, preserve_range=True):
172
+
173
+ img = rgb2yuv(img)
174
+ img[:, :, 0] = unsharp_mask(img[:, :, 0], radius=radius, amount=amount,
175
+ multichannel=multichannel, preserve_range=preserve_range)
176
+ img = yuv2rgb(img)
177
+ return img
178
+
179
+
180
+ def sharpening_filter(image, iterations=1, kernel=np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])):
181
+
182
+ # https://towardsdatascience.com/image-processing-with-python-blurring-and-sharpening-for-beginners-3bcebec0583a
183
+
184
+ img_yuv = rgb2yuv(image)
185
+
186
+ for i in range(iterations):
187
+ img_yuv[:, :, 0] = convolve2d(img_yuv[:, :, 0], kernel, 'same', boundary='fill', fillvalue=0)
188
+
189
+ final_image = yuv2rgb(img_yuv)
190
+
191
+ return final_image
192
+
193
+
194
+ def median_denoising(img, size=3):
195
+
196
+ img = rgb2yuv(img)
197
+ img[:, :, 0] = ndimage.median_filter(img[:, :, 0], size)
198
+ img = yuv2rgb(img)
199
+
200
+ return img
201
+
202
+
203
+ def gaussian_denoising(img, sigma=0.5):
204
+
205
+ img = rgb2yuv(img)
206
+ img[:, :, 0] = ndimage.gaussian_filter(img[:, :, 0], sigma)
207
+ img = yuv2rgb(img)
208
+
209
+ return img
210
+
211
+
212
+ def fft_denoising(img, keep_fraction=0.3, row_cut=False, column_cut=True):
213
+ """ keep_fraction = 0.5 --> same image as input
214
+ keep_fraction --> 0 --> remove all details """
215
+ # http://scipy-lectures.org/intro/scipy/auto_examples/solutions/plot_fft_image_denoise.html
216
+
217
+ im_fft = fftpack.fft2(img)
218
+
219
+ # Call ff a copy of the original transform. Numpy arrays have a copy
220
+ # method for this purpose.
221
+ im_fft2 = im_fft
222
+
223
+ # Set r and c to be the number of rows and columns of the array.
224
+ r, c, _ = im_fft2.shape
225
+
226
+ # Set to zero all rows with indices between r*keep_fraction and r*(1-keep_fraction):
227
+ if row_cut == True:
228
+ im_fft2[int(r * keep_fraction):int(r * (1 - keep_fraction))] = 0
229
+
230
+ # Similarly with the columns:
231
+ if column_cut == True:
232
+ im_fft2[:, int(c * keep_fraction):int(c * (1 - keep_fraction))] = 0
233
+
234
+ # Reconstruct the denoised image from the filtered spectrum, keep only the
235
+ # real part for display.
236
+ im_new = fftpack.ifft2(im_fft2).real
237
+
238
+ return im_new
239
+
240
+
241
+ def adjust_gamma(img, gamma=1.0):
242
+ invGamma = 1.0 / gamma
243
+ img = (img ** invGamma)
244
+ return img
245
+
246
+
247
+ def show_img(img, title="no_title", size=12, histo=True, bins=300, bits=16, x_range=-1):
248
+ """Plot image and its histogram
249
+
250
+ Args:
251
+ img (ndarray): image to plot
252
+ title (str): title of the plot
253
+ histo (bool): True - Plot histrograms per channel of the image. False - Plot the curve of histogram in a continue way
254
+ bins (int): number of bins of the histogram
255
+ size (int): figure size
256
+ bits (int): number of bits per pixel in the ndarray
257
+ x_range (list): maximum x range of the histogram (if -1 it will be take all x values)
258
+ """
259
+ shape = img.shape
260
+
261
+ fig = plt.figure(figsize=(size, size))
262
+
263
+ # show original image
264
+ fig.add_subplot(221)
265
+ if len(shape) > 2 and img.max() > 255:
266
+ img_to_show = (img.copy() * 255. / (2**bits - 1)).astype(int)
267
+ else:
268
+ img_to_show = img.copy().astype(int)
269
+ plt.imshow(img_to_show)
270
+ if title != "no_title":
271
+ plt.title(title)
272
+
273
+ fig.add_subplot(222)
274
+
275
+ if len(shape) > 2:
276
+ if histo == True:
277
+ plt.hist(img[:, :, 0].flatten(), bins=bins, label="Channel1", color="red", alpha=0.5)
278
+ plt.hist(img[:, :, 1].flatten(), bins=bins, label="Channel2", color="green", alpha=0.5)
279
+ plt.hist(img[:, :, 2].flatten(), bins=bins, label="Channel3", color="blue", alpha=0.5)
280
+ if x_range != -1:
281
+ plt.xlim([x_range[0], x_range[1]])
282
+ else:
283
+ h1, b1 = np.histogram(img[:, :, 0].flatten(), bins=bins)
284
+ h2, b2 = np.histogram(img[:, :, 1].flatten(), bins=bins)
285
+ h3, b3 = np.histogram(img[:, :, 2].flatten(), bins=bins)
286
+ plt.plot(b1[:-1], h1, label="Channel1", color="red", alpha=0.5)
287
+ plt.plot(b2[:-1], h2, label="Channel2", color="green", alpha=0.5)
288
+ plt.plot(b3[:-1], h3, label="Channel3", color="blue", alpha=0.5)
289
+
290
+ plt.legend()
291
+ else:
292
+ if histo == True:
293
+ plt.hist(img.flatten(), bins=bins)
294
+ if x_range != -1:
295
+ plt.xlim([x_range[0], x_range[1]])
296
+ else:
297
+ h, b = np.histogram(img.flatten(), bins=bins)
298
+ plt.plot(b[:-1], h)
299
+
300
+ plt.xlabel("Intensities")
301
+ plt.ylabel("Counts")
302
+
303
+ plt.show()
304
+
305
+
306
+ def get_statistics(dataset, train_indices, transform=None):
307
+ """Calculates the mean and the standard deviation of a given sub train set of dataset
308
+
309
+ Args:
310
+ dataset (Subset of DroneDataset):
311
+ train_indices (tensor): indicies correponding to a subset of the dataset
312
+ transform (Compose): list of transformations compatible with Compose to be applied before calculations
313
+ return:
314
+ mean (tensor of dtype float): size (C,1,1)
315
+ std (tensor of dtype float): size (C,1,1)
316
+ """
317
+
318
+ trainset = Subset(dataset, indices=train_indices, transform=transform)
319
+ dataloader = DataLoader(trainset, batch_size=len(trainset), shuffle=False)
320
+ dataiter = iter(dataloader)
321
+
322
+ images, labels = dataiter.next()
323
+
324
+ if len(images.shape) == 3:
325
+ mean, std = torch.mean(images, axis=(0, 1, 2)), torch.std(images, axis=(0, 1, 2))
326
+ return mean, std
327
+ else:
328
+ mean, std = torch.mean(images, axis=(0, 2, 3))[:, None, None], torch.std(images, axis=(0, 2, 3))[:, None, None]
329
+ return mean, std
processingpipeline/torch_pipeline.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from numpy.lib.function_base import interp
3
+ import torch
4
+ import torch.nn as nn
5
+ if not os.path.exists('README.md'):
6
+ os.chdir('..')
7
+
8
+ from processingpipeline.pipeline import processing as default_processing
9
+ from utils.base import np2torch, torch2np
10
+
11
+ import segmentation_models_pytorch as smp
12
+
13
+ from utils.debug import debug
14
+
15
+ K_G = torch.Tensor([[0, 1, 0],
16
+ [1, 4, 1],
17
+ [0, 1, 0]]) / 4
18
+
19
+ K_RB = torch.Tensor([[1, 2, 1],
20
+ [2, 4, 2],
21
+ [1, 2, 1]]) / 4
22
+
23
+ M_RGB_2_YUV = torch.Tensor([[0.299, 0.587, 0.114],
24
+ [-0.14714119, -0.28886916, 0.43601035],
25
+ [0.61497538, -0.51496512, -0.10001026]])
26
+ M_YUV_2_RGB = torch.Tensor([[1.0000000000e+00, -4.1827794561e-09, 1.1398830414e+00],
27
+ [1.0000000000e+00, -3.9464232326e-01, -5.8062183857e-01],
28
+ [1.0000000000e+00, 2.0320618153e+00, -1.2232658220e-09]])
29
+
30
+ K_BLUR = torch.Tensor([[6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08],
31
+ [2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05],
32
+ [2.0755e-04, 8.3731e-02, 6.1869e-01, 8.3731e-02, 2.0755e-04],
33
+ [2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05],
34
+ [6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08]])
35
+ K_SHARP = torch.Tensor([[0, -1, 0],
36
+ [-1, 5, -1],
37
+ [0, -1, 0]])
38
+ DEFAULT_CAMERA_PARAMS = (
39
+ [0., 0., 0., 0.],
40
+ [1., 1., 1.],
41
+ [1., 0., 0., 0., 1., 0., 0., 0., 1.],
42
+ )
43
+
44
+
45
+ class RawToRGB(nn.Module):
46
+ def __init__(self, reduce_size=True, out_channels=3, track_stages=False, normalize_mosaic=None):
47
+ super().__init__()
48
+ self.stages = None
49
+ self.buffer = None
50
+ self.reduce_size = reduce_size
51
+ self.out_channels = out_channels
52
+ self.track_stages = track_stages
53
+ self.normalize_mosaic = normalize_mosaic
54
+
55
+ def forward(self, raw):
56
+ self.stages = {}
57
+ self.buffer = {}
58
+
59
+ rgb = raw2rgb(raw, reduce_size=self.reduce_size, out_channels=self.out_channels)
60
+ self.stages['demosaic'] = rgb
61
+ if self.normalize_mosaic:
62
+ rgb = self.normalize_mosaic(rgb)
63
+
64
+ if self.track_stages and raw.requires_grad:
65
+ for stage in self.stages.values():
66
+ stage.retain_grad()
67
+
68
+ self.buffer['processed_rgb'] = rgb
69
+
70
+ return rgb
71
+
72
+
73
+ class NNProcessing(nn.Module):
74
+ def __init__(self, track_stages=False, normalize_mosaic=None, batch_norm_output=True):
75
+ super().__init__()
76
+ self.stages = None
77
+ self.buffer = None
78
+ self.track_stages = track_stages
79
+ self.model = smp.UnetPlusPlus(
80
+ encoder_name='resnet34',
81
+ encoder_depth=3,
82
+ decoder_channels=[256, 128, 64],
83
+ in_channels=3,
84
+ classes=3,
85
+ )
86
+ self.batch_norm = None if not batch_norm_output else nn.BatchNorm2d(3)
87
+ self.normalize_mosaic = normalize_mosaic
88
+
89
+ def forward(self, raw):
90
+ self.stages = {}
91
+ self.buffer = {}
92
+ # self.stages['raw'] = raw
93
+ rgb = raw2rgb(raw)
94
+ if self.normalize_mosaic:
95
+ rgb = self.normalize_mosaic(rgb)
96
+ self.stages['demosaic'] = rgb
97
+ rgb = self.model(rgb)
98
+ if self.batch_norm is not None:
99
+ rgb = self.batch_norm(rgb)
100
+ self.stages['rgb'] = rgb
101
+
102
+ if self.track_stages and raw.requires_grad:
103
+ for stage in self.stages.values():
104
+ stage.retain_grad()
105
+
106
+ self.buffer['processed_rgb'] = rgb
107
+
108
+ return rgb
109
+
110
+
111
+ class ParametrizedProcessing(nn.Module):
112
+ def __init__(self, camera_parameters, track_stages=False, batch_norm_output=True, noise_layer=False):
113
+ super().__init__()
114
+ self.stages = None
115
+ self.buffer = None
116
+ self.track_stages = track_stages
117
+
118
+ black_level, white_balance, colour_matrix = camera_parameters
119
+ self.register_buffer('black_level', torch.as_tensor(black_level))
120
+ self.register_buffer('colour_correction',
121
+ torch.as_tensor(white_balance).reshape(1, 3)
122
+ * torch.as_tensor(colour_matrix).reshape(3, 3))
123
+ self.register_buffer('M_RGB_2_YUV', M_RGB_2_YUV.clone())
124
+ self.register_buffer('M_YUV_2_RGB', M_YUV_2_RGB.clone())
125
+
126
+ self.gamma_correct = nn.Parameter(torch.Tensor([2.2]))
127
+
128
+ self.debayer = Debayer()
129
+
130
+ self.sharpening_filter = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
131
+ self.sharpening_filter.weight.data[0][0] = K_SHARP.clone()
132
+
133
+ self.gaussian_blur = nn.Conv2d(1, 1, kernel_size=5, padding=2, padding_mode='reflect', bias=False)
134
+ self.gaussian_blur.weight.data[0][0] = K_BLUR.clone()
135
+
136
+ self.batch_norm = nn.BatchNorm2d(3) if batch_norm_output else None
137
+
138
+ # if noise_layer:
139
+ # for param in self.parameters():
140
+ # param.requires_grad = False
141
+
142
+ self.additive_layer = nn.Parameter(0.001 * torch.randn((1, 3, 256, 256))
143
+ ) if noise_layer else None # XXX: can this be 0?
144
+
145
+ def forward(self, raw):
146
+ assert raw.ndim == 3, f"needs dims (B, H, W), got {raw.shape}"
147
+
148
+ self.stages = {}
149
+ self.buffer = {}
150
+
151
+ # self.stages['raw'] = raw
152
+
153
+ rgb = raw2rgb(raw, black_level=self.black_level, reduce_size=False)
154
+ rgb = rgb.contiguous()
155
+ self.stages['demosaic'] = rgb
156
+
157
+ rgb = self.debayer(rgb)
158
+ # self.stages['debayer'] = rgb
159
+
160
+ rgb = torch.einsum('bchw,kc->bkhw', rgb, self.colour_correction).contiguous()
161
+ self.stages['color_correct'] = rgb
162
+
163
+ yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous()
164
+ yuv[:, [0], ...] = self.sharpening_filter(yuv[:, [0], ...])
165
+
166
+ if self.track_stages: # keep stage in computational graph for grad information
167
+ rgb = torch.einsum('bchw,kc->bkhw', yuv.clone(), self.M_YUV_2_RGB).contiguous()
168
+ self.stages['sharpening'] = rgb
169
+ yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous()
170
+
171
+ yuv[:, [0], ...] = self.gaussian_blur(yuv[:, [0], ...])
172
+ rgb = torch.einsum('bchw,kc->bkhw', yuv, self.M_YUV_2_RGB).contiguous()
173
+ self.stages['gaussian'] = rgb
174
+
175
+ rgb = torch.clip(rgb, 1e-5, 1)
176
+ self.stages['clipped'] = rgb
177
+
178
+ rgb = torch.exp((1 / self.gamma_correct) * torch.log(rgb))
179
+ self.stages['gamma_correct'] = rgb
180
+
181
+ if self.additive_layer is not None:
182
+ # rgb = rgb + 0 * self.additive_layer
183
+ rgb = rgb + self.additive_layer
184
+ self.stages['noise'] = rgb
185
+
186
+ if self.batch_norm is not None:
187
+ rgb = self.batch_norm(rgb)
188
+
189
+ if self.track_stages and raw.requires_grad:
190
+ for stage in self.stages.values():
191
+ stage.retain_grad()
192
+
193
+ self.buffer['processed_rgb'] = rgb
194
+
195
+ return rgb
196
+
197
+
198
+ class Debayer(nn.Conv2d):
199
+ def __init__(self):
200
+ super().__init__(3, 3, kernel_size=3, padding=1, padding_mode='reflect', bias=False) # default_pipeline uses 'replicate'
201
+ self.weight.data.fill_(0)
202
+ self.weight.data[0, 0] = K_RB.clone()
203
+ self.weight.data[1, 1] = K_G.clone()
204
+ self.weight.data[2, 2] = K_RB.clone()
205
+
206
+
207
+ def raw2rgb(raw, black_level=None, reduce_size=True, out_channels=3):
208
+ """transform raw image with 1 channel to rgb with 3 channels
209
+ Args:
210
+ raw (Tensor): raw Tensor of shape (B, H, W)
211
+ black_level (iterable, optional): RGGB black level values to subtract
212
+ reduce_size (bool, optional): if False, the output image will have the same height and width
213
+ as the raw input, i.e. (B, C, H, W), empty values are filled with zeros.
214
+ if True, the output dimensions are reduced by half (B, C, H//2, W//2),
215
+ the two green channels are averaged.
216
+ out_channels (int, optional): number of output channels. One of {3, 4}.
217
+ """
218
+ assert out_channels in [3, 4]
219
+ if black_level is None:
220
+ black_level = [0, 0, 0, 0]
221
+ Bch, H, W = raw.shape
222
+ R = raw[:, 0::2, 0::2] - black_level[0] # R
223
+ G1 = raw[:, 0::2, 1::2] - black_level[1] # G
224
+ G2 = raw[:, 1::2, 0::2] - black_level[2] # G
225
+ B = raw[:, 1::2, 1::2] - black_level[3] # B
226
+ if reduce_size:
227
+ rgb = torch.zeros((Bch, out_channels, H // 2, W // 2), device=raw.device)
228
+ if out_channels == 3:
229
+ rgb[:, 0, :, :] = R
230
+ rgb[:, 1, :, :] = (G1 + G2) / 2
231
+ rgb[:, 2, :, :] = B
232
+ elif out_channels == 4:
233
+ rgb[:, 0, :, :] = R
234
+ rgb[:, 1, :, :] = G1
235
+ rgb[:, 2, :, :] = G2
236
+ rgb[:, 3, :, :] = B
237
+ else:
238
+ rgb = torch.zeros((Bch, out_channels, H, W), device=raw.device)
239
+ if out_channels == 3:
240
+ rgb[:, 0, 0::2, 0::2] = R
241
+ rgb[:, 1, 0::2, 1::2] = G1
242
+ rgb[:, 1, 1::2, 0::2] = G2
243
+ rgb[:, 2, 1::2, 1::2] = B
244
+ elif out_channels == 4:
245
+ rgb[:, 0, 0::2, 0::2] = R
246
+ rgb[:, 1, 0::2, 1::2] = G1
247
+ rgb[:, 2, 1::2, 0::2] = G2
248
+ rgb[:, 3, 1::2, 1::2] = B
249
+ return rgb
250
+
251
+
252
+ # pipeline validation
253
+ if __name__ == "__main__":
254
+
255
+ import torch
256
+ import numpy as np
257
+
258
+ if not os.path.exists('README.md'):
259
+ os.chdir('..')
260
+
261
+ import matplotlib.pyplot as plt
262
+ from utils.dataset import get_dataset
263
+ from utils.base import np2torch, torch2np
264
+
265
+ from utils.debug import debug
266
+ from processingpipeline.pipeline import processing as default_processing
267
+
268
+ raw_dataset = get_dataset('DS')
269
+ loader = torch.utils.data.DataLoader(raw_dataset, batch_size=1)
270
+ batch_raw, batch_mask = next(iter(loader))
271
+
272
+ # torch proc
273
+ camera_parameters = raw_dataset.camera_parameters
274
+ black_level = camera_parameters[0]
275
+
276
+ proc = ParametrizedProcessing(camera_parameters)
277
+
278
+ batch_rgb = proc(batch_raw)
279
+ rgb = batch_rgb[0]
280
+
281
+ # numpy proc
282
+ raw_img = batch_raw[0]
283
+ numpy_raw = torch2np(raw_img)
284
+
285
+ default_rgb = default_processing(numpy_raw, *camera_parameters,
286
+ sharpening='sharpening_filter', denoising='gaussian_denoising')
287
+
288
+ rgb_valid = np2torch(default_rgb)
289
+
290
+ print("pipeline norm difference:", (rgb - rgb_valid).norm().item())
291
+
292
+ rgb_mosaic = raw2rgb(batch_raw, reduce_size=False).squeeze()
293
+ rgb_reduced = raw2rgb(batch_raw, reduce_size=True).squeeze()
294
+
295
+ plt.figure(figsize=(16, 8))
296
+ plt.subplot(151)
297
+ plt.title('Raw')
298
+ plt.imshow(torch2np(raw_img))
299
+ plt.subplot(152)
300
+ plt.title('RGB Mosaic')
301
+ plt.imshow(torch2np(rgb_mosaic))
302
+ plt.subplot(153)
303
+ plt.title('RGB Reduced')
304
+ plt.imshow(torch2np(rgb_reduced))
305
+ plt.subplot(154)
306
+ plt.title('Torch Pipeline')
307
+ plt.imshow(torch2np(rgb))
308
+ plt.subplot(155)
309
+ plt.title('Default Pipeline')
310
+ plt.imshow(torch2np(rgb_valid))
311
+ plt.show()
312
+
313
+ # assert rgb.allclose(rgb_valid)
sanity_checks_and_statistics.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53f62c6ce9a6656a31c3e0ae1deded2e4f9818cd891381dbe1030dd5edc5f278
3
+ size 6103871
show_classification_results.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62c5dbc4bb22ecd26bc691c1f574b6bcf07b7cd48f62668506955df3513afe55
3
+ size 10556940
show_results.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ datasets='Microscopy Drone'
4
+ augmentations='weak strong none'
5
+
6
+ for augment in $augmentations
7
+ do
8
+ for data in $datasets
9
+ do
10
+
11
+ python show_results.py \
12
+ --dataset $data \
13
+ --augmentation $augment \
14
+
15
+ done
16
+ done
train.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import copy
4
+ import argparse
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ import mlflow.pytorch
10
+ from torch.utils.data import DataLoader
11
+ from torchvision.models import resnet18
12
+ import torchvision.transforms as T
13
+ from pytorch_lightning.metrics.functional import accuracy
14
+ import pytorch_lightning as pl
15
+ from pytorch_lightning.callbacks import ModelCheckpoint
16
+
17
+ from utils.base import display_mlflow_run_info, str2bool, fetch_from_mlflow, get_name, data_loader_mean_and_std
18
+ from utils.debug import debug
19
+ from utils.augmentation import get_augmentation
20
+ from utils.dataset import Subset, get_dataset, k_fold
21
+
22
+ from processingpipeline.pipeline import RawProcessingPipeline
23
+ from processingpipeline.torch_pipeline import raw2rgb, RawToRGB, ParametrizedProcessing, NNProcessing
24
+
25
+ from models.classifier import log_tensor, resnet_model, LitModel, TrackImagesCallback
26
+
27
+ import segmentation_models_pytorch as smp
28
+
29
+ from utils.pytorch_ssim import SSIM
30
+
31
+ # args to set up task
32
+ parser = argparse.ArgumentParser(description="classification_task")
33
+ parser.add_argument("--tracking_uri", type=str,
34
+ default="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com", help='URI of the mlflow server on AWS')
35
+ parser.add_argument("--processor_uri", type=str, default=None,
36
+ help='URI of the processing model (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/processing-model)')
37
+ parser.add_argument("--classifier_uri", type=str, default=None,
38
+ help='URI of the net (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/prediction-model)')
39
+ parser.add_argument("--state_dict_uri", type=str,
40
+ default=None, help='URI of the indices you want to load (e.g. s3://mlflow-artifacts-601883093460/7/4326da05aca54107be8c554de0674a14/artifacts/training')
41
+
42
+ parser.add_argument("--experiment_name", type=str,
43
+ default='classification learnable pipeline', help='Specify the experiment you are running, e.g. end2end segmentation')
44
+ parser.add_argument("--run_name", type=str,
45
+ default='test run', help='Specify the name of your run')
46
+
47
+ parser.add_argument("--log_model", type=str2bool, default=True, help='Enables model logging')
48
+ parser.add_argument("--save_locally", action='store_true',
49
+ help='Model will be saved locally if action is taken') # TODO: bypass mlflow
50
+
51
+ parser.add_argument("--track_processing", action='store_true',
52
+ help='Save images after each trasformation of the pipeline for the test set')
53
+ parser.add_argument("--track_processing_gradients", action='store_true',
54
+ help='Save images of gradients after each trasformation of the pipeline for the test set')
55
+ parser.add_argument("--track_save_tensors", action='store_true',
56
+ help='Save the torch tensors after each trasformation of the pipeline for the test set')
57
+ parser.add_argument("--track_predictions", action='store_true',
58
+ help='Save images after each trasformation of the pipeline for the test set + input gradient')
59
+ parser.add_argument("--track_n_images", default=5,
60
+ help='Track the n first elements of dataset. Only used for args.track_processing=True')
61
+ parser.add_argument("--track_every_epoch", action='store_true', help='Track images every epoch or once after training')
62
+
63
+ # args to create dataset
64
+ parser.add_argument("--seed", type=int, default=1, help='Global seed')
65
+ parser.add_argument("--dataset", type=str, default='Microscopy',
66
+ choices=["Drone", "DroneSegmentation", "Microscopy"], help='Select dataset')
67
+
68
+ parser.add_argument("--n_splits", type=int, default=1, help='Number of splits used for training')
69
+ parser.add_argument("--train_size", type=float, default=0.8, help='Fraction of training points in dataset')
70
+
71
+ # args for training
72
+ parser.add_argument("--lr", type=float, default=1e-5, help="learning rate used for training")
73
+ parser.add_argument("--epochs", type=int, default=3, help="numper of epochs")
74
+ parser.add_argument("--batch_size", type=int, default=32, help="Training batch size")
75
+ parser.add_argument("--augmentation", type=str, default='none',
76
+ choices=["none", "weak", "strong"], help="Applies augmentation to training")
77
+ parser.add_argument("--augmentation_on_valid_epoch", action='store_true',
78
+ help='Track images every epoch or once after training') # TODO: implement, actually should be disabled by default for 'val' and 'test
79
+ parser.add_argument("--check_val_every_n_epoch", type=int, default=1)
80
+
81
+ # args to specify the processing
82
+ parser.add_argument("--processing_mode", type=str, default="parametrized",
83
+ choices=["parametrized", "static", "neural_network", "none"],
84
+ help="Which type of raw to rgb processing should be used")
85
+
86
+ # args to specify model
87
+ parser.add_argument("--classifier_network", type=str, default='ResNet18',
88
+ help='Type of pretrained network') # TODO: implement different choices
89
+ parser.add_argument("--classifier_pretrained", action='store_true',
90
+ help='Whether to use a pre-trained model or not')
91
+ parser.add_argument("--smp_encoder", type=str, default='resnet34', help='segmentation model encoder')
92
+
93
+ parser.add_argument("--freeze_processor", action='store_true', help="Freeze raw to rgb processing model weights")
94
+ parser.add_argument("--freeze_classifier", action='store_true', help="Freeze classification model weights")
95
+
96
+ # args to specify static pipeline transformations
97
+ parser.add_argument("--sp_debayer", type=str, default='bilinear',
98
+ choices=['bilinear', 'malvar2004', 'menon2007'], help="Specify algorithm used as debayer")
99
+ parser.add_argument("--sp_sharpening", type=str, default='sharpening_filter',
100
+ choices=['sharpening_filter', 'unsharp_masking'], help="Specify algorithm used for sharpening")
101
+ parser.add_argument("--sp_denoising", type=str, default='gaussian_denoising',
102
+ choices=['gaussian_denoising', 'median_denoising', 'fft_denoising'], help="Specify algorithm used for denoising")
103
+
104
+ # args to choose training mode
105
+ parser.add_argument("--adv_training", action='store_true', help="Enable adversarial training")
106
+ parser.add_argument("--adv_aux_weight", type=float, default=1, help="Weighting of the adversarial auxilliary loss")
107
+ parser.add_argument("--adv_aux_loss", type=str, default='ssim', choices=['l2', 'ssim'],
108
+ help="Type of adversarial auxilliary regularization loss")
109
+
110
+ parser.add_argument("--cache_downloaded_models", type=str2bool, default=True)
111
+
112
+ parser.add_argument('--test_run', action='store_true')
113
+
114
+ if 'ipykernel_launcher' in sys.argv[0]:
115
+ args = parser.parse_args([
116
+ '--dataset=Microscopy',
117
+ '--epochs=100',
118
+ '--augmentation=strong',
119
+ '--lr=1e-5',
120
+ '--freeze_processor',
121
+ # '--track_processing',
122
+ # '--test_run',
123
+ # '--track_predictions',
124
+ # '--track_every_epoch',
125
+ # '--adv_training',
126
+ # '--adv_aux_weight=100',
127
+ # '--adv_aux_loss=l2',
128
+ # '--log_model=',
129
+ ])
130
+ else:
131
+ args = parser.parse_args()
132
+
133
+
134
+ def run_train(args):
135
+
136
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
137
+ training_mode = 'adversarial' if args.adv_training else 'default'
138
+
139
+ # set tracking uri, this is the address of the mlflow server where light experimental data will be stored
140
+ mlflow.set_tracking_uri(args.tracking_uri)
141
+ mlflow.set_experiment(args.experiment_name)
142
+ os.environ["AWS_ACCESS_KEY_ID"] = #TODO: add your AWS access key if you want to write your results to our collaborative lab server
143
+ os.environ["AWS_SECRET_ACCESS_KEY"] = #TODO: add your AWS seceret access key if you want to write your results to our collaborative lab server
144
+
145
+ # dataset
146
+
147
+ dataset = get_dataset(args.dataset)
148
+
149
+ print(f'dataset: {type(dataset).__name__}[{len(dataset)}]')
150
+ print(f'task: {dataset.task}')
151
+ print(f'mode: {training_mode} training')
152
+ print(f'# cross-validation subsets: {args.n_splits}')
153
+ pl.seed_everything(args.seed)
154
+ idxs_kfold = k_fold(dataset, n_splits=args.n_splits, seed=args.seed, train_size=args.train_size)
155
+
156
+ with mlflow.start_run(run_name=args.run_name) as parent_run:
157
+
158
+ for k_iter, idxs in enumerate(idxs_kfold):
159
+
160
+ print(f"K_fold subset: {k_iter+1}/{args.n_splits}")
161
+
162
+ if args.processing_mode == 'static':
163
+ if args.dataset == "Drone" or args.dataset == "DroneSegmentation":
164
+ mean = torch.tensor([0.35, 0.36, 0.35])
165
+ std = torch.tensor([0.12, 0.11, 0.12])
166
+ elif args.dataset == "Microscopy":
167
+ mean = torch.tensor([0.91, 0.84, 0.94])
168
+ std = torch.tensor([0.08, 0.12, 0.05])
169
+
170
+ dataset.transform = T.Compose([RawProcessingPipeline(
171
+ camera_parameters=dataset.camera_parameters,
172
+ debayer=args.sp_debayer,
173
+ sharpening=args.sp_sharpening,
174
+ denoising=args.sp_denoising,
175
+ ), T.Normalize(mean, std)])
176
+ # XXX: Not clean
177
+
178
+ processor = nn.Identity()
179
+
180
+ if args.processor_uri is not None and args.processing_mode != 'none':
181
+ print('Fetching processor: ', end='')
182
+ model = fetch_from_mlflow(args.processor_uri, use_cache=args.cache_downloaded_models)
183
+ processor = model.processor
184
+ for param in processor.parameters():
185
+ param.requires_grad = True
186
+ model.processor = None
187
+ del model
188
+ else:
189
+ print(f'processing_mode: {args.processing_mode}')
190
+ normalize_mosaic = None # normalize after raw has been passed to raw2rgb
191
+ if args.dataset == "Microscopy":
192
+ mosaic_mean = [0.5663, 0.1401, 0.0731]
193
+ mosaic_std = [0.097, 0.0423, 0.008]
194
+ normalize_mosaic = T.Normalize(mosaic_mean, mosaic_std)
195
+
196
+ track_stages = args.track_processing or args.track_processing_gradients
197
+ if args.processing_mode == 'parametrized':
198
+ processor = ParametrizedProcessing(
199
+ camera_parameters=dataset.camera_parameters, track_stages=track_stages, batch_norm_output=True,
200
+ noise_layer=args.adv_training, # XXX: Remove?
201
+ )
202
+ elif args.processing_mode == 'neural_network':
203
+ processor = NNProcessing(track_stages=track_stages,
204
+ normalize_mosaic=normalize_mosaic, batch_norm_output=True)
205
+ elif args.processing_mode == 'none':
206
+ processor = RawToRGB(reduce_size=True, out_channels=3, track_stages=track_stages,
207
+ normalize_mosaic=normalize_mosaic)
208
+
209
+ if args.classifier_uri: # fetch classifier
210
+ print('Fetching classifier: ', end='')
211
+ model = fetch_from_mlflow(args.classifier_uri, use_cache=args.cache_downloaded_models)
212
+ classifier = model.classifier
213
+ model.classifier = None
214
+ del model
215
+ else:
216
+ if dataset.task == 'classification':
217
+ classifier = resnet_model(
218
+ model=resnet18,
219
+ pretrained=args.classifier_pretrained,
220
+ in_channels=3,
221
+ fc_out_features=len(dataset.classes)
222
+ )
223
+ else:
224
+ # XXX: add other network choices to args.smp_network (FPN) and args.network
225
+ classifier = smp.UnetPlusPlus(
226
+ encoder_name=args.smp_encoder,
227
+ encoder_depth=5,
228
+ encoder_weights='imagenet',
229
+ in_channels=3,
230
+ classes=1,
231
+ activation=None,
232
+ )
233
+
234
+ if args.freeze_processor and len(list(iter(processor.parameters()))) == 0:
235
+ print('Note: freezing processor without parameters.')
236
+ assert not (args.freeze_processor and args.freeze_classifier), 'Likely no parameters to train.'
237
+
238
+ if dataset.task == 'classification':
239
+ loss = nn.CrossEntropyLoss()
240
+ metrics = [accuracy]
241
+ else:
242
+ # loss = utils.base.smp_get_loss(args.smp_loss) # XXX: add other losses to args.smp_loss
243
+ loss = smp.losses.DiceLoss(mode='binary', from_logits=True)
244
+ metrics = [smp.utils.metrics.IoU()]
245
+
246
+ loss_aux = None
247
+
248
+ if args.adv_training:
249
+
250
+ assert args.processing_mode == 'parametrized', f"Processing mode ({args.processing_mode}) should be set to 'parametrized' for adversarial training"
251
+ assert args.freeze_classifier, "Classifier should be frozen for adversarial training"
252
+ assert not args.freeze_processor, "Processor should not be frozen for adversarial training"
253
+
254
+ processor_default = copy.deepcopy(processor)
255
+ processor_default.track_stages = False
256
+ processor_default.eval()
257
+ processor_default.to(DEVICE)
258
+ # debug(processor_default)
259
+
260
+ def l2_regularization(x, y):
261
+ return (x - y).norm()
262
+
263
+ if args.adv_aux_loss == 'l2':
264
+ regularization = l2_regularization
265
+ elif args.adv_aux_loss == 'ssim':
266
+ regularization = SSIM(window_size=11)
267
+ else:
268
+ NotImplementedError(args.adv_aux_loss)
269
+
270
+ class AuxLoss(nn.Module):
271
+ def __init__(self, loss_aux, weight=1):
272
+ super().__init__()
273
+ self.loss_aux = loss_aux
274
+ self.weight = weight
275
+
276
+ def forward(self, x):
277
+ x_reference = processor_default(x)
278
+ x_processed = processor.buffer['processed_rgb']
279
+ return self.weight * self.loss_aux(x_reference, x_processed)
280
+
281
+ class WeightedLoss(nn.Module):
282
+ def __init__(self, loss, weight=1):
283
+ super().__init__()
284
+ self.loss = loss
285
+ self.weight = weight
286
+
287
+ def forward(self, x, y):
288
+ return self.weight * self.loss(x, y)
289
+
290
+ def __repr__(self):
291
+ return f'{self.weight} * {get_name(self.loss)}'
292
+
293
+ loss = WeightedLoss(loss=nn.CrossEntropyLoss(), weight=-1)
294
+ # loss = WeightedLoss(loss=nn.CrossEntropyLoss(), weight=0)
295
+ loss_aux = AuxLoss(
296
+ loss_aux=regularization,
297
+ weight=args.adv_aux_weight,
298
+ )
299
+
300
+ augmentation = get_augmentation(args.augmentation)
301
+
302
+ model = LitModel(
303
+ classifier=classifier,
304
+ processor=processor,
305
+ loss=loss,
306
+ loss_aux=loss_aux,
307
+ adv_training=args.adv_training,
308
+ metrics=metrics,
309
+ augmentation=augmentation,
310
+ is_segmentation_task=dataset.task == 'segmentation',
311
+ freeze_classifier=args.freeze_classifier,
312
+ freeze_processor=args.freeze_processor,
313
+ )
314
+
315
+ # get train_set_dict
316
+ if args.state_dict_uri:
317
+ state_dict = mlflow.pytorch.load_state_dict(args.state_dict_uri)
318
+ train_indices = state_dict['train_indices']
319
+ valid_indices = state_dict['valid_indices']
320
+ else:
321
+ train_indices = idxs[0]
322
+ valid_indices = idxs[1]
323
+ state_dict = vars(args).copy()
324
+
325
+ track_indices = list(range(args.track_n_images))
326
+
327
+ if dataset.task == 'classification':
328
+ state_dict['classes'] = dataset.classes
329
+ state_dict['device'] = DEVICE
330
+ state_dict['train_indices'] = train_indices
331
+ state_dict['valid_indices'] = valid_indices
332
+ state_dict['elements in train set'] = len(train_indices)
333
+ state_dict['elements in test set'] = len(valid_indices)
334
+
335
+ if args.test_run:
336
+ train_indices = train_indices[:args.batch_size]
337
+ valid_indices = valid_indices[:args.batch_size]
338
+
339
+ train_set = Subset(dataset, indices=train_indices)
340
+ valid_set = Subset(dataset, indices=valid_indices)
341
+ track_set = Subset(dataset, indices=track_indices)
342
+
343
+ train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=16, shuffle=True)
344
+ valid_loader = DataLoader(valid_set, batch_size=args.batch_size, num_workers=16, shuffle=False)
345
+ track_loader = DataLoader(track_set, batch_size=args.batch_size, num_workers=16, shuffle=False)
346
+
347
+ with mlflow.start_run(run_name=f"{args.run_name}_{k_iter}", nested=True) as child_run:
348
+
349
+ #mlflow.pytorch.autolog(silent=True)
350
+
351
+ if k_iter == 0:
352
+ display_mlflow_run_info(child_run)
353
+
354
+ mlflow.pytorch.log_state_dict(state_dict, artifact_path=None)
355
+
356
+ hparams = {
357
+ 'dataset': args.dataset,
358
+ 'processing_mode': args.processing_mode,
359
+ 'training_mode': training_mode,
360
+ }
361
+ if training_mode == 'adversarial':
362
+ hparams['adv_aux_weight'] = args.adv_aux_weight
363
+ hparams['adv_aux_loss'] = args.adv_aux_loss
364
+
365
+ mlflow.log_params(hparams)
366
+
367
+ with open('results/state_dict.txt', 'w') as f:
368
+ f.write('python ' + ' '.join(sys.argv) + '\n')
369
+ f.write('\n'.join([f'{k}={v}' for k, v in state_dict.items()]))
370
+ mlflow.log_artifact('results/state_dict.txt', artifact_path=None)
371
+
372
+ mlf_logger = pl.loggers.MLFlowLogger(experiment_name=args.experiment_name,
373
+ tracking_uri=args.tracking_uri,)
374
+ mlf_logger._run_id = child_run.info.run_id
375
+
376
+ callbacks = []
377
+ if args.track_processing:
378
+ callbacks += [TrackImagesCallback(track_loader,
379
+ track_every_epoch=args.track_every_epoch,
380
+ track_processing=args.track_processing,
381
+ track_gradients=args.track_processing_gradients,
382
+ track_predictions=args.track_predictions,
383
+ save_tensors=args.track_save_tensors)]
384
+
385
+ #if True: #args.save_best:
386
+ # if dataset.task == 'classification':
387
+ #checkpoint_callback = ModelCheckpoint(pathmonitor="val_accuracy", mode='max')
388
+ # checkpoint_callback = ModelCheckpoint(dirpath=args.tracking_uri, save_top_k=1, verbose=True, monitor="val_accuracy", mode="max") #dirpath=args.tracking_uri,
389
+ # else:
390
+ # checkpoint_callback = ModelCheckpoint(monitor="val_iou_score")
391
+ #callbacks += [checkpoint_callback]
392
+
393
+ trainer = pl.Trainer(
394
+ gpus=1 if DEVICE == 'cuda' else 0,
395
+ min_epochs=args.epochs,
396
+ max_epochs=args.epochs,
397
+ logger=mlf_logger,
398
+ callbacks=callbacks,
399
+ check_val_every_n_epoch=args.check_val_every_n_epoch,
400
+ #checkpoint_callback=True,
401
+ )
402
+
403
+ if args.log_model:
404
+ mlflow.pytorch.autolog(log_every_n_epoch=10)
405
+ print(f'model_uri="{mlflow.get_artifact_uri()}/model"')
406
+
407
+ t = trainer.fit(
408
+ model,
409
+ train_dataloader=train_loader,
410
+ val_dataloaders=valid_loader,
411
+ )
412
+
413
+ # if args.adv_training:
414
+ # for (name, p1), p2 in zip(processor.named_parameters(), processor_default.cpu().parameters()):
415
+ # print(f"param '{name}' diff: {p2 - p1}, l2: {(p2-p1).norm().item()}")
416
+ return model
417
+
418
+
419
+ if __name__ == '__main__':
420
+ model = run_train(args)
train.sh ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # # Parametrized Training
4
+ # 100 epochs, frozen_processor: http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com/#/experiments/49/runs/2803f44514e34a0f87d591520706e876
5
+ # model_uri="s3://mlflow-artifacts-601883093460/49/2803f44514e34a0f87d591520706e876/artifacts/model"
6
+
7
+ # used for training current model to 100% train and 80% val accuracy
8
+ # python train.py \
9
+ # --experiment_name parametrized \
10
+ # --classifier_uri "${model_uri}" \
11
+ # --run_name par_full_kurt \
12
+ # --dataset Microscopy \
13
+ # --lr 1e-5 \
14
+ # --epochs 50 \
15
+ # --freeze_classifier \
16
+
17
+ # --freeze_processor \
18
+
19
+ # # Adversarial Training
20
+
21
+ # python train.py \
22
+ # --experiment_name adversarial \
23
+ # --run_name adv_frozen_processor \
24
+ # --classifier_uri "${model_uri}" \
25
+ # --dataset Microscopy \
26
+ # --adv_training \
27
+ # --lr 1e-3 \
28
+ # --epochs 7 \
29
+ # --freeze_classifier \
30
+ # --track_processing \
31
+ # --track_every_epoch \
32
+ # --log_model=False \
33
+ # --adv_aux_weight=0.1 \
34
+ # --adv_aux_loss "l2" \
35
+
36
+ # --adv_aux_weight=2e-5 \
37
+ # --adv_aux_weight=2e-5 \
38
+ # --adv_aux_weight=1.9e-5 \
39
+
40
+ # Cross pipeline training (Segmentation/Classification)
41
+
42
+ # Static Pipeline Script
43
+
44
+ # datasets="Microscopy Drone DroneSegmentation"
45
+ datasets="DroneSegmentation"
46
+ augmentations="weak strong none"
47
+
48
+ demosaicings="bilinear malvar2004 menon2007"
49
+ sharpenings="sharpening_filter unsharp_masking"
50
+ denoisings="median_denoising gaussian_denoising"
51
+
52
+ for augment in $augmentations
53
+ do
54
+ for data in $datasets
55
+ do
56
+ for demosaicing in $demosaicings
57
+ do
58
+ for sharpening in $sharpenings
59
+ do
60
+ for denoising in $denoisings
61
+ do
62
+
63
+ python train.py \
64
+ --experiment_name ABtesting \
65
+ --run_name "$data"_"$demosaicing"_"$sharpening"_"$denoising"_"$augment" \
66
+ --dataset "$data" \
67
+ --batch_size 4 \
68
+ --lr 1e-5 \
69
+ --epochs 100 \
70
+ --sp_debayer "$demosaicing" \
71
+ --sp_sharpening "$sharpening" \
72
+ --sp_denoising "$denoising" \
73
+ --processing_mode "static" \
74
+ --augmentation "$augment" \
75
+ --n_split 5 \
76
+
77
+ done
78
+ done
79
+ done
80
+ done
81
+ done
utils/Cperturb.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Code extracted from the paper:
3
+
4
+ @articlehendrycks2019robustness,
5
+ title=Benchmarking Neural Network Robustness to Common Corruptions and Perturbations,
6
+ author=Dan Hendrycks and Thomas Dietterich,
7
+ journal=Proceedings of the International Conference on Learning Representations,
8
+ year=2019
9
+ }
10
+
11
+ The code is modified to fit with our model
12
+ '''
13
+
14
+ import os
15
+ from PIL import Image
16
+ import os.path
17
+ import time
18
+ import torch
19
+ import torchvision.datasets as dset
20
+ import torchvision.transforms as trn
21
+ import torch.utils.data as data
22
+ import numpy as np
23
+
24
+ from PIL import Image
25
+
26
+
27
+ # /////////////// Distortion Helpers ///////////////
28
+
29
+ import skimage as sk
30
+ from skimage.filters import gaussian
31
+ from io import BytesIO
32
+ from wand.image import Image as WandImage
33
+ from wand.api import library as wandlibrary
34
+ import wand.color as WandColor
35
+ import ctypes
36
+ from PIL import Image as PILImage
37
+ import cv2
38
+ from scipy.ndimage import zoom as scizoom
39
+ from scipy.ndimage.interpolation import map_coordinates
40
+ import warnings
41
+
42
+ warnings.simplefilter("ignore", UserWarning)
43
+
44
+
45
+ def disk(radius, alias_blur=0.1, dtype=np.float32):
46
+ if radius <= 8:
47
+ L = np.arange(-8, 8 + 1)
48
+ ksize = (3, 3)
49
+ else:
50
+ L = np.arange(-radius, radius + 1)
51
+ ksize = (5, 5)
52
+ X, Y = np.meshgrid(L, L)
53
+ aliased_disk = np.array((X ** 2 + Y ** 2) <= radius ** 2, dtype=dtype)
54
+ aliased_disk /= np.sum(aliased_disk)
55
+
56
+ # supersample disk to antialias
57
+ return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)
58
+
59
+
60
+ # Tell Python about the C method
61
+ wandlibrary.MagickMotionBlurImage.argtypes = (ctypes.c_void_p, # wand
62
+ ctypes.c_double, # radius
63
+ ctypes.c_double, # sigma
64
+ ctypes.c_double) # angle
65
+
66
+
67
+ # Extend wand.image.Image class to include method signature
68
+ class MotionImage(WandImage):
69
+ def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
70
+ wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)
71
+
72
+
73
+ # modification of https://github.com/FLHerne/mapgen/blob/master/diamondsquare.py
74
+ def plasma_fractal(mapsize=32, wibbledecay=3):
75
+ """
76
+ Generate a heightmap using diamond-square algorithm.
77
+ Return square 2d array, side length 'mapsize', of floats in range 0-255.
78
+ 'mapsize' must be a power of two.
79
+ """
80
+ assert (mapsize & (mapsize - 1) == 0)
81
+ maparray = np.empty((mapsize, mapsize), dtype=np.float_)
82
+ maparray[0, 0] = 0
83
+ stepsize = mapsize
84
+ wibble = 100
85
+
86
+ def wibbledmean(array):
87
+ return array / 4 + wibble * np.random.uniform(-wibble, wibble, array.shape)
88
+
89
+ def fillsquares():
90
+ """For each square of points stepsize apart,
91
+ calculate middle value as mean of points + wibble"""
92
+ cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
93
+ squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0)
94
+ squareaccum += np.roll(squareaccum, shift=-1, axis=1)
95
+ maparray[stepsize // 2:mapsize:stepsize,
96
+ stepsize // 2:mapsize:stepsize] = wibbledmean(squareaccum)
97
+
98
+ def filldiamonds():
99
+ """For each diamond of points stepsize apart,
100
+ calculate middle value as mean of points + wibble"""
101
+ mapsize = maparray.shape[0]
102
+ drgrid = maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize]
103
+ ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
104
+ ldrsum = drgrid + np.roll(drgrid, 1, axis=0)
105
+ lulsum = ulgrid + np.roll(ulgrid, -1, axis=1)
106
+ ltsum = ldrsum + lulsum
107
+ maparray[0:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(ltsum)
108
+ tdrsum = drgrid + np.roll(drgrid, 1, axis=1)
109
+ tulsum = ulgrid + np.roll(ulgrid, -1, axis=0)
110
+ ttsum = tdrsum + tulsum
111
+ maparray[stepsize // 2:mapsize:stepsize, 0:mapsize:stepsize] = wibbledmean(ttsum)
112
+
113
+ while stepsize >= 2:
114
+ fillsquares()
115
+ filldiamonds()
116
+ stepsize //= 2
117
+ wibble /= wibbledecay
118
+
119
+ maparray -= maparray.min()
120
+ return maparray / maparray.max()
121
+
122
+
123
+ def clipped_zoom(img, zoom_factor):
124
+ h = img.shape[0]
125
+ # ceil crop height(= crop width)
126
+ ch = int(np.ceil(h / zoom_factor))
127
+
128
+ top = (h - ch) // 2
129
+ img = scizoom(img[top:top + ch, top:top + ch], (zoom_factor, zoom_factor, 1), order=1)
130
+ # trim off any extra pixels
131
+ trim_top = (img.shape[0] - h) // 2
132
+
133
+ return img[trim_top:trim_top + h, trim_top:trim_top + h]
134
+
135
+
136
+ # /////////////// End Distortion Helpers ///////////////
137
+
138
+
139
+ # /////////////// Distortions ///////////////
140
+
141
+ class Distortions:
142
+ def __init__(self, severity=1, transform='identity'):
143
+ self.severity = severity
144
+ self.transform = transform
145
+
146
+ def __call__(self, img):
147
+ assert torch.is_tensor(img), 'Input data need to be a torch.tensor'
148
+ assert len(img.shape) == 3, 'Input image should be RGB'
149
+ img = self.torch2np(img)
150
+ t = getattr(self, self.transform)
151
+ img = t(img, self.severity)
152
+ return self.np2torch(img).float()
153
+
154
+ def np2torch(self,x):
155
+ return torch.tensor(x).permute(2,0,1)
156
+
157
+ def torch2np(self,x):
158
+ return np.array(x.permute(1,2,0))
159
+
160
+ def identity(self,x, severity=1):
161
+ return x
162
+
163
+ def gaussian_noise(self, x, severity=1):
164
+ c = [0.04, 0.06, .08, .09, .10][severity - 1]
165
+ return np.clip(x + np.random.normal(size=x.shape, scale=c), 0, 1)
166
+
167
+
168
+ def shot_noise(self, x, severity=1):
169
+ c = [500, 250, 100, 75, 50][severity - 1]
170
+ return np.clip(np.random.poisson(x * c) / c, 0, 1)
171
+
172
+
173
+ def impulse_noise(self, x, severity=1):
174
+ c = [.01, .02, .03, .05, .07][severity - 1]
175
+
176
+ x = sk.util.random_noise(x, mode='s&p', amount=c)
177
+ return np.clip(x, 0, 1)
178
+
179
+
180
+ def speckle_noise(self, x, severity=1):
181
+ c = [.06, .1, .12, .16, .2][severity - 1]
182
+ return np.clip(x + x * np.random.normal(size=x.shape, scale=c), 0, 1)
183
+
184
+
185
+ def gaussian_blur(self, x, severity=1):
186
+ c = [.4, .6, 0.7, .8, 1][severity - 1]
187
+
188
+ x = gaussian(x, sigma=c, multichannel=True)
189
+ return np.clip(x, 0, 1)
190
+
191
+
192
+ def glass_blur(self, x, severity=1):
193
+ # sigma, max_delta, iterations
194
+ c = [(0.05,1,1), (0.25,1,1), (0.4,1,1), (0.25,1,2), (0.4,1,2)][severity - 1]
195
+
196
+ x = gaussian(x, sigma=c[0], multichannel=True)
197
+
198
+ # locally shuffle pixels
199
+ for i in range(c[2]):
200
+ for h in range(32 - c[1], c[1], -1):
201
+ for w in range(32 - c[1], c[1], -1):
202
+ dx, dy = np.random.randint(-c[1], c[1], size=(2,))
203
+ h_prime, w_prime = h + dy, w + dx
204
+ # swap
205
+ x[h, w], x[h_prime, w_prime] = x[h_prime, w_prime], x[h, w]
206
+
207
+ return np.clip(gaussian(x, sigma=c[0], multichannel=True), 0, 1)
208
+
209
+
210
+ def defocus_blur(self, x, severity=1):
211
+ c = [(0.3, 0.4), (0.4, 0.5), (0.5, 0.6), (1, 0.2), (1.5, 0.1)][severity - 1]
212
+ kernel = disk(radius=c[0], alias_blur=c[1])
213
+
214
+ channels = []
215
+ for d in range(3):
216
+ channels.append(cv2.filter2D(x[:, :, d], -1, kernel))
217
+ channels = np.array(channels).transpose((1, 2, 0)) # 3x32x32 -> 32x32x3
218
+
219
+ return np.clip(channels, 0, 1)
220
+
221
+
222
+ def motion_blur(self, x, severity=1):
223
+ c = [(6,1), (6,1.5), (6,2), (8,2), (9,2.5)][severity - 1]
224
+
225
+ output = BytesIO()
226
+ x.save(output, format='PNG')
227
+ x = MotionImage(blob=output.getvalue())
228
+
229
+ x.motion_blur(radius=c[0], sigma=c[1], angle=np.random.uniform(-45, 45))
230
+
231
+ x = cv2.imdecode(np.fromstring(x.make_blob(), np.uint8),
232
+ cv2.IMREAD_UNCHANGED)
233
+
234
+ if x.shape != (32, 32):
235
+ return np.clip(x[..., [2, 1, 0]], 0, 1) # BGR to RGB
236
+ else: # greyscale to RGB
237
+ return np.clip(np.array([x, x, x]).transpose((1, 2, 0)), 0, 1)
238
+
239
+
240
+ def zoom_blur(self, x, severity=1):
241
+ c = [np.arange(1, 1.06, 0.01), np.arange(1, 1.11, 0.01), np.arange(1, 1.16, 0.01),
242
+ np.arange(1, 1.21, 0.01), np.arange(1, 1.26, 0.01)][severity - 1]
243
+ out = np.zeros_like(x)
244
+ for zoom_factor in c:
245
+ out += clipped_zoom(x, zoom_factor)
246
+
247
+ x = (x + out) / (len(c) + 1)
248
+ return np.clip(x, 0, 1)
249
+
250
+
251
+ def fog(self, x, severity=1):
252
+ c = [(.2,3), (.5,3), (0.75,2.5), (1,2), (1.5,1.75)][severity - 1]
253
+ max_val = x.max()
254
+ x += c[0] * plasma_fractal(wibbledecay=c[1])[:32, :32][..., np.newaxis]
255
+ return np.clip(x * max_val / (max_val + c[0]), 0, 1)
256
+
257
+
258
+ def frost(self, x, severity=1):
259
+ c = [(1, 0.2), (1, 0.3), (0.9, 0.4), (0.85, 0.4), (0.75, 0.45)][severity - 1]
260
+ idx = np.random.randint(5)
261
+ filename = ['./frost1.png', './frost2.png', './frost3.png', './frost4.jpg', './frost5.jpg', './frost6.jpg'][idx]
262
+ frost = cv2.imread(filename)
263
+ frost = cv2.resize(frost, (0, 0), fx=0.2, fy=0.2)
264
+ # randomly crop and convert to rgb
265
+ x_start, y_start = np.random.randint(0, frost.shape[0] - 32), np.random.randint(0, frost.shape[1] - 32)
266
+ frost = frost[x_start:x_start + 32, y_start:y_start + 32][..., [2, 1, 0]]
267
+
268
+ return np.clip(c[0] * np.array(x) + c[1] * frost, 0, 1)
269
+
270
+
271
+ def snow(self, x, severity=1):
272
+ c = [(0.1,0.2,1,0.6,8,3,0.95),
273
+ (0.1,0.2,1,0.5,10,4,0.9),
274
+ (0.15,0.3,1.75,0.55,10,4,0.9),
275
+ (0.25,0.3,2.25,0.6,12,6,0.85),
276
+ (0.3,0.3,1.25,0.65,14,12,0.8)][severity - 1]
277
+
278
+ snow_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1]) # [:2] for monochrome
279
+
280
+ snow_layer = clipped_zoom(snow_layer[..., np.newaxis], c[2])
281
+ snow_layer[snow_layer < c[3]] = 0
282
+
283
+ snow_layer = PILImage.fromarray((np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), mode='L')
284
+ output = BytesIO()
285
+ snow_layer.save(output, format='PNG')
286
+ snow_layer = MotionImage(blob=output.getvalue())
287
+
288
+ snow_layer.motion_blur(radius=c[4], sigma=c[5], angle=np.random.uniform(-135, -45))
289
+
290
+ snow_layer = cv2.imdecode(np.fromstring(snow_layer.make_blob(), np.uint8),
291
+ cv2.IMREAD_UNCHANGED) / (2**16-1)
292
+ snow_layer = snow_layer[..., np.newaxis]
293
+
294
+ x = c[6] * x + (1 - c[6]) * np.maximum(x, cv2.cvtColor(x, cv2.COLOR_RGB2GRAY).reshape(32, 32, 1) * 1.5 + 0.5)
295
+ return np.clip(x + snow_layer + np.rot90(snow_layer, k=2), 0, 1)
296
+
297
+
298
+ def spatter(self, x, severity=1):
299
+ c = [(0.62,0.1,0.7,0.7,0.5,0),
300
+ (0.65,0.1,0.8,0.7,0.5,0),
301
+ (0.65,0.3,1,0.69,0.5,0),
302
+ (0.65,0.1,0.7,0.69,0.6,1),
303
+ (0.65,0.1,0.5,0.68,0.6,1)][severity - 1]
304
+
305
+ liquid_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1])
306
+
307
+ liquid_layer = gaussian(liquid_layer, sigma=c[2])
308
+ liquid_layer[liquid_layer < c[3]] = 0
309
+ if c[5] == 0:
310
+ liquid_layer = (liquid_layer * (2**16-1)).astype(np.uint8)
311
+ dist = (2**16-1) - cv2.Canny(liquid_layer, 50, 150)
312
+ dist = cv2.distanceTransform(dist, cv2.DIST_L2, 5)
313
+ _, dist = cv2.threshold(dist, 20, 20, cv2.THRESH_TRUNC)
314
+ dist = cv2.blur(dist, (3, 3)).astype(np.uint8)
315
+ dist = cv2.equalizeHist(dist)
316
+ # ker = np.array([[-1,-2,-3],[-2,0,0],[-3,0,1]], dtype=np.float32)
317
+ # ker -= np.mean(ker)
318
+ ker = np.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]])
319
+ dist = cv2.filter2D(dist, cv2.CV_8U, ker)
320
+ dist = cv2.blur(dist, (3, 3)).astype(np.float32)
321
+
322
+ m = cv2.cvtColor(liquid_layer * dist, cv2.COLOR_GRAY2BGRA)
323
+ m /= np.max(m, axis=(0, 1))
324
+ m *= c[4]
325
+
326
+ # water is pale turqouise
327
+ color = np.concatenate((175 / 255. * np.ones_like(m[..., :1]),
328
+ 238 / 255. * np.ones_like(m[..., :1]),
329
+ 238 / 255. * np.ones_like(m[..., :1])), axis=2)
330
+
331
+ color = cv2.cvtColor(color, cv2.COLOR_BGR2BGRA)
332
+ x = cv2.cvtColor(x, cv2.COLOR_BGR2BGRA)
333
+
334
+ return cv2.cvtColor(np.clip(x + m * color, 0, 1), cv2.COLOR_BGRA2BGR) * (2**16-1)
335
+ else:
336
+ m = np.where(liquid_layer > c[3], 1, 0)
337
+ m = gaussian(m.astype(np.float32), sigma=c[4])
338
+ m[m < 0.8] = 0
339
+ # m = np.abs(m) ** (1/c[4])
340
+
341
+ # mud brown
342
+ color = np.concatenate((63 / 255. * np.ones_like(x[..., :1]),
343
+ 42 / 255. * np.ones_like(x[..., :1]),
344
+ 20 / 255. * np.ones_like(x[..., :1])), axis=2)
345
+
346
+ color *= m[..., np.newaxis]
347
+ x *= (1 - m[..., np.newaxis])
348
+
349
+ return np.clip(x + color, 0, 1)
350
+
351
+
352
+ def contrast(self, x, severity=1):
353
+ c = [.75, .5, .4, .3, 0.15][severity - 1]
354
+ means = np.mean(x, axis=(0, 1), keepdims=True)
355
+ return np.clip((x - means) * c + means, 0, 1)
356
+
357
+
358
+ def brightness(self, x, severity=1):
359
+ c = [.05, .1, .15, .2, .3][severity - 1]
360
+
361
+ x = sk.color.rgb2hsv(x)
362
+ x[:, :, 2] = np.clip(x[:, :, 2] + c, 0, 1)
363
+ x = sk.color.hsv2rgb(x)
364
+
365
+ return np.clip(x, 0, 1)
366
+
367
+
368
+ def saturate(self, x, severity=1):
369
+ c = [(0.3, 0), (0.1, 0), (1.5, 0), (2, 0.1), (2.5, 0.2)][severity - 1]
370
+
371
+ x = sk.color.rgb2hsv(x)
372
+ x[:, :, 1] = np.clip(x[:, :, 1] * c[0] + c[1], 0, 1)
373
+ x = sk.color.hsv2rgb(x)
374
+
375
+ return np.clip(x, 0, 1)
376
+
377
+
378
+ def jpeg_compression(self, x, severity=1):
379
+ c = [80, 65, 58, 50, 40][severity - 1]
380
+
381
+ output = BytesIO()
382
+ x.save(output, 'JPEG', quality=c)
383
+ x = PILImage.open(output)
384
+
385
+ return x
386
+
387
+
388
+ def pixelate(self, x, severity=1):
389
+ c = [0.95, 0.9, 0.85, 0.75, 0.65][severity - 1]
390
+
391
+ x = x.resize((int(32 * c), int(32 * c)), PILImage.BOX)
392
+ x = x.resize((32, 32), PILImage.BOX)
393
+
394
+ return x
395
+
396
+
397
+ # mod of https://gist.github.com/erniejunior/601cdf56d2b424757de5
398
+ def elastic_transform(self, image, severity=1):
399
+ IMSIZE = 32
400
+ c = [(IMSIZE*0, IMSIZE*0, IMSIZE*0.08),
401
+ (IMSIZE*0.05, IMSIZE*0.2, IMSIZE*0.07),
402
+ (IMSIZE*0.08, IMSIZE*0.06, IMSIZE*0.06),
403
+ (IMSIZE*0.1, IMSIZE*0.04, IMSIZE*0.05),
404
+ (IMSIZE*0.1, IMSIZE*0.03, IMSIZE*0.03)][severity - 1]
405
+
406
+ shape = image.shape
407
+ shape_size = shape[:2]
408
+
409
+ # random affine
410
+ center_square = np.float32(shape_size) // 2
411
+ square_size = min(shape_size) // 3
412
+ pts1 = np.float32([center_square + square_size,
413
+ [center_square[0] + square_size, center_square[1] - square_size],
414
+ center_square - square_size])
415
+ pts2 = pts1 + np.random.uniform(-c[2], c[2], size=pts1.shape).astype(np.float32)
416
+ M = cv2.getAffineTransform(pts1, pts2)
417
+ image = cv2.warpAffine(image, M, shape_size[::-1], borderMode=cv2.BORDER_REFLECT_101)
418
+
419
+ dx = (gaussian(np.random.uniform(-1, 1, size=shape[:2]),
420
+ c[1], mode='reflect', truncate=3) * c[0]).astype(np.float32)
421
+ dy = (gaussian(np.random.uniform(-1, 1, size=shape[:2]),
422
+ c[1], mode='reflect', truncate=3) * c[0]).astype(np.float32)
423
+ dx, dy = dx[..., np.newaxis], dy[..., np.newaxis]
424
+
425
+ x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2]))
426
+ indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z, (-1, 1))
427
+ return np.clip(map_coordinates(image, indices, order=1, mode='reflect').reshape(shape), 0, 1)
428
+
429
+ if __name__=='__main__':
430
+ import os
431
+
432
+ import numpy as np
433
+ import matplotlib.pyplot as plt
434
+ import tifffile as tiff
435
+ import torch
436
+
437
+ os.system('cd ..')
438
+
439
+ img = tiff.imread('/home/marco/perturbed-minds/perturbed-minds/data/microscopy/images/rgb_scale100/Ma190c_lame1_zone1_composite_Mcropped_1.tiff')
440
+ img = np.array(img)/(2**16-1)
441
+ img = torch.tensor(img).permute(2,0,1)
442
+
443
+ def identity(x, sev):
444
+ return x
445
+
446
+ if not os.path.exists('results/Cimages'):
447
+ os.makedirs('results/Cimages')
448
+
449
+ transformations = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
450
+ 'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
451
+
452
+ # glass_blur, defocus_blur, motion_blur, fog, frost, snow, spatter, jpeg_compression, pixelate,
453
+
454
+ plt.figure()
455
+ plt.imshow(img.permute(1,2,0))
456
+ plt.title('identity')
457
+ plt.show()
458
+ plt.savefig(f'results/Cimages/1_identity.png')
459
+
460
+
461
+ for i,t in enumerate(transformations):
462
+
463
+ fig = plt.figure(figsize=(25,5))
464
+ columns = 5
465
+ rows = 1
466
+
467
+ for sev in range(1,6):
468
+ dist = Distortions(severity=sev, transform=t)
469
+ fig.add_subplot(rows, columns, sev)
470
+ plt.imshow(dist(img).permute(1,2,0))
471
+ plt.title(f'{t} {sev}')
472
+ plt.xticks([], [])
473
+ plt.yticks([], [])
474
+ plt.show()
475
+ plt.savefig(f'results/Cimages/{i+2}_{t}.png')
utils/augmentation.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torchvision.transforms as T
6
+
7
+
8
+ class RandomRotate90(): # Note: not the same as T.RandomRotation(90)
9
+ def __call__(self, x):
10
+ x = x.rot90(random.randint(0, 3), dims=(-1, -2))
11
+ return x
12
+
13
+ def __repr__(self):
14
+ return self.__class__.__name__
15
+
16
+
17
+ class AddGaussianNoise():
18
+ def __init__(self, std=0.01):
19
+ self.std = std
20
+
21
+ def __call__(self, x):
22
+ # noise = torch.randn_like(x) * self.std
23
+ # out = x + noise
24
+ # debug(x)
25
+ # debug(noise)
26
+ # debug(out)
27
+ return x + torch.randn_like(x) * self.std
28
+
29
+ def __repr__(self):
30
+ return self.__class__.__name__ + f'(std={self.std})'
31
+
32
+
33
+ def set_global_seed(seed):
34
+ torch.random.manual_seed(seed)
35
+ np.random.seed(seed % (2**32 - 1))
36
+ random.seed(seed)
37
+
38
+
39
+ class ComposeState(T.Compose):
40
+ def __init__(self, transforms):
41
+ self.transforms = []
42
+ self.mask_transforms = []
43
+
44
+ for t in transforms:
45
+ apply_for_mask = True
46
+ if isinstance(t, tuple):
47
+ t, apply_for_mask = t
48
+ self.transforms.append(t)
49
+ if apply_for_mask:
50
+ self.mask_transforms.append(t)
51
+
52
+ self.seed = None
53
+
54
+ # @debug
55
+ def __call__(self, x, retain_state=False, mask_transform=False):
56
+ if self.seed is not None: # retain previous state
57
+ set_global_seed(self.seed)
58
+ if retain_state: # save state for next call
59
+ self.seed = self.seed or torch.seed()
60
+ set_global_seed(self.seed)
61
+ else:
62
+ self.seed = None # reset / ignore state
63
+
64
+ transforms = self.transforms if not mask_transform else self.mask_transforms
65
+ for t in transforms:
66
+ x = t(x)
67
+ return x
68
+
69
+
70
+ augmentation_weak = ComposeState([
71
+ T.RandomHorizontalFlip(),
72
+ T.RandomVerticalFlip(),
73
+ RandomRotate90(),
74
+ ])
75
+
76
+
77
+ augmentation_strong = ComposeState([
78
+ T.RandomHorizontalFlip(p=0.5),
79
+ T.RandomVerticalFlip(p=0.5),
80
+ T.RandomApply([T.RandomRotation(90)], p=0.5),
81
+ # (transform, apply_to_mask=True)
82
+ (T.RandomApply([AddGaussianNoise(std=0.0005)], p=0.5), False),
83
+ (T.RandomAdjustSharpness(0.5, p=0.5), False),
84
+ ])
85
+
86
+
87
+ def get_augmentation(type):
88
+ if type == 'none':
89
+ return None
90
+ if type == 'weak':
91
+ return augmentation_weak
92
+ if type == 'strong':
93
+ return augmentation_strong
94
+
95
+
96
+ if __name__ == '__main__':
97
+ import os
98
+ if not os.path.exists('README.md'):
99
+ os.chdir('..')
100
+
101
+ # from utils.debug import debug
102
+ from utils.dataset import get_dataset
103
+ import matplotlib.pyplot as plt
104
+
105
+ dataset = get_dataset('DS') # drone segmentation
106
+ img, mask = dataset[10]
107
+ mask = (mask + 0.2) / 1.2
108
+
109
+ plt.figure(figsize=(14, 8))
110
+ plt.subplot(121)
111
+ plt.imshow(img)
112
+ plt.subplot(122)
113
+ plt.imshow(mask)
114
+ plt.suptitle('no augmentation')
115
+ plt.show()
116
+
117
+ from utils.base import np2torch, torch2np
118
+ img, mask = np2torch(img), np2torch(mask)
119
+
120
+ # from utils.augmentation import get_augmentation
121
+ augmentation = get_augmentation('strong')
122
+
123
+ set_global_seed(1)
124
+
125
+ for i in range(1, 4):
126
+ plt.figure(figsize=(14, 8))
127
+ plt.subplot(121)
128
+ plt.imshow(torch2np(augmentation(img.unsqueeze(0), retain_state=True)).squeeze())
129
+ plt.subplot(122)
130
+ plt.imshow(torch2np(augmentation(mask.unsqueeze(0), mask_transform=True)).squeeze())
131
+ plt.suptitle(f'augmentation test {i}')
132
+ plt.show()
utils/base.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for other scripts
3
+ """
4
+
5
+ import os
6
+ import shutil
7
+
8
+ import random
9
+
10
+ import torch
11
+ import mlflow
12
+ from mlflow.tracking import MlflowClient
13
+ import numpy as np
14
+
15
+ from IPython.display import display, Markdown
16
+
17
+ from b2sdk.v1 import *
18
+
19
+ import argparse
20
+
21
+
22
+ class SmartFormatter(argparse.HelpFormatter):
23
+
24
+ def _split_lines(self, text, width):
25
+ if text.startswith('R|'):
26
+ return text[2:].splitlines()
27
+ # this is the RawTextHelpFormatter._split_lines
28
+ return argparse.HelpFormatter._split_lines(self, text, width)
29
+
30
+
31
+ def str2bool(string):
32
+ return string == 'True'
33
+
34
+
35
+ def np2torch(nparray):
36
+ """Convert numpy array to torch tensor
37
+ For array with more than 3 channels, it is better to use an input array in the format BxHxWxC
38
+
39
+ Args:
40
+ numpy array (ndarray) BxHxWxC
41
+ Returns:
42
+ torch tensor (tensor) BxCxHxW"""
43
+
44
+ tensor = torch.Tensor(nparray)
45
+
46
+ if tensor.ndim == 2:
47
+ return tensor
48
+ if tensor.ndim == 3:
49
+ height, width, channels = tensor.shape
50
+ if channels <= 3: # Single image with more channels (HxWxC)
51
+ return tensor.permute(2, 0, 1)
52
+
53
+ if tensor.ndim == 4: # More images with more channels (BxHxWxC)
54
+ return tensor.permute(0, 3, 1, 2)
55
+
56
+ return tensor
57
+
58
+
59
+ def torch2np(torchtensor):
60
+ """Convert torch tensor to numpy array
61
+ For tensor with more than 3 channels or batch, it is better to use an input tensor in the format BxCxHxW
62
+
63
+ Args:
64
+ torch tensor (tensor) BxCxHxW
65
+ Returns:
66
+ numpy array (ndarray) BxHxWxC"""
67
+
68
+ ndarray = torchtensor.detach().cpu().numpy().astype(np.float32)
69
+
70
+ if ndarray.ndim == 3: # Single image with more channels (CxHxW)
71
+ channels, height, width = ndarray.shape
72
+ if channels <= 3:
73
+ return ndarray.transpose(1, 2, 0)
74
+
75
+ if ndarray.ndim == 4: # More images with more channels (BxCxHxW)
76
+ return ndarray.transpose(0, 2, 3, 1)
77
+
78
+ return ndarray
79
+
80
+
81
+ def set_random_seed(seed):
82
+ np.random.seed(seed) # cpu vars
83
+ torch.manual_seed(seed) # cpu vars
84
+ random.seed(seed) # Python
85
+ if torch.cuda.is_available():
86
+ torch.cuda.manual_seed(seed)
87
+ torch.cuda.manual_seed_all(seed) # gpu vars
88
+ torch.backends.cudnn.deterministic = True # needed
89
+ torch.backends.cudnn.benchmark = False
90
+
91
+
92
+ def normalize(img):
93
+ """Normalize images
94
+
95
+ Args:
96
+ imgs (ndarray): image to normalize --> size: (Height,Width,Channels)
97
+ Returns:
98
+ normalized (ndarray): normalized image
99
+ mu (ndarray): mean
100
+ sigma (ndarray): standard deviation
101
+ """
102
+
103
+ img = img.astype(float)
104
+
105
+ if len(img.shape) == 2:
106
+ img = img[:, :, np.newaxis]
107
+
108
+ height, width, channels = img.shape
109
+
110
+ mu, sigma = np.empty(channels), np.empty(channels)
111
+
112
+ for ch in range(channels):
113
+ temp_mu = img[:, :, ch].mean()
114
+ temp_sigma = img[:, :, ch].std()
115
+
116
+ img[:, :, ch] = (img[:, :, ch] - temp_mu) / (temp_sigma + 1e-4)
117
+
118
+ mu[ch] = temp_mu
119
+ sigma[ch] = temp_sigma
120
+
121
+ return img, mu, sigma
122
+
123
+
124
+ def b2_list_files(folder=''):
125
+ bucket = get_b2_bucket()
126
+ for file_info, _ in bucket.ls(folder, show_versions=False):
127
+ print(file_info.file_name)
128
+
129
+
130
+ def get_b2_bucket():
131
+ bucket_name = 'perturbed-minds'
132
+ application_key_id = '003d6b042de536a0000000004'
133
+ application_key = 'K003E5Cr+BAYlvSHfg2ynLtvS5aNq78'
134
+ info = InMemoryAccountInfo()
135
+ b2_api = B2Api(info)
136
+ b2_api.authorize_account('production', application_key_id, application_key)
137
+ bucket = b2_api.get_bucket_by_name(bucket_name)
138
+ return bucket
139
+
140
+
141
+ def b2_download_folder(b2_dir, local_dir, force_download=False, mirror_folder=True):
142
+ """Downloads a folder from the b2 bucket and optionally cleans
143
+ up files that are no longer on the server
144
+
145
+ Args:
146
+ b2_dir (str): path to folder on the b2 server
147
+ local_dir (str): path to folder on the local machine
148
+ force_download (bool, optional): force the download, if set to `False`,
149
+ files with matching names on the local machine will be skipped
150
+ mirror_folder (bool, optional): if set to `True`, files that are found in
151
+ the local directory, but are not on the server will be deleted
152
+ """
153
+ bucket = get_b2_bucket()
154
+
155
+ if not os.path.exists(local_dir):
156
+ os.makedirs(local_dir)
157
+ elif not force_download:
158
+ return
159
+
160
+ download_files = [file_info.file_name.split(b2_dir + '/')[-1]
161
+ for file_info, _ in bucket.ls(b2_dir, show_versions=False)]
162
+
163
+ for file_name in download_files:
164
+ if file_name.endswith('/.bzEmpty'): # subdirectory, download recursively
165
+ subdir = file_name.replace('/.bzEmpty', '')
166
+ if len(subdir) > 0:
167
+ b2_subdir = os.path.join(b2_dir, subdir)
168
+ local_subdir = os.path.join(local_dir, subdir)
169
+ if b2_subdir != b2_dir:
170
+ b2_download_folder(b2_subdir, local_subdir, force_download=force_download,
171
+ mirror_folder=mirror_folder)
172
+ else: # file
173
+ b2_file = os.path.join(b2_dir, file_name)
174
+ local_file = os.path.join(local_dir, file_name)
175
+ if not os.path.exists(local_file) or force_download:
176
+ print(f"downloading b2://{b2_file} -> {local_file}")
177
+ bucket.download_file_by_name(b2_file, DownloadDestLocalFile(local_file))
178
+
179
+ if mirror_folder: # remove all files that are not on the b2 server anymore
180
+ for i, file in enumerate(download_files):
181
+ if file.endswith('/.bzEmpty'): # subdirectory, download recursively
182
+ download_files[i] = file.replace('/.bzEmpty', '')
183
+ for file_name in os.listdir(local_dir):
184
+ if file_name not in download_files:
185
+ local_file = os.path.join(local_dir, file_name)
186
+ print(f"deleting {local_file}")
187
+ if os.path.isdir(local_file):
188
+ shutil.rmtree(local_file)
189
+ else:
190
+ os.remove(local_file)
191
+
192
+
193
+ def get_name(obj):
194
+ return obj.__name__ if hasattr(obj, '__name__') else type(obj).__name__
195
+
196
+
197
+ def get_mlflow_model_by_name(experiment_name, run_name,
198
+ tracking_uri = "http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com",
199
+ download_model = True):
200
+
201
+ # 0. mlflow basics
202
+ mlflow.set_tracking_uri(tracking_uri)
203
+ os.environ["AWS_ACCESS_KEY_ID"] = #TODO: add your AWS access key if you want to write your results to our collaborative lab server
204
+ os.environ["AWS_SECRET_ACCESS_KEY"] = #TODO: add your AWS seceret access key if you want to write your results to our collaborative lab server
205
+
206
+ # # 1. use get_experiment_by_name to get experiment objec
207
+ experiment = mlflow.get_experiment_by_name(experiment_name)
208
+
209
+ # # 2. use search_runs with experiment_id for string search query
210
+ if os.path.isfile('cache/runs_names.pkl'):
211
+ runs = pd.read_pickle('cache/runs_names.pkl')
212
+ if runs['tags.mlflow.runName'][runs['tags.mlflow.runName'] == run_name].empty:
213
+ runs = fetch_runs_list_mlflow(experiment) #returns a pandas data frame where each row is a run (if several exist under that name)
214
+ else:
215
+ runs = fetch_runs_list_mlflow(experiment) #returns a pandas data frame where each row is a run (if several exist under that name)
216
+
217
+ # 3. get the selected run between all runs inside the selected experiment
218
+ run = runs.loc[runs['tags.mlflow.runName'] == run_name]
219
+
220
+ # 4. check if there is only a run with that name
221
+ assert len(run) == 1, "More runs with this name"
222
+ index_run = run.index[0]
223
+ artifact_uri = run.loc[index_run, 'artifact_uri']
224
+
225
+ # 5. load state_dict of your run
226
+ state_dict = mlflow.pytorch.load_state_dict(artifact_uri)
227
+
228
+ # 6. load model of your run
229
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
230
+ # model = mlflow.pytorch.load_model(os.path.join(
231
+ # artifact_uri, "model"), map_location=torch.device(DEVICE))
232
+ model = fetch_from_mlflow(os.path.join(
233
+ artifact_uri, "model"), use_cache=True, download_model=download_model)
234
+
235
+ return state_dict, model
236
+
237
+ def data_loader_mean_and_std(data_loader, transform=None):
238
+ means = []
239
+ stds = []
240
+ for x, y in data_loader:
241
+ if transform is not None:
242
+ x = transform(x)
243
+ means.append(x.mean(dim=(0, 2, 3)).unsqueeze(0))
244
+ stds.append(x.std(dim=(0, 2, 3)).unsqueeze(0))
245
+ return torch.cat(means).mean(dim=0), torch.cat(stds).mean(dim=0)
246
+
247
+ def fetch_runs_list_mlflow(experiment):
248
+ runs = mlflow.search_runs(experiment.experiment_id)
249
+ runs.to_pickle('cache/runs_names.pkl') # where to save it, usually as a .pkl
250
+ return runs
251
+
252
+ def fetch_from_mlflow(uri, use_cache=True, download_model=True):
253
+ cache_loc = os.path.join('cache', uri.split('//')[1]) + '.pt'
254
+ if use_cache and os.path.exists(cache_loc):
255
+ print(f'loading cached model from {cache_loc} ...')
256
+ return torch.load(cache_loc)
257
+ else:
258
+ print(f'fetching model from {uri} ...')
259
+ model = mlflow.pytorch.load_model(uri)
260
+ os.makedirs(os.path.dirname(cache_loc), exist_ok=True)
261
+ if download_model:
262
+ torch.save(model, cache_loc, pickle_module=mlflow.pytorch.pickle_module)
263
+ return model
264
+
265
+
266
+ def display_mlflow_run_info(run):
267
+ uri = mlflow.get_tracking_uri()
268
+ experiment_id = run.info.experiment_id
269
+ experiment_name = mlflow.get_experiment(experiment_id).name
270
+ run_id = run.info.run_id
271
+ run_name = run.data.tags['mlflow.runName']
272
+ experiment_url = f'{uri}/#/experiments/{experiment_id}'
273
+ run_url = f'{experiment_url}/runs/{run_id}'
274
+
275
+ print(f'view results at {run_url}')
276
+ display(Markdown(
277
+ f"[<a href='{experiment_url}'>experiment {experiment_id} '{experiment_name}'</a>]"
278
+ f" > "
279
+ f"[<a href='{run_url}'>run '{run_name}' {run_id}</a>]"
280
+ ))
281
+ print('')
282
+
283
+
284
+ def get_train_test_indices_drone(df, frac, seed=None):
285
+ """ Split indices of a DataFrame with binary and balanced labels into balanced subindices
286
+
287
+ Args:
288
+ df (pd.DataFrame): {0,1}-labeled data
289
+ frac (float): fraction of indicies in first subset
290
+ random_seed (int): random seed used as random state in np.random and as argument for random.seed()
291
+ Returns:
292
+ train_indices (torch.tensor): balanced subset of indices corresponding to rows in the DataFrame
293
+ test_indices (torch.tensor): balanced subset of indices corresponding to rows in the DataFrame
294
+ """
295
+
296
+ split_idx = int(len(df) * frac / 2)
297
+ df_with = df[df['label'] == 1]
298
+ df_without = df[df['label'] == 0]
299
+
300
+ np.random.seed(seed)
301
+ df_with_train = df_with.sample(n=split_idx, random_state=seed)
302
+ df_with_test = df_with.drop(df_with_train.index)
303
+
304
+ df_without_train = df_without.sample(n=split_idx, random_state=seed)
305
+ df_without_test = df_without.drop(df_without_train.index)
306
+
307
+ train_indices = list(df_without_train.index) + list(df_with_train.index)
308
+ test_indices = list(df_without_test.index) + list(df_with_test.index)
309
+
310
+ """"
311
+ print('fraction of 1-label in train set: {}'.format(len(df_with_train)/(len(df_with_train) + len(df_without_train))))
312
+ print('fraction of 1-label in test set: {}'.format(len(df_with_test)/(len(df_with_test) + len(df_with_test))))
313
+ """
314
+
315
+ return train_indices, test_indices
316
+
317
+
318
+ def smp_get_loss(loss):
319
+ if loss == "Dice":
320
+ return smp.losses.DiceLoss(mode='binary', from_logits=True)
321
+ if loss == "BCE":
322
+ return nn.BCELoss()
323
+ elif loss == "BCEWithLogits":
324
+ return smp.losses.BCEWithLogitsLoss()
325
+ elif loss == "DicyBCE":
326
+ from pytorch_toolbelt import losses as ptbl
327
+ return ptbl.JointLoss(ptbl.DiceLoss(mode='binary', from_logits=False),
328
+ nn.BCELoss(),
329
+ first_weight=args.dice_weight,
330
+ second_weight=args.bce_weight)
utils/dataset.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import rawpy
4
+ import random
5
+ from PIL import Image
6
+ import tifffile as tiff
7
+ import zipfile
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
13
+ from sklearn.model_selection import StratifiedShuffleSplit
14
+
15
+ if not os.path.exists('README.md'): # set pwd to root
16
+ os.chdir('..')
17
+
18
+ from utils.splitting import split_img
19
+ from utils.base import np2torch, torch2np, b2_download_folder
20
+
21
+ IMAGE_FILE_TYPES = ['dng', 'png', 'tif', 'tiff']
22
+
23
+
24
+ def get_dataset(name, I_ratio=1.0):
25
+ # DroneDataset
26
+ if name in ('DC', 'Drone', 'DroneClassification', 'DroneDatasetClassificationTiled'):
27
+ return DroneDatasetClassificationTiled(I_ratio=I_ratio)
28
+ if name in ('DS', 'DroneSegmentation', 'DroneDatasetSegmentationTiled'):
29
+ return DroneDatasetSegmentationTiled(I_ratio=I_ratio)
30
+
31
+ # MicroscopyDataset
32
+ if name in ('M', 'Microscopy', 'MicroscopyDataset'):
33
+ return MicroscopyDataset(I_ratio=I_ratio)
34
+
35
+ # for testing
36
+ if name in ('DSF', 'DroneDatasetSegmentationFull'):
37
+ return DroneDatasetSegmentationFull(I_ratio=I_ratio)
38
+ if name in ('MRGB', 'MicroscopyRGB', 'MicroscopyDatasetRGB'):
39
+ return MicroscopyDatasetRGB(I_ratio=I_ratio)
40
+
41
+ raise ValueError(name)
42
+
43
+
44
+ def load_image(path):
45
+ file_type = path.split('.')[-1].lower()
46
+ if file_type == 'dng':
47
+ img = rawpy.imread(path).raw_image_visible
48
+ elif file_type == 'tiff' or file_type == 'tif':
49
+ img = np.array(tiff.imread(path), dtype=np.float32)
50
+ else:
51
+ img = np.array(Image.open(path), dtype=np.float32)
52
+ return img
53
+
54
+
55
+ def list_images_in_dir(path):
56
+ image_list = [os.path.join(path, img_name)
57
+ for img_name in sorted(os.listdir(path))
58
+ if img_name.split('.')[-1].lower() in IMAGE_FILE_TYPES]
59
+ return image_list
60
+
61
+
62
+ class ImageFolderDataset(Dataset):
63
+ """Creates a dataset of images in img_dir and corresponding masks in mask_dir.
64
+ Corresponding mask files need to contain the filename of the image.
65
+ Files are expected to be of the same filetype.
66
+
67
+ Args:
68
+ img_dir (str): path to image folder
69
+ mask_dir (str): path to mask folder
70
+ transform (callable, optional): transformation to apply to image and mask
71
+ bits (int, optional): normalize image by dividing by 2^bits - 1
72
+ """
73
+
74
+ task = 'classification'
75
+
76
+ def __init__(self, img_dir, labels, transform=None, bits=1):
77
+
78
+ self.img_dir = img_dir
79
+ self.labels = labels
80
+
81
+ self.images = list_images_in_dir(img_dir)
82
+
83
+ assert len(self.images) == len(self.labels)
84
+
85
+ self.transform = transform
86
+ self.bits = bits
87
+
88
+ def __repr__(self):
89
+ rep = f"{type(self).__name__}: ImageFolderDataset[{len(self.images)}]"
90
+ for n, (img, label) in enumerate(zip(self.images, self.labels)):
91
+ rep += f'\nimage: {img}\tlabel: {label}'
92
+ if n > 10:
93
+ rep += '\n...'
94
+ break
95
+ return rep
96
+
97
+ def __len__(self):
98
+ return len(self.images)
99
+
100
+ def __getitem__(self, idx):
101
+
102
+ label = self.labels[idx]
103
+
104
+ img = load_image(self.images[idx])
105
+ img = img / (2**self.bits - 1)
106
+ if self.transform is not None:
107
+ img = self.transform(img)
108
+
109
+ if len(img.shape) == 2:
110
+ assert img.shape == (256, 256), f"Invalid size for {self.images[idx]}"
111
+ else:
112
+ assert img.shape == (3, 256, 256), f"Invalid size for {self.images[idx]}"
113
+
114
+ return img, label
115
+
116
+
117
+ class ImageFolderDatasetSegmentation(Dataset):
118
+ """Creates a dataset of images in `img_dir` and corresponding masks in `mask_dir`.
119
+ Corresponding mask files need to contain the filename of the image.
120
+ Files are expected to be of the same filetype.
121
+
122
+ Args:
123
+ img_dir (str): path to image folder
124
+ mask_dir (str): path to mask folder
125
+ transform (callable, optional): transformation to apply to image and mask
126
+ bits (int, optional): normalize image by dividing by 2^bits - 1
127
+ """
128
+
129
+ task = 'segmentation'
130
+
131
+ def __init__(self, img_dir, mask_dir, transform=None, bits=1):
132
+
133
+ self.img_dir = img_dir
134
+ self.mask_dir = mask_dir
135
+
136
+ self.images = list_images_in_dir(img_dir)
137
+ self.masks = list_images_in_dir(mask_dir)
138
+
139
+ check_image_folder_consistency(self.images, self.masks)
140
+
141
+ self.transform = transform
142
+ self.bits = bits
143
+
144
+ def __repr__(self):
145
+ rep = f"{type(self).__name__}: ImageFolderDatasetSegmentation[{len(self.images)}]"
146
+ for n, (img, mask) in enumerate(zip(self.images, self.masks)):
147
+ rep += f'\nimage: {img}\tmask: {mask}'
148
+ if n > 10:
149
+ rep += '\n...'
150
+ break
151
+ return rep
152
+
153
+ def __len__(self):
154
+ return len(self.images)
155
+
156
+ def __getitem__(self, idx):
157
+
158
+ img = load_image(self.images[idx])
159
+ mask = load_image(self.masks[idx])
160
+
161
+ img = img / (2**self.bits - 1)
162
+ mask = (mask > 0).astype(np.float32)
163
+
164
+ if self.transform is not None:
165
+ img = self.transform(img)
166
+
167
+ return img, mask
168
+
169
+ class MultiIntensity(Dataset):
170
+ """Wrap datasets with different intesities
171
+
172
+ Args:
173
+ datasets (list): list of datasets to wrap
174
+ """
175
+ def __init__(self, datasets):
176
+ self.dataset = datasets[0]
177
+
178
+ for d in range(1,len(datasets)):
179
+ self.dataset.images = self.dataset.images+datasets[d].images
180
+ self.dataset.labels = self.dataset.labels+datasets[d].labels
181
+
182
+ def __len__(self):
183
+ return len(self.dataset)
184
+
185
+ def __repr__(self):
186
+ return f"Subset [{len(self.dataset)}] of " + repr(self.dataset)
187
+
188
+ def __getitem__(self, idx):
189
+ x, y = self.dataset[idx]
190
+ if self.transform is not None:
191
+ x = self.transform(x)
192
+ return x, y
193
+
194
+ class Subset(Dataset):
195
+ """Define a subset of a dataset by only selecting given indices.
196
+
197
+ Args:
198
+ dataset (Dataset): full dataset
199
+ indices (list): subset indices
200
+ """
201
+
202
+ def __init__(self, dataset, indices=None, transform=None):
203
+ self.dataset = dataset
204
+ self.indices = indices if indices is not None else range(len(dataset))
205
+ self.transform = transform
206
+
207
+ def __len__(self):
208
+ return len(self.indices)
209
+
210
+ def __repr__(self):
211
+ return f"Subset [{len(self)}] of " + repr(self.dataset)
212
+
213
+ def __getitem__(self, idx):
214
+ x, y = self.dataset[self.indices[idx]]
215
+ if self.transform is not None:
216
+ x = self.transform(x)
217
+ return x, y
218
+
219
+
220
+ class DroneDatasetSegmentationFull(ImageFolderDatasetSegmentation):
221
+ """Dataset consisting of full-sized numpy images and masks. Images are normalized to range [0, 1].
222
+ """
223
+
224
+ black_level = [0.0625, 0.0626, 0.0625, 0.0626]
225
+ white_balance = [2.86653646, 1., 1.73079425]
226
+ colour_matrix = [1.50768983, -0.33571374, -0.17197604, -0.23048614,
227
+ 1.70698738, -0.47650126, -0.03119153, -0.32803956, 1.35923111]
228
+ camera_parameters = black_level, white_balance, colour_matrix
229
+
230
+ def __init__(self, I_ratio=1.0, transform=None, force_download=False, bits=16):
231
+
232
+ assert I_ratio in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]
233
+
234
+ img_dir = f'data/drone/images_full/raw_scale{int(I_ratio*100):03d}'
235
+ mask_dir = 'data/drone/masks_full'
236
+
237
+ download_drone_dataset(force_download) # XXX: zip files and add checksum? date?
238
+
239
+ super().__init__(img_dir=img_dir, mask_dir=mask_dir, transform=transform, bits=bits)
240
+
241
+
242
+ class DroneDatasetSegmentationTiled(ImageFolderDatasetSegmentation):
243
+ """Dataset consisting of tiled numpy images and masks. Images are in range [0, 1]
244
+ Args:
245
+ tile_size (int, optional): size of the tiled images. Defaults to 256.
246
+ """
247
+
248
+ camera_parameters = DroneDatasetSegmentationFull.camera_parameters
249
+
250
+ def __init__(self, I_ratio=1.0, transform=None):
251
+
252
+ tile_size = 256
253
+
254
+ img_dir = f'data/drone/images_tiles_{tile_size}/raw_scale{int(I_ratio*100):03d}'
255
+ mask_dir = f'data/drone/masks_tiles_{tile_size}'
256
+
257
+ if not os.path.exists(img_dir) or not os.path.exists(mask_dir):
258
+ dataset_full = DroneDatasetSegmentationFull(I_ratio=I_ratio, bits=1)
259
+ print("tiling dataset..")
260
+ create_tiles_dataset(dataset_full, img_dir, mask_dir, tile_size=tile_size)
261
+
262
+ super().__init__(img_dir=img_dir, mask_dir=mask_dir, transform=transform, bits=16)
263
+
264
+
265
+ class DroneDatasetClassificationTiled(ImageFolderDataset):
266
+
267
+ camera_parameters = DroneDatasetSegmentationFull.camera_parameters
268
+
269
+ def __init__(self, I_ratio=1.0, transform=None):
270
+
271
+ random_state = 72
272
+ tile_size = 256
273
+ thr = 0.01
274
+
275
+ img_dir = f'data/drone/classification/images_tiles_{tile_size}/raw_scale{int(I_ratio*100):03d}_thr_{thr}'
276
+ mask_dir = f'data/drone/classification/masks_tiles_{tile_size}_thr_{thr}'
277
+ df_path = f'data/drone/classification/dataset_tiles_{tile_size}_{random_state}_{thr}.csv'
278
+
279
+ if not os.path.exists(img_dir) or not os.path.exists(mask_dir):
280
+ dataset_full = DroneDatasetSegmentationFull(I_ratio=I_ratio, bits=1)
281
+ print("tiling dataset..")
282
+ create_tiles_dataset_binary(dataset_full, img_dir, mask_dir, random_state, thr, tile_size=tile_size)
283
+
284
+ self.classes = ['car', 'no car']
285
+ self.df = pd.read_csv(df_path)
286
+ labels = self.df['label'].to_list()
287
+
288
+ super().__init__(img_dir=img_dir, labels=labels, transform=transform, bits=16)
289
+
290
+ images, class_labels = read_label_csv(self.df)
291
+ self.images = [os.path.join(self.img_dir, image) for image in images]
292
+ self.labels = class_labels
293
+
294
+
295
+ class MicroscopyDataset(ImageFolderDataset):
296
+ """MicroscopyDataset raw images
297
+
298
+ Args:
299
+ I_ratio (float): Original image rescaled by this factor, possible values [0.01,0.05,0.1,0.25,0.5,0.75,1.0]
300
+ raw (bool): Select rgb dataset or raw dataset
301
+ transform (callable, optional): transformation to apply to image and mask
302
+ bits (int, optional): normalize image by dividing by 2^bits - 1
303
+ """
304
+
305
+ black_level = [9.834368023181512e-06, 9.834368023181512e-06, 9.834368023181512e-06, 9.834368023181512e-06]
306
+ white_balance = [-0.6567, 1.9673, 3.5304]
307
+ colour_matrix = [-2.0338, 0.0933, 0.4157, -0.0286, 2.6464, -0.0574, -0.5516, -0.0947, 2.9308]
308
+
309
+ camera_parameters = black_level, white_balance, colour_matrix
310
+
311
+ dataset_mean = [0.91, 0.84, 0.94]
312
+ dataset_std = [0.08, 0.12, 0.05]
313
+
314
+ def __init__(self, I_ratio=1.0, transform=None, bits=16, force_download=False):
315
+
316
+ assert I_ratio in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]
317
+
318
+ download_microscopy_dataset(force_download=force_download)
319
+
320
+ self.img_dir = f'data/microscopy/images/raw_scale{int(I_ratio*100):03d}'
321
+ self.transform = transform
322
+ self.bits = bits
323
+
324
+ self.label_file = 'data/microscopy/labels/Ma190c_annotations.dat'
325
+
326
+ self.valid_classes = ['BAS', 'EBO', 'EOS', 'KSC', 'LYA', 'LYT', 'MMZ', 'MOB',
327
+ 'MON', 'MYB', 'MYO', 'NGB', 'NGS', 'PMB', 'PMO', 'UNC']
328
+
329
+ self.invalid_files = ['Ma190c_lame3_zone13_composite_Mcropped_2.tiff', ]
330
+
331
+ images, class_labels = read_label_file(self.label_file)
332
+
333
+ # filter classes with low appearance
334
+ self.valid_classes = [class_label for class_label in self.valid_classes
335
+ if class_labels.count(class_label) > 4]
336
+
337
+ # remove invalid classes and invalid files from (images, class_labels)
338
+ images, class_labels = list(zip(*[
339
+ (image, class_label)
340
+ for image, class_label in zip(images, class_labels)
341
+ if class_label in self.valid_classes and image not in self.invalid_files
342
+ ]))
343
+
344
+ self.classes = list(sorted({*class_labels}))
345
+
346
+ # store full path
347
+ self.images = [os.path.join(self.img_dir, image) for image in images]
348
+
349
+ # reindex labels
350
+ self.labels = [self.classes.index(class_label) for class_label in class_labels]
351
+
352
+
353
+ class MicroscopyDatasetRGB(MicroscopyDataset):
354
+ """MicroscopyDataset RGB images
355
+
356
+ Args:
357
+ I_ratio (float): Original image rescaled by this factor, possible values [0.01,0.05,0.1,0.25,0.5,0.75,1.0]
358
+ raw (bool): Select rgb dataset or raw dataset
359
+ transform (callable, optional): transformation to apply to image and mask
360
+ bits (int, optional): normalize image by dividing by 2^bits - 1
361
+ """
362
+ camera_parameters = None
363
+
364
+ dataset_mean = None
365
+ dataset_std = None
366
+
367
+ def __init__(self, I_ratio=1.0, transform=None, bits=16, force_download=False):
368
+ super().__init__(I_ratio=I_ratio, transform=transform, bits=bits, force_download=force_download)
369
+ self.images = [image.replace('raw', 'rgb') for image in self.images] # XXX: hack
370
+
371
+
372
+ def read_label_file(label_file_path):
373
+
374
+ images = []
375
+ class_labels = []
376
+
377
+ with open(label_file_path, "rb") as data:
378
+ for line in data:
379
+ file_name, class_label = line.decode("utf-8").split()
380
+ image = file_name + '.tiff'
381
+ images.append(image)
382
+ class_labels.append(class_label)
383
+
384
+ return images, class_labels
385
+
386
+
387
+ def read_label_csv(df):
388
+
389
+ images = []
390
+ class_labels = []
391
+
392
+ for file_name, label in zip(df['file name'], df['label']):
393
+ image = file_name + '.tif'
394
+ images.append(image)
395
+ class_labels.append(int(label))
396
+ return images, class_labels
397
+
398
+
399
+ def download_drone_dataset(force_download):
400
+ b2_download_folder('drone/images', 'data/drone/images_full', force_download=force_download)
401
+ b2_download_folder('drone/masks', 'data/drone/masks_full', force_download=force_download)
402
+ unzip_drone_images()
403
+
404
+
405
+ def download_microscopy_dataset(force_download):
406
+ b2_download_folder('Data histopathology/WhiteCellsImages',
407
+ 'data/microscopy/images', force_download=force_download)
408
+ b2_download_folder('Data histopathology/WhiteCellsLabels',
409
+ 'data/microscopy/labels', force_download=force_download)
410
+ unzip_microscopy_images()
411
+
412
+
413
+ def unzip_microscopy_images():
414
+
415
+ if os.path.isfile('data/microscopy/labels/.bzEmpty'):
416
+ os.remove('data/microscopy/labels/.bzEmpty')
417
+
418
+ for file in os.listdir('data/microscopy/images'):
419
+ if file.endswith(".zip"):
420
+ zip = zipfile.ZipFile(os.path.join('data/microscopy/images', file))
421
+ zip.extractall('data/microscopy/images')
422
+ os.remove(os.path.join('data/microscopy/images', file))
423
+
424
+ def unzip_drone_images():
425
+
426
+ if os.path.isfile('data/drone/masks_full/.bzEmpty'):
427
+ os.remove('data/drone/masks_full/.bzEmpty')
428
+
429
+ for file in os.listdir('data/drone/images_full'):
430
+ if file.endswith(".zip"):
431
+ zip = zipfile.ZipFile(os.path.join('data/drone/images_full', file))
432
+ zip.extractall('data/drone/images_full')
433
+ os.remove(os.path.join('data/drone/images_full', file))
434
+
435
+
436
+ def create_tiles_dataset(dataset, img_dir, mask_dir, tile_size=256):
437
+ for folder in [img_dir, mask_dir]:
438
+ if not os.path.exists(folder):
439
+ os.makedirs(folder)
440
+ for n, (img, mask) in enumerate(dataset):
441
+ tiled_img = split_img(img, ROIs=(tile_size, tile_size), step=(tile_size, tile_size))
442
+ tiled_mask = split_img(mask, ROIs=(tile_size, tile_size), step=(tile_size, tile_size))
443
+ tiled_img, tiled_mask = class_detection(tiled_img, tiled_mask) # Remove images without cars in it
444
+ for i, (sub_img, sub_mask) in enumerate(zip(tiled_img, tiled_mask)):
445
+ tile_id = f"{n:02d}_{i:05d}"
446
+ Image.fromarray(sub_img).save(os.path.join(img_dir, tile_id + '.tif'))
447
+ Image.fromarray(sub_mask > 0).save(os.path.join(mask_dir, tile_id + '.png'))
448
+
449
+
450
+ def create_tiles_dataset_binary(dataset, img_dir, mask_dir, random_state, thr, tile_size=256):
451
+
452
+ for folder in [img_dir, mask_dir]:
453
+ if not os.path.exists(folder):
454
+ os.makedirs(folder)
455
+
456
+ ids = []
457
+ labels = []
458
+
459
+ for n, (img, mask) in enumerate(dataset):
460
+ tiled_img = split_img(img, ROIs=(tile_size, tile_size), step=(tile_size, tile_size))
461
+ tiled_mask = split_img(mask, ROIs=(tile_size, tile_size), step=(tile_size, tile_size))
462
+
463
+ X_with, X_without, Y_with, Y_without = binary_class_detection(
464
+ tiled_img, tiled_mask, random_state, thr) # creates balanced arrays with class and without class
465
+
466
+ for i, (sub_X_with, sub_Y_with) in enumerate(zip(X_with, Y_with)):
467
+ tile_id = f"{n:02d}_{i:05d}"
468
+ ids.append(tile_id)
469
+ labels.append(0)
470
+ Image.fromarray(sub_X_with).save(os.path.join(img_dir, tile_id + '.tif'))
471
+ Image.fromarray(sub_Y_with > 0).save(os.path.join(mask_dir, tile_id + '.png'))
472
+ for j, (sub_X_without, sub_Y_without) in enumerate(zip(X_without, Y_without)):
473
+ tile_id = f"{n:02d}_{i+1+j:05d}"
474
+ ids.append(tile_id)
475
+ labels.append(1)
476
+ Image.fromarray(sub_X_without).save(os.path.join(img_dir, tile_id + '.tif'))
477
+ Image.fromarray(sub_Y_without > 0).save(os.path.join(mask_dir, tile_id + '.png'))
478
+ # Image.fromarray(sub_mask).save(os.path.join(mask_dir, tile_id + '.png'))
479
+
480
+ df = pd.DataFrame({'file name': ids, 'label': labels})
481
+
482
+ df_loc = f'data/drone/classification/dataset_tiles_{tile_size}_{random_state}_{thr}.csv'
483
+ df.to_csv(df_loc)
484
+
485
+ return
486
+
487
+
488
+ def class_detection(X, Y):
489
+ """Split dataset in images which has the class in the target
490
+
491
+ Args:
492
+ X (ndarray): input image
493
+ Y (ndarray): target with segmentation map (images with {0,1} values where it is 1 when there is the class)
494
+ Returns:
495
+ X_with_class (ndarray): input regions with the selected class
496
+ Y_with_class (ndarray): target regions with the selected class
497
+ X_without_class (ndarray): input regions without the selected class
498
+ Y_without_class (ndarray): target regions without the selected class
499
+ """
500
+
501
+ with_class = []
502
+ without_class = []
503
+ for i, img in enumerate(Y):
504
+ if img.mean() == 0:
505
+ without_class.append(i)
506
+ else:
507
+ with_class.append(i)
508
+
509
+ X_with_class = np.delete(X, without_class, 0)
510
+ Y_with_class = np.delete(Y, without_class, 0)
511
+
512
+ return X_with_class, Y_with_class
513
+
514
+
515
+ def binary_class_detection(X, Y, random_seed, thr):
516
+ """Splits subimages in subimages with the selected class and without the selected class by calculating the mean of the submasks; subimages with 0 < submask.mean()<=thr are disregared
517
+
518
+
519
+
520
+ Args:
521
+ X (ndarray): input image
522
+ Y (ndarray): target with segmentation map (images with {0,1} values where it is 1 when there is the class)
523
+ thr (flaot): sub images are not considered if 0 < sub_target.mean() <= thr
524
+ balanced (bool): number of returned sub images is equal for both classes if true
525
+ random_seed (None or int): selection of sub images in class with more elements according to random_seed if balanced
526
+ Returns:
527
+ X_with_class (ndarray): input regions with the selected class
528
+ Y_with_class (ndarray): target regions with the selected class
529
+ X_without_class (ndarray): input regions without the selected class
530
+ Y_without_class (ndarray): target regions without the selected class
531
+ """
532
+
533
+ with_class = []
534
+ without_class = []
535
+ no_class = []
536
+
537
+ for i, img in enumerate(Y):
538
+ m = img.mean()
539
+ if m == 0:
540
+ without_class.append(i)
541
+ else:
542
+ if m > thr:
543
+ with_class.append(i)
544
+ else:
545
+ no_class.append(i)
546
+
547
+ N = len(with_class)
548
+ M = len(without_class)
549
+ random.seed(random_seed)
550
+ if N <= M:
551
+ random.shuffle(without_class)
552
+ with_class.extend(without_class[:M - N])
553
+ else:
554
+ random.shuffle(with_class)
555
+ without_class.extend(with_class[:N - M])
556
+
557
+ X_with_class = np.delete(X, without_class + no_class, 0)
558
+ X_without_class = np.delete(X, with_class + no_class, 0)
559
+ Y_with_class = np.delete(Y, without_class + no_class, 0)
560
+ Y_without_class = np.delete(Y, with_class + no_class, 0)
561
+
562
+ return X_with_class, X_without_class, Y_with_class, Y_without_class
563
+
564
+
565
+ def make_dataloader(dataset, batch_size, shuffle=True):
566
+
567
+ X, Y = dataset
568
+
569
+ X, Y = np2torch(X), np2torch(Y)
570
+
571
+ dataset = TensorDataset(X, Y)
572
+ dataset = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
573
+
574
+ return dataset
575
+
576
+
577
+ def check_image_folder_consistency(images, masks):
578
+ file_type_images = images[0].split('.')[-1].lower()
579
+ file_type_masks = masks[0].split('.')[-1].lower()
580
+ assert len(images) == len(masks), "images / masks length mismatch"
581
+ for img_file, mask_file in zip(images, masks):
582
+ img_name = img_file.split('/')[-1].split('.')[0]
583
+ assert img_name in mask_file, f"image {img_file} corresponds to {mask_file}?"
584
+ assert img_file.split('.')[-1].lower() == file_type_images, \
585
+ f"image file {img_file} file type mismatch. Shoule be: {file_type_images}"
586
+ assert mask_file.split('.')[-1].lower() == file_type_masks, \
587
+ f"image file {mask_file} file type mismatch. Should be: {file_type_masks}"
588
+
589
+
590
+ def k_fold(dataset, n_splits: int, seed: int, train_size: float):
591
+ """Split dataset in subsets for cross-validation
592
+
593
+ Args:
594
+ dataset (class): dataset to split
595
+ n_split (int): Number of re-shuffling & splitting iterations.
596
+ seed (int): seed for k_fold splitting
597
+ train_size (float): should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the train split.
598
+ Returns:
599
+ idxs (list): indeces for splitting the dataset. The list contain n_split pair of train/test indeces.
600
+ """
601
+ if hasattr(dataset, 'labels'):
602
+ x = dataset.images
603
+ y = dataset.labels
604
+ elif hasattr(dataset, 'masks'):
605
+ x = dataset.images
606
+ y = dataset.masks
607
+
608
+ idxs = []
609
+
610
+ if dataset.task == 'classification':
611
+ sss = StratifiedShuffleSplit(n_splits=n_splits, train_size=train_size, random_state=seed)
612
+
613
+ for idxs_train, idxs_test in sss.split(x, y):
614
+ idxs.append((idxs_train.tolist(), idxs_test.tolist()))
615
+
616
+ elif dataset.task == 'segmentation':
617
+ for n in range(n_splits):
618
+ split_idx = int(len(dataset) * train_size)
619
+ indices = np.random.permutation(len(dataset))
620
+ idxs.append((indices[:split_idx].tolist(), indices[split_idx:].tolist()))
621
+
622
+ return idxs
utils/debug.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import inspect
4
+ from functools import reduce, wraps
5
+ from collections.abc import Iterable
6
+ from IPython import embed
7
+
8
+ try:
9
+ get_ipython() # pylint: disable=undefined-variable
10
+ interactive_notebook = True
11
+ except:
12
+ interactive_notebook = False
13
+
14
+ _NONE = "__UNSET_VARIABLE__"
15
+
16
+
17
+ def debug_init():
18
+ debug.disable = False
19
+ debug.silent = False
20
+ debug.verbose = 2
21
+ debug.expand_ignore = ["DataLoader", "Dataset", "Subset"]
22
+ debug.max_expand = 10
23
+ debug.show_tensor = False
24
+ debug.raise_exception = True
25
+ debug.full_stack = True
26
+ debug.restore_defaults_on_exception = not interactive_notebook
27
+ debug._indent = 0
28
+ debug._stack = ""
29
+
30
+ debug.embed = embed
31
+ debug.show = debug_show
32
+ debug.pause = debug_pause
33
+
34
+
35
+ def debug_pause():
36
+ input("Press Enter to continue...")
37
+
38
+
39
+ def debug(*args, assert_true=False):
40
+ """Decorator for debugging functions and tensors.
41
+ Will throw an exception as soon as a nan is encountered.
42
+ If used on iterables, these will be expanded and also searched for nans.
43
+ Usage:
44
+ debug(x)
45
+ Or:
46
+ @debug
47
+ def function():
48
+ ...
49
+ If used as a function wrapper, all arguments will be searched and printed.
50
+ """
51
+
52
+ single_arg = len(args) == 1
53
+
54
+ if debug.disable:
55
+ return args[0] if single_arg else None
56
+
57
+ try:
58
+ call_line = ''.join(inspect.stack()[1][4]).strip()
59
+ except:
60
+ call_line = '...'
61
+ used_as_wrapper = 'def ' == call_line[:4]
62
+ expect_return_arg = single_arg and 'debug' in call_line and call_line.split('debug')[0].strip() != ''
63
+ is_func = single_arg and hasattr(args[0], '__call__')
64
+
65
+ if is_func and (used_as_wrapper or expect_return_arg):
66
+ func = args[0]
67
+ sig_parameters = inspect.signature(func).parameters
68
+ sig_argnames = [p.name for p in sig_parameters.values()]
69
+ sig_defaults = {
70
+ k: v.default
71
+ for k, v in sig_parameters.items()
72
+ if v.default is not inspect.Parameter.empty
73
+ }
74
+
75
+ @wraps(func)
76
+ def _func(*args, **kwargs):
77
+ if debug.disable:
78
+ return func(*args, **kwargs)
79
+
80
+ if debug._indent == 0:
81
+ debug._stack = ""
82
+ stack_before = debug._stack
83
+ indent = ' ' * 4 * debug._indent
84
+ debug._indent += 1
85
+
86
+ args_kw = dict(zip(sig_argnames, args))
87
+ defaults = {k: v for k, v in sig_defaults.items()
88
+ if k not in kwargs
89
+ if k not in args_kw}
90
+ all_args = {**args_kw, **kwargs, **defaults}
91
+
92
+ func_name = None
93
+ if hasattr(func, '__name__'):
94
+ func_name = func.__name__
95
+ elif hasattr(func, '__class__'):
96
+ func_name = func.__class__.__name__
97
+
98
+ if func_name is None:
99
+ func_name = '... ' + call_line + '...'
100
+ else:
101
+ func_name = '@' + func_name + '()'
102
+
103
+ _debug_log('', indent=indent)
104
+ _debug_log(func_name, indent=indent)
105
+
106
+ debug._last_call = func
107
+ debug._last_args = all_args
108
+ debug._last_args_sig = sig_argnames
109
+
110
+ for argtype, params in [("args", args_kw.items()),
111
+ ("kwargs", kwargs.items()),
112
+ ("defaults", defaults.items())]:
113
+ if params:
114
+ _debug_log(f"{argtype}:", indent=indent + ' ' * 6)
115
+ for argname, arg in params:
116
+ if argname == 'self':
117
+ # _debug_log(f"- self: ...", indent=indent + ' ' * 8)
118
+ pass
119
+ else:
120
+ _debug_log(f"- {argname}: ", arg, indent + ' ' * 8, assert_true)
121
+ try:
122
+ out = func(*args, **kwargs)
123
+ except:
124
+ _debug_crash_save()
125
+ debug._stack = ""
126
+ debug._indent = 0
127
+ raise
128
+ debug.out = out
129
+ _debug_log("returned: ", out, indent, assert_true)
130
+ _debug_log('', indent=indent)
131
+ debug._indent -= 1
132
+ if not debug.full_stack:
133
+ debug._stack = stack_before
134
+ return out
135
+ return _func
136
+ else:
137
+ if debug._indent == 0:
138
+ debug._stack = ""
139
+ argname = ')'.join('('.join(call_line.split('(')[1:]).split(')')[:-1])
140
+ if assert_true:
141
+ argname = ','.join(argname.split(',')[:-1])
142
+ _debug_log(f"assert{{{argname}}} ", args[0], ' ' * 4 * debug._indent, assert_true)
143
+ else:
144
+ for arg in args:
145
+ _debug_log(f"{{{argname}}} = ", arg, ' ' * 4 * debug._indent, assert_true)
146
+ if expect_return_arg:
147
+ return args[0]
148
+ return
149
+
150
+
151
+ def is_iterable(x):
152
+ return isinstance(x, Iterable) or hasattr(x, '__getitem__') and not isinstance(x, str)
153
+
154
+
155
+ def ndarray_repr(t, assert_all=False):
156
+ exception_encountered = False
157
+ info = []
158
+ shape = tuple(t.shape)
159
+ single_entry = shape == () or shape == (1,)
160
+ if single_entry:
161
+ info.append(f"[{t.item():.4f}]")
162
+ else:
163
+ info.append(f"({', '.join(map(repr, shape))})")
164
+ invalid_sum = (~np.isfinite(t)).sum().item()
165
+ if invalid_sum:
166
+ info.append(
167
+ f"{invalid_sum} INVALID ENTR{'Y' if invalid_sum == 1 else 'IES'}")
168
+ exception_encountered = True
169
+ if debug.verbose > 1:
170
+ if not invalid_sum and not single_entry:
171
+ info.append(f"|x|={np.linalg.norm(t):.1f}")
172
+ if t.size:
173
+ info.append(f"x in [{t.min():.1f}, {t.max():.1f}]")
174
+ if debug.verbose and t.dtype != np.float:
175
+ info.append(f"dtype={str(t.dtype)}".replace("'", ''))
176
+ if assert_all:
177
+ assert_val = t.all()
178
+ if not assert_val:
179
+ exception_encountered = True
180
+ if assert_all and not exception_encountered:
181
+ output = "passed"
182
+ else:
183
+ if assert_all and not assert_val:
184
+ output = f"ndarray({info[0]})"
185
+ else:
186
+ output = f"ndarray({', '.join(info)})"
187
+ if exception_encountered and (not hasattr(debug, 'raise_exception') or debug.raise_exception):
188
+ if debug.restore_defaults_on_exception:
189
+ debug.raise_exception = False
190
+ debug.silent = False
191
+ debug.x = t
192
+ msg = output
193
+ debug._stack += output
194
+ if debug._stack and '\n' in debug._stack:
195
+ msg += '\nSTACK: ' + debug._stack
196
+ if assert_all:
197
+ assert assert_val, "Assert did not pass on " + msg
198
+ raise Exception("Invalid entries encountered in " + msg)
199
+ return output
200
+
201
+
202
+ def tensor_repr(t, assert_all=False):
203
+ exception_encountered = False
204
+ info = []
205
+ shape = tuple(t.shape)
206
+ single_entry = shape == () or shape == (1,)
207
+ if single_entry:
208
+ info.append(f"[{t.item():.3f}]")
209
+ else:
210
+ info.append(f"({', '.join(map(repr, shape))})")
211
+ invalid_sum = (~torch.isfinite(t)).sum().item()
212
+ if invalid_sum:
213
+ info.append(
214
+ f"{invalid_sum} INVALID ENTR{'Y' if invalid_sum == 1 else 'IES'}")
215
+ exception_encountered = True
216
+ if debug.verbose and t.requires_grad:
217
+ info.append('req_grad')
218
+ if debug.verbose > 2:
219
+ if t.is_leaf:
220
+ info.append('leaf')
221
+ if hasattr(t, 'retains_grad') and t.retains_grad:
222
+ info.append('retains_grad')
223
+ has_grad = (t.is_leaf or hasattr(t, 'retains_grad') and t.retains_grad) and t.grad is not None
224
+ if has_grad:
225
+ grad_invalid_sum = (~torch.isfinite(t.grad)).sum().item()
226
+ if grad_invalid_sum:
227
+ info.append(
228
+ f"GRAD {grad_invalid_sum} INVALID ENTR{'Y' if grad_invalid_sum == 1 else 'IES'}")
229
+ exception_encountered = True
230
+ if debug.verbose > 1:
231
+ if not invalid_sum and not single_entry:
232
+ info.append(f"|x|={t.float().norm():.1f}")
233
+ if t.numel():
234
+ info.append(f"x in [{t.min():.2f}, {t.max():.2f}]")
235
+ if has_grad and not grad_invalid_sum:
236
+ if single_entry:
237
+ info.append(f"grad={t.grad.float().item():.3f}")
238
+ else:
239
+ info.append(f"|grad|={t.grad.float().norm():.1f}")
240
+ if debug.verbose and t.dtype != torch.float:
241
+ info.append(f"dtype={str(t.dtype).split('.')[-1]}")
242
+ if debug.verbose and t.device.type != 'cpu':
243
+ info.append(f"device={t.device.type}")
244
+ if assert_all:
245
+ assert_val = t.all()
246
+ if not assert_val:
247
+ exception_encountered = True
248
+ if assert_all and not exception_encountered:
249
+ output = "passed"
250
+ else:
251
+ if assert_all and not assert_val:
252
+ output = f"tensor({info[0]})"
253
+ else:
254
+ output = f"tensor({', '.join(info)})"
255
+ if exception_encountered and (not hasattr(debug, 'raise_exception') or debug.raise_exception):
256
+ if debug.restore_defaults_on_exception:
257
+ debug.raise_exception = False
258
+ debug.silent = False
259
+ debug.x = t
260
+ msg = output
261
+ debug._stack += output
262
+ if debug._stack and '\n' in debug._stack:
263
+ msg += '\nSTACK: ' + debug._stack
264
+ if assert_all:
265
+ assert assert_val, "Assert did not pass on " + msg
266
+ raise Exception("Invalid entries encountered in " + msg)
267
+ return output
268
+
269
+
270
+ def _debug_crash_save():
271
+ if debug._indent:
272
+ debug.args = debug._last_args
273
+ debug.func = debug._last_call
274
+
275
+ @wraps(debug.func)
276
+ def _recall(*args, **kwargs):
277
+ call_args = {**debug.args, **kwargs, **dict(zip(debug._last_args_sig, args))}
278
+ return debug(debug.func)(**call_args)
279
+
280
+ def print_stack(stack=debug._stack):
281
+ print('\nSTACK: ' + stack)
282
+ debug.stack = print_stack
283
+
284
+ debug.recall = _recall
285
+ debug._indent = 0
286
+
287
+
288
+ def _debug_log(output, var=_NONE, indent='', assert_true=False, expand=True):
289
+ debug._stack += indent + output
290
+ if not debug.silent:
291
+ print(indent + output, end='')
292
+ if var is not _NONE:
293
+ type_str = type(var).__name__.lower()
294
+ if var is None:
295
+ _debug_log('None')
296
+ elif isinstance(var, str):
297
+ _debug_log(f"'{var}'")
298
+ elif type_str == 'ndarray':
299
+ _debug_log(ndarray_repr(var, assert_true))
300
+ if debug.show_tensor:
301
+ _debug_show_print(var, indent=indent + 4 * ' ')
302
+ # elif type_str in ('tensor', 'parameter'):
303
+ elif type_str == 'tensor':
304
+ _debug_log(tensor_repr(var, assert_true))
305
+ if debug.show_tensor:
306
+ _debug_show_print(var, indent=indent + 4 * ' ')
307
+ elif hasattr(var, 'named_parameters'):
308
+ _debug_log(type_str)
309
+ params = list(var.named_parameters())
310
+ _debug_log(f"{type_str}[{len(params)}] {{")
311
+ for k, v in params:
312
+ _debug_log(f"'{k}': ", v, indent + 6 * ' ')
313
+ _debug_log(indent + 4 * ' ' + '}')
314
+ elif is_iterable(var):
315
+ expand = debug.expand_ignore != '*' and expand
316
+ if expand:
317
+ if isinstance(debug.expand_ignore, str):
318
+ if type_str == str(debug.expand_ignore).lower():
319
+ expand = False
320
+ elif is_iterable(debug.expand_ignore):
321
+ for ignore in debug.expand_ignore:
322
+ if type_str == ignore.lower():
323
+ expand = False
324
+ if hasattr(var, '__len__'):
325
+ length = len(var)
326
+ else:
327
+ var = list(var)
328
+ length = len(var)
329
+ if expand and length > 0:
330
+ _debug_log(f"{type_str}[{length}] {{")
331
+ if isinstance(var, dict):
332
+ for k, v in var.items():
333
+ _debug_log(f"'{k}': ", v, indent + 6 * ' ', assert_true)
334
+ else:
335
+ i = 0
336
+ for k, i in zip(var, range(debug.max_expand)):
337
+ _debug_log('- ', k, indent + 6 * ' ', assert_true)
338
+ if i < length - 1:
339
+ _debug_log('- ' + ' ' * 6 + '...', indent=indent + 6 * ' ')
340
+ _debug_log(indent + 4 * ' ' + '}')
341
+ else:
342
+ _debug_log(f"{type_str}[{length}]")
343
+ else:
344
+ _debug_log(str(var))
345
+ else:
346
+ debug._stack += '\n'
347
+ if not debug.silent:
348
+ print()
349
+
350
+
351
+ def debug_show(x):
352
+ assert is_iterable(x)
353
+ debug(x)
354
+ _debug_show_print(x, indent=' ' * 4 * debug._indent)
355
+
356
+
357
+ def _debug_show_print(x, indent=''):
358
+ is_tensor = type(x).__name__ in ('Tensor', 'ndarray')
359
+ if is_tensor:
360
+ x = x.flatten()
361
+ if type(x).__name__ == 'Tensor' and x.dim() == 0:
362
+ return
363
+ n_samples = min(10, len(x))
364
+ di = len(x) // n_samples
365
+ var = list(x[i * di] for i in range(n_samples))
366
+ if is_tensor or type(var[0]) == float:
367
+ var = [round(float(v), 4) for v in var]
368
+ _debug_log('--> ', str(var), indent, expand=False)
369
+
370
+
371
+ debug_init()
utils/mutual_entropy.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ from scipy.signal import convolve2d
5
+
6
+ def mse(x,y):
7
+ return ((x-y)**2).mean()
8
+
9
+ def gaussian_noise_entropies(t1, bins=20):
10
+ all_MI= []
11
+ all_mse = []
12
+ for sigma in np.linspace(0,100,201):
13
+ t2 = np.random.normal(t1.copy(), scale=sigma, size = t1.shape)
14
+ hist_2d, x_edges, y_edges = np.histogram2d(
15
+ t1.ravel(),
16
+ t2.ravel(),
17
+ bins=bins)
18
+ all_mse.append(mse(t1,t2))
19
+ MI = mutual_information(hist_2d)
20
+ all_MI.append(MI)
21
+
22
+ return np.array((all_MI)), np.array((all_mse))
23
+
24
+ def shifts_entropies(t1, bins=20):
25
+ all_MI=[]
26
+ all_mse=[]
27
+ for N in np.linspace(1,50,50):
28
+ N = int(N)
29
+ temp_t2 = t1[:-N].copy()
30
+ temp_t1 = t1[N:].copy()
31
+ hist_2d, x_edges, y_edges = np.histogram2d(
32
+ t1.ravel(),
33
+ t2.ravel(),
34
+ bins=bins)
35
+ MI = mutual_information(hist_2d)
36
+
37
+ all_mse.append(mse(temp_t1,temp_t2))
38
+ all_MI.append(MI)
39
+
40
+ return np.array((all_MI)), np.array((all_mse))
41
+
42
+ def mutual_information(hgram):
43
+ """ Mutual information for joint histogram
44
+ """
45
+ # Convert bins counts to probability values
46
+ pxy = hgram / float(np.sum(hgram))
47
+ px = np.sum(pxy, axis=1) # marginal for x over y
48
+ py = np.sum(pxy, axis=0) # marginal for y over x
49
+ px_py = px[:, None] * py[None, :] # Broadcast to multiply marginals
50
+ # Now we can do the calculation using the pxy, px_py 2D arrays
51
+ nzs = pxy > 0 # Only non-zero pxy values contribute to the sum
52
+ return np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs]))
53
+
54
+ def entropy(image, bins=20):
55
+ image = image.ravel()
56
+ hist, bin_edges = np.histogram(image, bins = bins)
57
+ hist = hist/hist.sum()
58
+ entropy_term = np.where(hist != 0, hist*np.log(hist), 0)
59
+ entropy = - np.sum(entropy_term)
60
+
61
+ return entropy
62
+
63
+ # Gray Image
64
+ # t1 = np.array(Image.open("img.png"))[:,:,0].astype(float)
65
+
66
+ # Colour Image
67
+ t1 = np.array(Image.open("img.png").resize((255,255)))
68
+
69
+ perturb = "gauss"
70
+ show_image = True
71
+ bins=20
72
+
73
+ print(perturb)
74
+
75
+ # Identity
76
+ if perturb == "identity":
77
+ t2 = t1
78
+ title = "Identity"
79
+ image1 = "Clean"
80
+ image2 = "Clean"
81
+
82
+ # Poisson Noise on t2
83
+ if perturb == "poisson":
84
+ t2 = np.random.poisson(t1)
85
+ title = "Poisson Noise"
86
+ image1 = "Clean"
87
+ image2 = "Noisy"
88
+
89
+ # Gaussian Noise on t2
90
+ if perturb == "gauss":
91
+ print(np.shape(t1))
92
+ sigma = 50.0
93
+ t2 = np.random.normal(t1.copy(), scale=sigma, size = t1.shape)
94
+ if "grad" in locals():
95
+ title = f"Gaussian Noise, grad= True, sigma = {sigma:.2f}"
96
+ else:
97
+ title = f"Gaussian Noise, sigma = {sigma:.2f}"
98
+ image1 = "Clean"
99
+ image2 = "Noisy"
100
+
101
+ if perturb == "box":
102
+ sigma = 50.0
103
+ mean = np.mean(t1)
104
+ print(np.shape(t1))
105
+ t2 = t1.copy()
106
+ t2[30:220,50:120,:] = mean
107
+ title = "Box with mean pixels"
108
+ image1 = "Clean"
109
+ image2 = "Noisy"
110
+
111
+
112
+ # Shift t2 on y axis
113
+ if perturb == "shift":
114
+ N=30
115
+ t2 = t1[:-N]
116
+ t1 = t1[N:]
117
+ title = "y shift"
118
+ image1 = "Clean"
119
+ image2 = "Shifted"
120
+
121
+ t2 = np.clip(t2,0,255).astype(int)
122
+
123
+ print("Correlation Coefficient: ",np.corrcoef(t1.ravel(), t2.ravel())[0, 1])
124
+
125
+ # 2D Histogram
126
+ hist_2d, x_edges, y_edges = np.histogram2d(
127
+ t1.ravel(),
128
+ t2.ravel(),
129
+ bins=bins)
130
+
131
+ MI = mutual_information(hist_2d)
132
+
133
+ print("Mutual Information", MI)
134
+ print("Mean squared error:", mse(t1,t2))
135
+
136
+ if show_image == True:
137
+ plt.figure()
138
+ plt.imshow(np.hstack((t2, t1)))
139
+ plt.title(title)
140
+
141
+ plt.figure()
142
+
143
+ plt.plot(t1.ravel(), t2.ravel(), '.')
144
+ plt.xlabel(image1)
145
+ plt.ylabel(image2)
146
+ plt.title('I1 vs I2')
147
+
148
+ plt.figure()
149
+ plt.imshow((hist_2d.T)/hist_2d.max(), origin='lower')
150
+ plt.xlabel(image1)
151
+ plt.ylabel(image2)
152
+ plt.xticks(ticks=np.linspace(0,bins-1,10), labels=np.linspace(x_edges.min(),x_edges.max(),10).astype(int))
153
+ plt.yticks(ticks=np.linspace(0,bins-1,10), labels=np.linspace(y_edges.min(),y_edges.max(),10).astype(int))
154
+ plt.title('p(x,y)')
155
+ plt.colorbar()
156
+
157
+ # Show log histogram, avoiding divide by 0
158
+ plt.figure(figsize=(4,4))
159
+ hist_2d_log = np.zeros(hist_2d.shape)
160
+ non_zeros = hist_2d != 0
161
+ hist_2d_log[non_zeros] = np.log(hist_2d[non_zeros])
162
+ plt.imshow((hist_2d_log.T)/hist_2d_log.max(), origin='lower')
163
+ plt.xlabel(image1)
164
+ plt.ylabel(image2)
165
+ plt.xticks(ticks=np.linspace(0,bins-1,10), labels=np.linspace(x_edges.min(),x_edges.max(),10).astype(int))
166
+ plt.yticks(ticks=np.linspace(0,bins-1,10), labels=np.linspace(y_edges.min(),y_edges.max(),10).astype(int))
167
+ plt.title('log(p(x,y))')
168
+ plt.colorbar()
169
+ plt.show()
170
+
171
+ if perturb == "shift":
172
+ mi_array, mse_array = shifts_entropies(t1, bins=bins)
173
+ plt.figure()
174
+ plt.plot(np.linspace(0,50,50), mi_array)
175
+ plt.xlabel("y shift")
176
+ plt.ylabel("Mutual Information")
177
+ plt.figure()
178
+ plt.plot(np.linspace(0,50,50), mse_array)
179
+ plt.xlabel("y shift")
180
+ plt.ylabel("Mean Squared Error")
181
+ plt.show()
182
+
183
+ if perturb == "gauss":
184
+ mi_array, mse_array = gaussian_noise_entropies(t1, bins= bins)
185
+ plt.figure()
186
+ plt.plot(np.linspace(0,100,201), mi_array)
187
+ plt.xlabel("sigma")
188
+ plt.ylabel("Mutual Information")
189
+ plt.figure()
190
+ plt.plot(np.linspace(0,100,201), mse_array)
191
+ plt.xlabel("sigma")
192
+ plt.ylabel("Mean Squared Error")
193
+ plt.show()
utils/pytorch_ssim.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """https://github.com/Po-Hsun-Su/pytorch-ssim"""
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.autograd import Variable
6
+ import numpy as np
7
+ from math import exp
8
+
9
+ def gaussian(window_size, sigma):
10
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
11
+ return gauss/gauss.sum()
12
+
13
+ def create_window(window_size, channel):
14
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
15
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
16
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
17
+ return window
18
+
19
+ def _ssim(img1, img2, window, window_size, channel, size_average = True):
20
+ mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
21
+ mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
22
+
23
+ mu1_sq = mu1.pow(2)
24
+ mu2_sq = mu2.pow(2)
25
+ mu1_mu2 = mu1*mu2
26
+
27
+ sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
28
+ sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
29
+ sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
30
+
31
+ C1 = 0.01**2
32
+ C2 = 0.03**2
33
+
34
+ ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
35
+
36
+ if size_average:
37
+ return ssim_map.mean()
38
+ else:
39
+ return ssim_map.mean(1).mean(1).mean(1)
40
+
41
+ class SSIM(torch.nn.Module):
42
+ def __init__(self, window_size = 11, size_average = True):
43
+ super(SSIM, self).__init__()
44
+ self.window_size = window_size
45
+ self.size_average = size_average
46
+ self.channel = 1
47
+ self.window = create_window(window_size, self.channel)
48
+
49
+ def forward(self, img1, img2):
50
+ (_, channel, _, _) = img1.size()
51
+
52
+ if channel == self.channel and self.window.data.type() == img1.data.type():
53
+ window = self.window
54
+ else:
55
+ window = create_window(self.window_size, channel)
56
+
57
+ if img1.is_cuda:
58
+ window = window.cuda(img1.get_device())
59
+ window = window.type_as(img1)
60
+
61
+ self.window = window
62
+ self.channel = channel
63
+
64
+
65
+ return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
66
+
67
+ def ssim(img1, img2, window_size = 11, size_average = True):
68
+ (_, channel, _, _) = img1.size()
69
+ window = create_window(window_size, channel)
70
+
71
+ if img1.is_cuda:
72
+ window = window.cuda(img1.get_device())
73
+ window = window.type_as(img1)
74
+
75
+ return _ssim(img1, img2, window, window_size, channel, size_average)
utils/show_dataset.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd09ffb969b9a0a5b414b892614b5b9e48fa32721ea9a14e9e0951160e8f92e4
3
+ size 2115545
utils/splitting.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Split images in blocks and vice versa
3
+ """
4
+
5
+ import random
6
+ import numpy as np
7
+
8
+ import torch
9
+
10
+ from skimage.util.shape import view_as_windows
11
+
12
+
13
+ def split_img(imgs, ROIs = (3,3) , step= (1,1)):
14
+ """Split the imgs in regions of size ROIs.
15
+
16
+ Args:
17
+ imgs (ndarray): images which you want to split
18
+ ROIs (tuple): size of sub-regions splitted (ROIs=region of interests)
19
+ step (tuple): step path from one sub-region to the next one (in the x,y axis)
20
+
21
+ Returns:
22
+ ndarray: splitted subimages.
23
+ The size is (x_num_subROIs*y_num_subROIs, **) where:
24
+ x_num_subROIs = ( imgs.shape[1]-int(ROIs[1]/2)*2 )/step[1]
25
+ y_num_subROIs = ( imgs.shape[0]-int(ROIs[0]/2)*2 )/step[0]
26
+
27
+ Example:
28
+ >>> from dataset_generator import split
29
+ >>> imgs_splitted = split(imgs, ROI_size = (5,5), step=(2,3))
30
+ """
31
+
32
+ if len(ROIs) > 2:
33
+ return print("ROIs is a 2 element list")
34
+
35
+ if len(step) > 2:
36
+ return print("step is a 2 element list")
37
+
38
+ if type(imgs) != type(np.array(1)):
39
+ return print("imgs should be a ndarray")
40
+
41
+ if len(imgs.shape) == 2: # Single image with one channel (HxW)
42
+ splitted = view_as_windows(imgs, (ROIs[0],ROIs[1]), (step[0], step[1]))
43
+ return splitted.reshape((-1, ROIs[0], ROIs[1]))
44
+
45
+ if len(imgs.shape) == 3:
46
+ _, _, channels = imgs.shape
47
+ if channels <= 3: # Single image more channels (HxWxC)
48
+ splitted = view_as_windows(imgs, (ROIs[0],ROIs[1], channels), (step[0], step[1], channels))
49
+ return splitted.reshape((-1, ROIs[0], ROIs[1], channels))
50
+ else: # More images with 1 channel
51
+ splitted = view_as_windows(imgs, (1, ROIs[0],ROIs[1]), (1, step[0], step[1]))
52
+ return splitted.reshape((-1, ROIs[0], ROIs[1]))
53
+
54
+ if len(imgs.shape) == 4: # More images with more channels(BxHxWxC)
55
+ _, _, _, channels = imgs.shape
56
+ splitted = view_as_windows(imgs, (1, ROIs[0],ROIs[1], channels), (1, step[0], step[1], channels))
57
+ return splitted.reshape((-1, ROIs[0], ROIs[1], channels))
58
+
59
+ def join_blocks(splitted, final_shape):
60
+ """Join blocks to reobtain a splitted image
61
+
62
+ Attribute:
63
+ splitted (tensor) = image splitted in blocks, size = (N_blocks, Channels, Height, Width)
64
+ final_shape (tuple) = size of the final image reconstructed (Height, Width)
65
+ Return:
66
+ tensor: image restored from blocks. size = (Channels, Height, Width)
67
+
68
+ """
69
+ n_blocks, channels, ROI_height, ROI_width = splitted.shape
70
+
71
+ rows = final_shape[0] // ROI_height
72
+ columns = final_shape[1] // ROI_width
73
+
74
+ final_img = torch.empty(rows, channels, ROI_height, ROI_width*columns)
75
+ for r in range(rows):
76
+ stackblocks = splitted[r*columns]
77
+ for c in range(1, columns):
78
+ stackblocks = torch.cat((stackblocks, splitted[r*columns+c]), axis=2)
79
+ final_img[r] = stackblocks
80
+
81
+ joined_img = final_img[0]
82
+
83
+ for i in np.arange(1, len(final_img)):
84
+ joined_img = torch.cat((joined_img,final_img[i]), axis = 1)
85
+
86
+ return joined_img
87
+
88
+ def random_ROI(X, Y, ROIs = (512,512)):
89
+ """ Return a random region for each input/target pair images of the dataset
90
+ Args:
91
+ Y (ndarray): target of your dataset --> size: (BxHxWxC)
92
+ X (ndarray): input of your dataset --> size: (BxHxWxC)
93
+ ROIs (tuple): size of random region (ROIs=region of interests)
94
+
95
+ Returns:
96
+ For each pair images (input/target) of the dataset, return respectively random ROIs
97
+ Y_cut (ndarray): target of your dataset --> size: (Batch,Channels,ROIs[0],ROIs[1])
98
+ X_cut (ndarray): input of your dataset --> size: (Batch,Channels,ROIs[0],ROIs[1])
99
+
100
+ Example:
101
+ >>> from dataset_generator import random_ROI
102
+ >>> X,Y = random_ROI(X,Y, ROIs = (10,10))
103
+ """
104
+
105
+ batch, channels, height, width = X.shape
106
+
107
+ X_cut=np.empty((batch, ROIs[0], ROIs[1], channels))
108
+ Y_cut=np.empty((batch, ROIs[0], ROIs[1], channels))
109
+
110
+ for i in np.arange(len(X)):
111
+ x_size=int(random.random()*(height-(ROIs[0]+1)))
112
+ y_size=int(random.random()*(width-(ROIs[1]+1)))
113
+ X_cut[i]=X[i, x_size:x_size+ROIs[0],y_size:y_size+ROIs[1], :]
114
+ Y_cut[i]=Y[i, x_size:x_size+ROIs[0],y_size:y_size+ROIs[1], :]
115
+ return X_cut, Y_cut
116
+
117
+ def one2many_random_ROI(X, Y, datasize=1000, ROIs = (512,512)):
118
+ """ Return a dataset of N subimages obtained from random regions of the same image
119
+ Args:
120
+ Y (ndarray): target of your dataset --> size: (1,H,W,C)
121
+ X (ndarray): input of your dataset --> size: (1,H,W,C)
122
+ datasize = number of random ROIs to generate
123
+ ROIs (tuple): size of random region (ROIs=region of interests)
124
+
125
+ Returns:
126
+ Y_cut (ndarray): target of your dataset --> size: (Datasize,ROIs[0],ROIs[1],Channels)
127
+ X_cut (ndarray): input of your dataset --> size: (Datasize,ROIs[0],ROIs[1],Channels)
128
+ """
129
+
130
+ batch, channels, height, width = X.shape
131
+
132
+ X_cut=np.empty((datasize, ROIs[0], ROIs[1], channels))
133
+ Y_cut=np.empty((datasize, ROIs[0], ROIs[1], channels))
134
+
135
+ for i in np.arange(datasize):
136
+ X_cut[i], Y_cut[i] = random_ROI(X, Y, ROIs)
137
+ return X_cut, Y_cut