Luis Oala
commited on
Commit
·
d9c7582
0
Parent(s):
fix aws access
Browse files- .gitattributes +1 -0
- ABtesting.py +806 -0
- README.md +26 -0
- figure1.sh +7 -0
- figure2.sh +4 -0
- figures.py +92 -0
- models/classifier.py +281 -0
- perturbed-environment.yml +363 -0
- processingpipeline/numpy_static_pipeline_show.ipynb +3 -0
- processingpipeline/pipeline.py +329 -0
- processingpipeline/torch_pipeline.py +313 -0
- sanity_checks_and_statistics.ipynb +3 -0
- show_classification_results.ipynb +3 -0
- show_results.sh +16 -0
- train.py +420 -0
- train.sh +81 -0
- utils/Cperturb.py +475 -0
- utils/augmentation.py +132 -0
- utils/base.py +330 -0
- utils/dataset.py +622 -0
- utils/debug.py +371 -0
- utils/mutual_entropy.py +193 -0
- utils/pytorch_ssim.py +75 -0
- utils/show_dataset.ipynb +3 -0
- utils/splitting.py +137 -0
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.ipynb filter=lfs diff=lfs merge=lfs -text
|
ABtesting.py
ADDED
@@ -0,0 +1,806 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
from cv2 import transform
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torchvision.transforms import Compose, Normalize
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from utils.dataset import get_dataset, Subset
|
12 |
+
from utils.base import get_mlflow_model_by_name, SmartFormatter
|
13 |
+
from processingpipeline.pipeline import RawProcessingPipeline
|
14 |
+
|
15 |
+
from utils.Cperturb import Distortions
|
16 |
+
|
17 |
+
import segmentation_models_pytorch as smp
|
18 |
+
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
|
21 |
+
parser = argparse.ArgumentParser(description="AB testing, Show Results", formatter_class=SmartFormatter)
|
22 |
+
|
23 |
+
#Select experiment
|
24 |
+
parser.add_argument("--mode", type=str, default="ABShowImages", choices=('ABMakeTable', 'ABShowTable', 'ABShowImages', 'ABShowAllImages', 'CMakeTable', 'CShowTable', 'CShowImages', 'CShowAllImages'),
|
25 |
+
help='R|Choose operation to compute. \n'
|
26 |
+
'A) Lens2Logit image generation: \n '
|
27 |
+
'ABMakeTable: Compute cross-validation metrics results \n '
|
28 |
+
'ABShowTable: Plot cross-validation results on a table \n '
|
29 |
+
'ABShowImages: Choose a training and testing image to compare different pipelines \n '
|
30 |
+
'ABShowAllImages: Plot all possible pipelines \n'
|
31 |
+
'B) Hendrycks Perturbations, C-type dataset: \n '
|
32 |
+
'CMakeTable: For each pipeline, it computes cross-validation metrics for different perturbations \n '
|
33 |
+
'CShowTable: Plot metrics for different pipelines and perturbations \n '
|
34 |
+
'CShowImages: Plot an image with a selected a pipeline and perturbation\n '
|
35 |
+
'CShowAllImages: Plot all possible perturbations for a fixed pipeline' )
|
36 |
+
|
37 |
+
parser.add_argument("--dataset_name", type=str, default='Microscopy', choices=['Microscopy', 'Drone', 'DroneSegmentation'], help='Choose dataset')
|
38 |
+
parser.add_argument("--augmentation", type=str, default='weak', choices=['none','weak','strong'], help='Choose augmentation')
|
39 |
+
parser.add_argument("--N_runs", type=int, default=5, help='Number of k-fold splitting used in the training')
|
40 |
+
parser.add_argument("--download_model", default=False, action='store_true', help='Download Models in cache')
|
41 |
+
|
42 |
+
#Select pipelines
|
43 |
+
parser.add_argument("--dm_train", type=str, default='bilinear', choices= ('bilinear', 'malvar2004', 'menon2007'), help='Choose demosaicing for training processing model')
|
44 |
+
parser.add_argument("--s_train", type=str, default='sharpening_filter', choices= ('sharpening_filter', 'unsharp_masking'), help='Choose sharpening for training processing model')
|
45 |
+
parser.add_argument("--dn_train", type=str, default='gaussian_denoising', choices= ('gaussian_denoising', 'median_denoising'), help='Choose denoising for training processing model')
|
46 |
+
parser.add_argument("--dm_test", type=str, default='bilinear', choices= ('bilinear', 'malvar2004', 'menon2007'), help='Choose demosaicing for testing processing model')
|
47 |
+
parser.add_argument("--s_test", type=str, default='sharpening_filter', choices= ('sharpening_filter', 'unsharp_masking'), help='Choose sharpening for testing processing model')
|
48 |
+
parser.add_argument("--dn_test", type=str, default='gaussian_denoising', choices= ('gaussian_denoising', 'median_denoising'), help='Choose denoising for testing processing model')
|
49 |
+
|
50 |
+
#Select Ctest parameters
|
51 |
+
parser.add_argument("--transform", type=str, default='identity', choices= ('identity','gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
|
52 |
+
'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform'), help='Choose transformation to show for Ctesting')
|
53 |
+
parser.add_argument("--severity", type=int, default=1, choices= (1,2,3,4,5), help='Choose severity for Ctesting')
|
54 |
+
|
55 |
+
args = parser.parse_args()
|
56 |
+
|
57 |
+
class metrics:
|
58 |
+
def __init__(self, confusion_matrix):
|
59 |
+
self.cm = confusion_matrix
|
60 |
+
self.N_classes = len(confusion_matrix)
|
61 |
+
|
62 |
+
def accuracy(self):
|
63 |
+
Tp = torch.diagonal(self.cm,0).sum()
|
64 |
+
N_elements = torch.sum(self.cm)
|
65 |
+
return Tp/N_elements
|
66 |
+
|
67 |
+
def precision(self):
|
68 |
+
Tp_Fp = torch.sum(self.cm, 1)
|
69 |
+
Tp_Fp[Tp_Fp == 0] = 1
|
70 |
+
return torch.diagonal(self.cm,0) / Tp_Fp
|
71 |
+
|
72 |
+
def recall(self):
|
73 |
+
Tp_Fn = torch.sum(self.cm, 0)
|
74 |
+
Tp_Fn[Tp_Fn == 0] = 1
|
75 |
+
return torch.diagonal(self.cm,0) / Tp_Fn
|
76 |
+
|
77 |
+
def f1_score(self):
|
78 |
+
prod = (self.precision()*self.recall())
|
79 |
+
sum = (self.precision() + self.recall())
|
80 |
+
sum[sum == 0.] = 1.
|
81 |
+
return 2*( prod / sum )
|
82 |
+
|
83 |
+
def over_N_runs(ms, N_runs):
|
84 |
+
m, m2 = 0, 0
|
85 |
+
|
86 |
+
for i in ms:
|
87 |
+
m += i
|
88 |
+
mu = m/N_runs
|
89 |
+
|
90 |
+
for i in ms:
|
91 |
+
m2 += (i-mu)**2
|
92 |
+
|
93 |
+
sigma = torch.sqrt( m2 / (N_runs-1) )
|
94 |
+
|
95 |
+
return mu.tolist(), sigma.tolist()
|
96 |
+
|
97 |
+
class ABtesting:
|
98 |
+
def __init__(self,
|
99 |
+
dataset_name: str,
|
100 |
+
augmentation: str,
|
101 |
+
dm_train: str,
|
102 |
+
s_train: str,
|
103 |
+
dn_train: str,
|
104 |
+
dm_test: str,
|
105 |
+
s_test: str,
|
106 |
+
dn_test: str,
|
107 |
+
N_runs: int,
|
108 |
+
severity=1,
|
109 |
+
transform='identity',
|
110 |
+
download_model=False):
|
111 |
+
self.experiment_name = 'ABtesting'
|
112 |
+
self.dataset_name = dataset_name
|
113 |
+
self.augmentation = augmentation
|
114 |
+
self.dm_train = dm_train
|
115 |
+
self.s_train = s_train
|
116 |
+
self.dn_train = dn_train
|
117 |
+
self.dm_test = dm_test
|
118 |
+
self.s_test = s_test
|
119 |
+
self.dn_test = dn_test
|
120 |
+
self.N_runs = N_runs
|
121 |
+
self.severity = severity
|
122 |
+
self.transform = transform
|
123 |
+
self.download_model = download_model
|
124 |
+
|
125 |
+
def static_pip_val(self, debayer=None, sharpening=None, denoising=None, severity=None, transform=None, plot_mode=False):
|
126 |
+
|
127 |
+
if debayer == None:
|
128 |
+
debayer = self.dm_test
|
129 |
+
if sharpening == None:
|
130 |
+
sharpening = self.s_test
|
131 |
+
if denoising == None:
|
132 |
+
denoising = self.dn_test
|
133 |
+
if severity == None:
|
134 |
+
severity = self.severity
|
135 |
+
if transform == None:
|
136 |
+
transform = self.transform
|
137 |
+
|
138 |
+
dataset = get_dataset(self.dataset_name)
|
139 |
+
|
140 |
+
if self.dataset_name == "Drone" or self.dataset_name == "DroneSegmentation":
|
141 |
+
mean = torch.tensor([0.35, 0.36, 0.35])
|
142 |
+
std = torch.tensor([0.12, 0.11, 0.12])
|
143 |
+
elif self.dataset_name == "Microscopy":
|
144 |
+
mean = torch.tensor([0.91, 0.84, 0.94])
|
145 |
+
std = torch.tensor([0.08, 0.12, 0.05])
|
146 |
+
|
147 |
+
if not plot_mode:
|
148 |
+
dataset.transform = Compose([RawProcessingPipeline(
|
149 |
+
camera_parameters=dataset.camera_parameters,
|
150 |
+
debayer=debayer,
|
151 |
+
sharpening=sharpening,
|
152 |
+
denoising=denoising,
|
153 |
+
), Distortions(severity=severity, transform=transform),
|
154 |
+
Normalize(mean, std)])
|
155 |
+
else:
|
156 |
+
dataset.transform = Compose([RawProcessingPipeline(
|
157 |
+
camera_parameters=dataset.camera_parameters,
|
158 |
+
debayer=debayer,
|
159 |
+
sharpening=sharpening,
|
160 |
+
denoising=denoising,
|
161 |
+
), Distortions(severity=severity, transform=transform)])
|
162 |
+
|
163 |
+
return dataset
|
164 |
+
|
165 |
+
def ABclassification(self):
|
166 |
+
|
167 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
168 |
+
|
169 |
+
parent_run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"
|
170 |
+
|
171 |
+
print(f'\nTraining pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_train}, Sharpening: {self.s_train}, Denoiser: {self.dn_train} \n')
|
172 |
+
print(f'\nTesting pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_test}, Sharpening: {self.s_test}, Denoiser: {self.dn_test} \n Transform: {self.transform}, Severity: {self.severity}\n')
|
173 |
+
|
174 |
+
accuracies, precisions, recalls, f1_scores = [],[],[],[]
|
175 |
+
|
176 |
+
os.system('rm -r /tmp/py*')
|
177 |
+
|
178 |
+
for N_run in range(self.N_runs):
|
179 |
+
|
180 |
+
print(f"Evaluating Run {N_run}")
|
181 |
+
|
182 |
+
run_name = parent_run_name+'_'+str(N_run)
|
183 |
+
|
184 |
+
state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name,
|
185 |
+
download_model=self.download_model)
|
186 |
+
|
187 |
+
dataset = self.static_pip_val()
|
188 |
+
valid_set = Subset(dataset, indices=state_dict['valid_indices'])
|
189 |
+
valid_loader = DataLoader(valid_set, batch_size=1, num_workers=16, shuffle=False)
|
190 |
+
|
191 |
+
model.eval()
|
192 |
+
|
193 |
+
len_classes = len(dataset.classes)
|
194 |
+
confusion_matrix = torch.zeros((len_classes, len_classes))
|
195 |
+
|
196 |
+
for img, label in valid_loader:
|
197 |
+
|
198 |
+
prediction = model(img.to(DEVICE)).detach().cpu()
|
199 |
+
prediction = torch.argmax(prediction, dim=1)
|
200 |
+
confusion_matrix[label,prediction] += 1 # Real value rows, Declared columns
|
201 |
+
|
202 |
+
m = metrics(confusion_matrix)
|
203 |
+
|
204 |
+
accuracies.append(m.accuracy())
|
205 |
+
precisions.append(m.precision())
|
206 |
+
recalls.append(m.recall())
|
207 |
+
f1_scores.append(m.f1_score())
|
208 |
+
|
209 |
+
os.system('rm -r /tmp/t*')
|
210 |
+
|
211 |
+
accuracy = metrics.over_N_runs(accuracies, self.N_runs)
|
212 |
+
precision = metrics.over_N_runs(precisions, self.N_runs)
|
213 |
+
recall = metrics.over_N_runs(recalls, self.N_runs)
|
214 |
+
f1_score = metrics.over_N_runs(f1_scores, self.N_runs)
|
215 |
+
return dataset.classes, accuracy, precision, recall, f1_score
|
216 |
+
|
217 |
+
def ABsegmentation(self):
|
218 |
+
|
219 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
220 |
+
|
221 |
+
parent_run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"
|
222 |
+
|
223 |
+
print(f'\nTraining pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_train}, Sharpening: {self.s_train}, Denoiser: {self.dn_train} \n')
|
224 |
+
print(f'\nTesting pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_test}, Sharpening: {self.s_test}, Denoiser: {self.dn_test} \n Transform: {self.transform}, Severity: {self.severity}\n')
|
225 |
+
|
226 |
+
IoUs = []
|
227 |
+
|
228 |
+
os.system('rm -r /tmp/py*')
|
229 |
+
|
230 |
+
for N_run in range(self.N_runs):
|
231 |
+
|
232 |
+
print(f"Evaluating Run {N_run}")
|
233 |
+
|
234 |
+
run_name = parent_run_name+'_'+str(N_run)
|
235 |
+
|
236 |
+
state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name,
|
237 |
+
download_model=self.download_model)
|
238 |
+
|
239 |
+
dataset = self.static_pip_val()
|
240 |
+
|
241 |
+
valid_set = Subset(dataset, indices=state_dict['valid_indices'])
|
242 |
+
valid_loader = DataLoader(valid_set, batch_size=1, num_workers=16, shuffle=False)
|
243 |
+
|
244 |
+
model.eval()
|
245 |
+
|
246 |
+
IoU=0
|
247 |
+
|
248 |
+
for img, label in valid_loader:
|
249 |
+
|
250 |
+
prediction = model(img.to(DEVICE)).detach().cpu()
|
251 |
+
prediction = F.logsigmoid(prediction).exp().squeeze()
|
252 |
+
IoU += smp.utils.metrics.IoU()(prediction,label)
|
253 |
+
|
254 |
+
IoU = IoU/len(valid_loader)
|
255 |
+
IoUs.append(IoU.item())
|
256 |
+
|
257 |
+
os.system('rm -r /tmp/t*')
|
258 |
+
|
259 |
+
IoU = metrics.over_N_runs(torch.tensor(IoUs), self.N_runs)
|
260 |
+
return IoU
|
261 |
+
|
262 |
+
def ABShowImages(self):
|
263 |
+
|
264 |
+
path = 'results/ABtesting/imgs/'
|
265 |
+
if not os.path.exists(path):
|
266 |
+
os.makedirs(path)
|
267 |
+
|
268 |
+
path = os.path.join(path, f'{self.dataset_name}_{self.augmentation}_{self.dm_train[:2]}{self.s_train[0]}{self.dn_train[:2]}_{self.dm_test[:2]}{self.s_test[0]}{self.dn_test[:2]}')
|
269 |
+
|
270 |
+
if not os.path.exists(path):
|
271 |
+
os.makedirs(path)
|
272 |
+
|
273 |
+
run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"+'_'+str(0)
|
274 |
+
|
275 |
+
state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name, download_model=self.download_model)
|
276 |
+
|
277 |
+
model.augmentation = None
|
278 |
+
|
279 |
+
for t in ([self.dm_train, self.s_train, self.dn_train, 'train_img'],
|
280 |
+
[self.dm_test, self.s_test, self.dn_test, 'test_img']):
|
281 |
+
|
282 |
+
debayer, sharpening, denoising, img_type = t[0], t[1], t[2], t[3]
|
283 |
+
|
284 |
+
dataset = self.static_pip_val(debayer=debayer, sharpening=sharpening, denoising=denoising, plot_mode=True)
|
285 |
+
valid_set = Subset(dataset, indices=state_dict['valid_indices'])
|
286 |
+
|
287 |
+
img, _ = next(iter(valid_set))
|
288 |
+
|
289 |
+
plt.figure()
|
290 |
+
plt.imshow(img.permute(1,2,0))
|
291 |
+
if img_type == 'train_img':
|
292 |
+
plt.title('Train Image')
|
293 |
+
plt.savefig(os.path.join(path, f'img_train.png'))
|
294 |
+
imgA = img
|
295 |
+
else:
|
296 |
+
plt.title('Test Image')
|
297 |
+
plt.savefig(os.path.join(path,f'img_test.png'))
|
298 |
+
|
299 |
+
for c, color in enumerate(['Red','Green','Blue']):
|
300 |
+
diff = torch.abs(imgA-img)
|
301 |
+
plt.figure()
|
302 |
+
# plt.imshow(diff.permute(1,2,0))
|
303 |
+
plt.imshow(diff[c,50:200,50:200], cmap=f'{color}s')
|
304 |
+
plt.title(f'|Train Image - Test Image| - {color}')
|
305 |
+
plt.colorbar()
|
306 |
+
plt.savefig(os.path.join(path, f'diff_{color}.png'))
|
307 |
+
plt.figure()
|
308 |
+
diff[diff == 0.]= 1e-5
|
309 |
+
# plt.imshow(torch.log(diff.permute(1,2,0)))
|
310 |
+
plt.imshow(torch.log(diff)[c])
|
311 |
+
plt.title(f'log(|Train Image - Test Image|) - color')
|
312 |
+
plt.colorbar()
|
313 |
+
plt.savefig(os.path.join(path, f'logdiff_{color}.png'))
|
314 |
+
|
315 |
+
if self.dataset_name == 'DroneSegmentation':
|
316 |
+
plt.figure()
|
317 |
+
plt.imshow(model(img[None].cuda()).detach().cpu().squeeze())
|
318 |
+
if img_type == 'train_img':
|
319 |
+
plt.savefig(os.path.join(path, f'mask_train.png'))
|
320 |
+
else:
|
321 |
+
plt.savefig(os.path.join(path,f'mask_test.png'))
|
322 |
+
|
323 |
+
def ABShowAllImages(self):
|
324 |
+
if not os.path.exists('results/ABtesting'):
|
325 |
+
os.makedirs('results/ABtesting')
|
326 |
+
|
327 |
+
demosaicings=['bilinear','malvar2004', 'menon2007']
|
328 |
+
sharpenings=['sharpening_filter', 'unsharp_masking']
|
329 |
+
denoisings=['median_denoising', 'gaussian_denoising']
|
330 |
+
|
331 |
+
fig = plt.figure()
|
332 |
+
columns=4
|
333 |
+
rows=3
|
334 |
+
|
335 |
+
i=1
|
336 |
+
|
337 |
+
for dm in demosaicings:
|
338 |
+
for s in sharpenings:
|
339 |
+
for dn in denoisings:
|
340 |
+
|
341 |
+
dataset = self.static_pip_val(self.dm_test, self.s_test,
|
342 |
+
self.dn_test, plot_mode=True)
|
343 |
+
|
344 |
+
img,_ = dataset[0]
|
345 |
+
|
346 |
+
fig.add_subplot(rows, columns, i)
|
347 |
+
plt.imshow(img.permute(1,2,0))
|
348 |
+
plt.title(f'{dm}\n{s}\n{dn}', fontsize=8)
|
349 |
+
plt.xticks([])
|
350 |
+
plt.yticks([])
|
351 |
+
plt.tight_layout()
|
352 |
+
|
353 |
+
i+=1
|
354 |
+
|
355 |
+
plt.show()
|
356 |
+
plt.savefig(f'results/ABtesting/ABpipelines.png')
|
357 |
+
|
358 |
+
def CShowImages(self):
|
359 |
+
|
360 |
+
path = 'results/Ctesting/imgs/'
|
361 |
+
if not os.path.exists(path):
|
362 |
+
os.makedirs(path)
|
363 |
+
|
364 |
+
run_name = f"{self.dataset_name}_{self.dm_test}_{self.s_test}_{self.dn_test}_{self.augmentation}"+'_'+str(0)
|
365 |
+
|
366 |
+
state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name, download_model=True)
|
367 |
+
|
368 |
+
model.augmentation = None
|
369 |
+
|
370 |
+
dataset = self.static_pip_val(self.dm_test, self.s_test, self.dn_test, self.severity, self.transform, plot_mode=True)
|
371 |
+
valid_set = Subset(dataset, indices=state_dict['valid_indices'])
|
372 |
+
|
373 |
+
img, _ = next(iter(valid_set))
|
374 |
+
|
375 |
+
plt.figure()
|
376 |
+
plt.imshow(img.permute(1,2,0))
|
377 |
+
plt.savefig(os.path.join(path, f'{self.dataset_name}_{self.augmentation}_{self.dm_train[:2]}{self.s_train[0]}{self.dn_train[:2]}_{self.transform}_sev{self.severity}'))
|
378 |
+
|
379 |
+
def CShowAllImages(self):
|
380 |
+
if not os.path.exists('results/Cimages'):
|
381 |
+
os.makedirs('results/Cimages')
|
382 |
+
|
383 |
+
transforms = ['identity','gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
|
384 |
+
'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
|
385 |
+
|
386 |
+
for i,t in enumerate(transforms):
|
387 |
+
|
388 |
+
fig = plt.figure(figsize=(10,6))
|
389 |
+
columns = 5
|
390 |
+
rows = 1
|
391 |
+
|
392 |
+
for sev in range(1,6):
|
393 |
+
|
394 |
+
dataset = self.static_pip_val(severity=sev, transform=t, plot_mode=True)
|
395 |
+
|
396 |
+
img,_ = dataset[0]
|
397 |
+
|
398 |
+
fig.add_subplot(rows, columns, sev)
|
399 |
+
plt.imshow(img.permute(1,2,0))
|
400 |
+
plt.title(f'Severity: {sev}')
|
401 |
+
plt.xticks([])
|
402 |
+
plt.yticks([])
|
403 |
+
plt.tight_layout()
|
404 |
+
|
405 |
+
if '_' in t:
|
406 |
+
t=t.replace('_', ' ')
|
407 |
+
t=t[0].upper()+t[1:]
|
408 |
+
|
409 |
+
fig.suptitle(f'{t}', x=0.5, y=0.8, fontsize=24)
|
410 |
+
plt.show()
|
411 |
+
plt.savefig(f'results/Cimages/{i+1}_{t.lower()}.png')
|
412 |
+
|
413 |
+
def ABMakeTable(dataset_name:str, augmentation: str,
|
414 |
+
N_runs: int, download_model: bool):
|
415 |
+
|
416 |
+
demosaicings=['bilinear','malvar2004', 'menon2007']
|
417 |
+
sharpenings=['sharpening_filter', 'unsharp_masking']
|
418 |
+
denoisings=['median_denoising', 'gaussian_denoising']
|
419 |
+
|
420 |
+
path='results/ABtesting/tables'
|
421 |
+
if not os.path.exists(path):
|
422 |
+
os.makedirs(path)
|
423 |
+
|
424 |
+
runs={}
|
425 |
+
i=0
|
426 |
+
|
427 |
+
for dm_train in demosaicings:
|
428 |
+
for s_train in sharpenings:
|
429 |
+
for dn_train in denoisings:
|
430 |
+
for dm_test in demosaicings:
|
431 |
+
for s_test in sharpenings:
|
432 |
+
for dn_test in denoisings:
|
433 |
+
train_pip = [dm_train, s_train, dn_train]
|
434 |
+
test_pip = [dm_test, s_test, dn_test]
|
435 |
+
runs[f'run{i}'] = {
|
436 |
+
'dataset': dataset_name,
|
437 |
+
'augmentation': augmentation,
|
438 |
+
'train_pip': train_pip,
|
439 |
+
'test_pip': test_pip,
|
440 |
+
'N_runs': N_runs
|
441 |
+
}
|
442 |
+
ABclass = ABtesting(
|
443 |
+
dataset_name=dataset_name,
|
444 |
+
augmentation=augmentation,
|
445 |
+
dm_train = dm_train,
|
446 |
+
s_train = s_train,
|
447 |
+
dn_train = dn_train,
|
448 |
+
dm_test = dm_test,
|
449 |
+
s_test = s_test,
|
450 |
+
dn_test = dn_test,
|
451 |
+
N_runs=N_runs,
|
452 |
+
download_model=download_model
|
453 |
+
)
|
454 |
+
|
455 |
+
if dataset_name == 'DroneSegmentation':
|
456 |
+
IoU = ABclass.ABsegmentation()
|
457 |
+
runs[f'run{i}']['IoU'] = IoU
|
458 |
+
else:
|
459 |
+
classes, accuracy, precision, recall, f1_score = ABclass.ABclassification()
|
460 |
+
runs[f'run{i}']['classes'] = classes
|
461 |
+
runs[f'run{i}']['accuracy'] = accuracy
|
462 |
+
runs[f'run{i}']['precision'] = precision
|
463 |
+
runs[f'run{i}']['recall'] = recall
|
464 |
+
runs[f'run{i}']['f1_score'] = f1_score
|
465 |
+
|
466 |
+
with open(os.path.join(path,f'{dataset_name}_{augmentation}_runs.txt'), 'w') as outfile:
|
467 |
+
json.dump(runs, outfile)
|
468 |
+
|
469 |
+
i+=1
|
470 |
+
|
471 |
+
def ABShowTable(dataset_name: str, augmentation: str):
|
472 |
+
|
473 |
+
path='results/ABtesting/tables'
|
474 |
+
assert os.path.exists(path), 'No tables to plot'
|
475 |
+
|
476 |
+
json_file = os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt')
|
477 |
+
|
478 |
+
with open(json_file, 'r') as run_file:
|
479 |
+
runs = json.load(run_file)
|
480 |
+
|
481 |
+
metrics=torch.zeros((2,12,12))
|
482 |
+
classes=[]
|
483 |
+
|
484 |
+
i,j=0,0
|
485 |
+
|
486 |
+
for r in range(len(runs)):
|
487 |
+
|
488 |
+
run = runs['run'+str(r)]
|
489 |
+
if dataset_name == 'DroneSegmentation':
|
490 |
+
acc = run['IoU']
|
491 |
+
else:
|
492 |
+
acc = run['accuracy']
|
493 |
+
if len(classes) < 12:
|
494 |
+
class_list = run['test_pip']
|
495 |
+
class_name = f'{class_list[0][:2]},{class_list[1][:1]},{class_list[2][:2]}'
|
496 |
+
classes.append(class_name)
|
497 |
+
mu,sigma = round(acc[0],4),round(acc[1],4)
|
498 |
+
|
499 |
+
metrics[0,j,i] = mu
|
500 |
+
metrics[1,j,i] = sigma
|
501 |
+
|
502 |
+
i+=1
|
503 |
+
|
504 |
+
if i == 12:
|
505 |
+
i=0
|
506 |
+
j+=1
|
507 |
+
|
508 |
+
differences = torch.zeros_like(metrics)
|
509 |
+
|
510 |
+
diag_mu = torch.diagonal(metrics[0],0)
|
511 |
+
diag_sigma = torch.diagonal(metrics[1],0)
|
512 |
+
|
513 |
+
for r in range(len(metrics[0])):
|
514 |
+
differences[0,r] = diag_mu[r] - metrics[0,r]
|
515 |
+
differences[1,r] = torch.sqrt(metrics[1,r]**2 + diag_sigma[r]**2)
|
516 |
+
|
517 |
+
# Plot with scatter
|
518 |
+
|
519 |
+
for i,img in enumerate([metrics, differences]):
|
520 |
+
|
521 |
+
x, y = torch.arange(12), torch.arange(12)
|
522 |
+
x, y = torch.meshgrid(x, y)
|
523 |
+
|
524 |
+
if i == 0:
|
525 |
+
vmin = max(0.65, round(img[0].min().item(),2))
|
526 |
+
vmax = round(img[0].max().item(),2)
|
527 |
+
step = 0.02
|
528 |
+
elif i == 1:
|
529 |
+
vmin = round(img[0].min().item(),2)
|
530 |
+
if augmentation == 'none':
|
531 |
+
vmax = min(0.15, round(img[0].max().item(),2))
|
532 |
+
if augmentation == 'weak':
|
533 |
+
vmax = min(0.08, round(img[0].max().item(),2))
|
534 |
+
if augmentation == 'strong':
|
535 |
+
vmax = min(0.05, round(img[0].max().item(),2))
|
536 |
+
step = 0.01
|
537 |
+
|
538 |
+
vmin = int(vmin/step)*step
|
539 |
+
vmax = int(vmax/step)*step
|
540 |
+
|
541 |
+
fig = plt.figure(figsize=(10,6.2))
|
542 |
+
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
|
543 |
+
marker_size=350
|
544 |
+
plt.scatter(x, y, c=torch.rot90(img[1][x,y],-1,[0,1]), vmin = 0., vmax = img[1].max(), cmap='viridis', s=marker_size*2, marker='s')
|
545 |
+
ticks = torch.arange(0.,img[1].max(),0.03).tolist()
|
546 |
+
ticks = [round(tick,2) for tick in ticks]
|
547 |
+
cba = plt.colorbar(pad=0.06)
|
548 |
+
cba.set_ticks(ticks)
|
549 |
+
cba.ax.set_yticklabels(ticks)
|
550 |
+
# cmap = plt.cm.get_cmap('tab20c').reversed()
|
551 |
+
cmap = plt.cm.get_cmap('Reds')
|
552 |
+
plt.scatter(x,y, c=torch.rot90(img[0][x,y],-1,[0,1]), vmin = vmin, vmax = vmax, cmap=cmap, s=marker_size, marker='s')
|
553 |
+
ticks = torch.arange(vmin, vmax, step).tolist()
|
554 |
+
ticks = [round(tick,2) for tick in ticks]
|
555 |
+
if ticks[-1] != vmax:
|
556 |
+
ticks.append(vmax)
|
557 |
+
cbb = plt.colorbar(pad=0.06)
|
558 |
+
cbb.set_ticks(ticks)
|
559 |
+
if i == 0:
|
560 |
+
ticks[0] = f'<{str(ticks[0])}'
|
561 |
+
elif i == 1:
|
562 |
+
ticks[-1] = f'>{str(ticks[-1])}'
|
563 |
+
cbb.ax.set_yticklabels(ticks)
|
564 |
+
for x in range(12):
|
565 |
+
for y in range(12):
|
566 |
+
txt = round(torch.rot90(img[0],-1,[0,1])[x,y].item(),2)
|
567 |
+
if str(txt) == '-0.0':
|
568 |
+
txt = '0.00'
|
569 |
+
elif str(txt) == '0.0':
|
570 |
+
txt = '0.00'
|
571 |
+
elif len(str(txt)) == 3:
|
572 |
+
txt = str(txt)+'0'
|
573 |
+
else:
|
574 |
+
txt = str(txt)
|
575 |
+
|
576 |
+
plt.text(x-0.25,y-0.1,txt, color='black', fontsize='x-small')
|
577 |
+
|
578 |
+
ax.set_xticks(torch.linspace(0,11,12))
|
579 |
+
ax.set_xticklabels(classes)
|
580 |
+
ax.set_yticks(torch.linspace(0,11,12))
|
581 |
+
classes.reverse()
|
582 |
+
ax.set_yticklabels(classes)
|
583 |
+
classes.reverse()
|
584 |
+
plt.xticks(rotation = 45)
|
585 |
+
plt.yticks(rotation = 45)
|
586 |
+
cba.set_label('Standard Deviation')
|
587 |
+
plt.xlabel("Test pipelines")
|
588 |
+
plt.ylabel("Train pipelines")
|
589 |
+
plt.title(f'Dataset: {dataset_name}, Augmentation: {augmentation}')
|
590 |
+
if i == 0:
|
591 |
+
if dataset_name == 'DroneSegmentation':
|
592 |
+
cbb.set_label('IoU')
|
593 |
+
plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_IoU.png"))
|
594 |
+
else:
|
595 |
+
cbb.set_label('Accuracy')
|
596 |
+
plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_accuracies.png"))
|
597 |
+
elif i == 1:
|
598 |
+
if dataset_name == 'DroneSegmentation':
|
599 |
+
cbb.set_label('IoU_d-IoU')
|
600 |
+
else:
|
601 |
+
cbb.set_label('Accuracy_d - Accuracy')
|
602 |
+
plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_differences.png"))
|
603 |
+
|
604 |
+
def CMakeTable(dataset_name: str, augmentation: str, severity: int, N_runs: int, download_model: bool):
|
605 |
+
|
606 |
+
path='results/Ctesting/tables'
|
607 |
+
if not os.path.exists(path):
|
608 |
+
os.makedirs(path)
|
609 |
+
|
610 |
+
demosaicings=['bilinear','malvar2004', 'menon2007']
|
611 |
+
sharpenings=['sharpening_filter', 'unsharp_masking']
|
612 |
+
denoisings=['median_denoising', 'gaussian_denoising']
|
613 |
+
|
614 |
+
transformations = ['identity','gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
|
615 |
+
'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
|
616 |
+
|
617 |
+
runs={}
|
618 |
+
i=0
|
619 |
+
|
620 |
+
for dm in demosaicings:
|
621 |
+
for s in sharpenings:
|
622 |
+
for dn in denoisings:
|
623 |
+
for t in transformations:
|
624 |
+
pip = [dm,s,dn]
|
625 |
+
runs[f'run{i}'] = {
|
626 |
+
'dataset': dataset_name,
|
627 |
+
'augmentation': augmentation,
|
628 |
+
'pipeline': pip,
|
629 |
+
'N_runs': N_runs,
|
630 |
+
'transform': t,
|
631 |
+
'severity': severity,
|
632 |
+
}
|
633 |
+
ABclass = ABtesting(
|
634 |
+
dataset_name=dataset_name,
|
635 |
+
augmentation=augmentation,
|
636 |
+
dm_train = dm,
|
637 |
+
s_train = s,
|
638 |
+
dn_train = dn,
|
639 |
+
dm_test = dm,
|
640 |
+
s_test = s,
|
641 |
+
dn_test = dn,
|
642 |
+
severity=severity,
|
643 |
+
transform=t,
|
644 |
+
N_runs=N_runs,
|
645 |
+
download_model=download_model
|
646 |
+
)
|
647 |
+
|
648 |
+
if dataset_name == 'DroneSegmentation':
|
649 |
+
IoU = ABclass.ABsegmentation()
|
650 |
+
runs[f'run{i}']['IoU'] = IoU
|
651 |
+
else:
|
652 |
+
classes, accuracy, precision, recall, f1_score = ABclass.ABclassification()
|
653 |
+
runs[f'run{i}']['classes'] = classes
|
654 |
+
runs[f'run{i}']['accuracy'] = accuracy
|
655 |
+
runs[f'run{i}']['precision'] = precision
|
656 |
+
runs[f'run{i}']['recall'] = recall
|
657 |
+
runs[f'run{i}']['f1_score'] = f1_score
|
658 |
+
|
659 |
+
with open(os.path.join(path,f'{dataset_name}_{augmentation}_runs.json'), 'w') as outfile:
|
660 |
+
json.dump(runs, outfile)
|
661 |
+
|
662 |
+
i+=1
|
663 |
+
|
664 |
+
def CShowTable(dataset_name, augmentation):
|
665 |
+
|
666 |
+
path='results/Ctesting/tables'
|
667 |
+
assert os.path.exists(path), 'No tables to plot'
|
668 |
+
|
669 |
+
json_file = os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt')
|
670 |
+
|
671 |
+
transforms = ['identity','gauss_noise', 'shot', 'impulse', 'speckle',
|
672 |
+
'gauss_blur', 'zoom', 'contrast', 'brightness', 'saturate', 'elastic']
|
673 |
+
|
674 |
+
pip = []
|
675 |
+
|
676 |
+
demosaicings=['bilinear','malvar2004', 'menon2007']
|
677 |
+
sharpenings=['sharpening_filter', 'unsharp_masking']
|
678 |
+
denoisings=['median_denoising', 'gaussian_denoising']
|
679 |
+
|
680 |
+
for dm in demosaicings:
|
681 |
+
for s in sharpenings:
|
682 |
+
for dn in denoisings:
|
683 |
+
pip.append(f'{dm[:2]},{s[0]},{dn[2]}')
|
684 |
+
|
685 |
+
with open(json_file, 'r') as run_file:
|
686 |
+
runs = json.load(run_file)
|
687 |
+
|
688 |
+
metrics=torch.zeros((2,len(pip),len(transforms)))
|
689 |
+
|
690 |
+
i,j=0,0
|
691 |
+
|
692 |
+
for r in range(len(runs)):
|
693 |
+
|
694 |
+
run = runs['run'+str(r)]
|
695 |
+
if dataset_name == 'DroneSegmentation':
|
696 |
+
acc = run['IoU']
|
697 |
+
else:
|
698 |
+
acc = run['accuracy']
|
699 |
+
mu,sigma = round(acc[0],4),round(acc[1],4)
|
700 |
+
|
701 |
+
metrics[0,j,i] = mu
|
702 |
+
metrics[1,j,i] = sigma
|
703 |
+
|
704 |
+
i+=1
|
705 |
+
|
706 |
+
if i == len(transforms):
|
707 |
+
i=0
|
708 |
+
j+=1
|
709 |
+
|
710 |
+
# Plot with scatter
|
711 |
+
|
712 |
+
img = metrics
|
713 |
+
|
714 |
+
vmin=0.
|
715 |
+
vmax=1.
|
716 |
+
|
717 |
+
x, y = torch.arange(12), torch.arange(11)
|
718 |
+
x, y = torch.meshgrid(x, y)
|
719 |
+
|
720 |
+
fig = plt.figure(figsize=(10,6.2))
|
721 |
+
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
|
722 |
+
marker_size=350
|
723 |
+
plt.scatter(x, y, c=torch.rot90(img[1][x,y],-1,[0,1]), vmin = 0., vmax = img[1].max(), cmap='viridis', s=marker_size*2, marker='s')
|
724 |
+
ticks = torch.arange(0.,img[1].max(),0.03).tolist()
|
725 |
+
ticks = [round(tick,2) for tick in ticks]
|
726 |
+
cba = plt.colorbar(pad=0.06)
|
727 |
+
cba.set_ticks(ticks)
|
728 |
+
cba.ax.set_yticklabels(ticks)
|
729 |
+
# cmap = plt.cm.get_cmap('tab20c').reversed()
|
730 |
+
cmap = plt.cm.get_cmap('Reds')
|
731 |
+
plt.scatter(x,y, c=torch.rot90(img[0][x,y],-1,[0,1]), vmin=vmin, vmax=vmax, cmap=cmap, s=marker_size, marker='s')
|
732 |
+
ticks = torch.arange(vmin, vmax, step).tolist()
|
733 |
+
ticks = [round(tick,2) for tick in ticks]
|
734 |
+
if ticks[-1] != vmax:
|
735 |
+
ticks.append(vmax)
|
736 |
+
cbb = plt.colorbar(pad=0.06)
|
737 |
+
cbb.set_ticks(ticks)
|
738 |
+
if i == 0:
|
739 |
+
ticks[0] = f'<{str(ticks[0])}'
|
740 |
+
elif i == 1:
|
741 |
+
ticks[-1] = f'>{str(ticks[-1])}'
|
742 |
+
cbb.ax.set_yticklabels(ticks)
|
743 |
+
for x in range(12):
|
744 |
+
for y in range(12):
|
745 |
+
txt = round(torch.rot90(img[0],-1,[0,1])[x,y].item(),2)
|
746 |
+
if str(txt) == '-0.0':
|
747 |
+
txt = '0.00'
|
748 |
+
elif str(txt) == '0.0':
|
749 |
+
txt = '0.00'
|
750 |
+
elif len(str(txt)) == 3:
|
751 |
+
txt = str(txt)+'0'
|
752 |
+
else:
|
753 |
+
txt = str(txt)
|
754 |
+
|
755 |
+
plt.text(x-0.25,y-0.1,txt, color='black', fontsize='x-small')
|
756 |
+
|
757 |
+
ax.set_xticks(torch.linspace(0,11,12))
|
758 |
+
ax.set_xticklabels(transforms)
|
759 |
+
ax.set_yticks(torch.linspace(0,11,12))
|
760 |
+
pip.reverse()
|
761 |
+
ax.set_yticklabels(pip)
|
762 |
+
pip.reverse()
|
763 |
+
plt.xticks(rotation = 45)
|
764 |
+
plt.yticks(rotation = 45)
|
765 |
+
cba.set_label('Standard Deviation')
|
766 |
+
plt.xlabel("Pipelines")
|
767 |
+
plt.ylabel("Distortions")
|
768 |
+
if dataset_name == 'DroneSegmentation':
|
769 |
+
cbb.set_label('IoU')
|
770 |
+
plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_IoU.png"))
|
771 |
+
else:
|
772 |
+
cbb.set_label('Accuracy')
|
773 |
+
plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_accuracies.png"))
|
774 |
+
|
775 |
+
if __name__ == '__main__':
|
776 |
+
|
777 |
+
if args.mode == 'ABMakeTable':
|
778 |
+
ABMakeTable(args.dataset_name, args.augmentation, args.N_runs, args.download_model)
|
779 |
+
elif args.mode == 'ABShowTable':
|
780 |
+
ABShowTable(args.dataset_name, args.augmentation)
|
781 |
+
elif args.mode == 'ABShowImages':
|
782 |
+
ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
|
783 |
+
args.s_train, args.dn_train, args.dm_test, args.s_test,
|
784 |
+
args.dn_test, args.N_runs, download_model=args.download_model)
|
785 |
+
ABclass.ABShowImages()
|
786 |
+
elif args.mode == 'ABShowAllImages':
|
787 |
+
ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
|
788 |
+
args.s_train, args.dn_train, args.dm_test, args.s_test,
|
789 |
+
args.dn_test, args.N_runs, download_model=args.download_model)
|
790 |
+
ABclass.ABShowAllImages()
|
791 |
+
elif args.mode == 'CMakeTable':
|
792 |
+
CMakeTable(args.dataset_name, args.augmentation, args.severity, args.N_runs, args.download_model)
|
793 |
+
elif args.mode == 'CShowTable': # TODO test it
|
794 |
+
CShowTable(args.dataset_name, args.augmentation, args.severity)
|
795 |
+
elif args.mode == 'CShowImages':
|
796 |
+
ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
|
797 |
+
args.s_train, args.dn_train, args.dm_test, args.s_test,
|
798 |
+
args.dn_test, args.N_runs, args.severity, args.transform,
|
799 |
+
download_model=args.download_model)
|
800 |
+
ABclass.CShowImages()
|
801 |
+
elif args.mode == 'CShowAllImages':
|
802 |
+
ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
|
803 |
+
args.s_train, args.dn_train, args.dm_test, args.s_test,
|
804 |
+
args.dn_test, args.N_runs, args.severity, args.transform,
|
805 |
+
download_model=args.download_model)
|
806 |
+
ABclass.CShowAllImages()
|
README.md
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Perturbed Minds
|
2 |
+
|
3 |
+
## Conda environment and dependencies
|
4 |
+
|
5 |
+
To make running this code easier you can install the latest conda environment for this project stored in `perturbed-environment.yml`.
|
6 |
+
|
7 |
+
### Install environment from `perturbed-environment.yml`
|
8 |
+
|
9 |
+
If you want to install the latest conda environment run
|
10 |
+
|
11 |
+
`conda env create -f perturbed-environment.yml`
|
12 |
+
|
13 |
+
### Install segmentation_models_pytorch newest version
|
14 |
+
|
15 |
+
PyPi version is not up-to-date with github version and lacks features
|
16 |
+
|
17 |
+
`python -m pip install git+https://github.com/qubvel/segmentation_models.pytorch`
|
18 |
+
|
19 |
+
### Update `perturbed-environment.yml`
|
20 |
+
|
21 |
+
If you add code that requires new packages, inside your perturbed-minds conda environment run
|
22 |
+
|
23 |
+
`conda env export > perturbed-environment.yml`
|
24 |
+
|
25 |
+
## Walk-through
|
26 |
+
Link to the repository structure we put down in miro: https://miro.com/app/board/o9J_lQdgyf8=/
|
figure1.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python figures.py \
|
2 |
+
--experiment_name track-test \
|
3 |
+
--run_name track-all \
|
4 |
+
--representation gradients \
|
5 |
+
--step gamma_correct \
|
6 |
+
--gif_name gradient \
|
7 |
+
--output gif \
|
figure2.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python figures.py \
|
2 |
+
--experiment_name track-test \
|
3 |
+
--run_name track-all \
|
4 |
+
--output train_vs_val_loss \
|
figures.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mlflow
|
2 |
+
from mlflow.tracking import MlflowClient
|
3 |
+
from mlflow.entities import ViewType
|
4 |
+
import argparse
|
5 |
+
#gif
|
6 |
+
import os
|
7 |
+
import pathlib
|
8 |
+
import shutil
|
9 |
+
import imageio
|
10 |
+
#plot
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
# -1. parse args
|
15 |
+
parser = argparse.ArgumentParser(description="results_analysis")
|
16 |
+
parser.add_argument("--tracking_uri", type=str,
|
17 |
+
default="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com", help='URI of the mlflow server on AWS')
|
18 |
+
parser.add_argument("--experiment_name", type=str, default=None,
|
19 |
+
help='Name of the experiment on the mlflow server, e.g. "processing_comparison"')
|
20 |
+
parser.add_argument("--run_name", type=str, default=None,
|
21 |
+
help='Name of the run on the mlflow server, e.g. "proc_nn"')
|
22 |
+
parser.add_argument("--representation", type=str, default=None,
|
23 |
+
choices=["processing", "gradients"], help='The representation form you want retrieve("processing" or "gradients")')
|
24 |
+
parser.add_argument("--step", type=str, default=None,
|
25 |
+
choices=["pre_debayer", "demosaic", "color_correct", "sharpening", "gaussian", "clipped", "gamma_correct", "rgb"],
|
26 |
+
help='The processing step you want to track ("pre_debayer" or "rgb")') #TODO: include predictions and ground truths
|
27 |
+
parser.add_argument("--gif_name", type=str, default=None,
|
28 |
+
help='Name of the gif that will be saved. Note: .gif will be added later by script') #TODO: option to include filepath where result should be written
|
29 |
+
#TODO: option to write results to existing run on mlflow
|
30 |
+
parser.add_argument("--local_dir", type=str, default=None,
|
31 |
+
help='Name of the local dir to be created to store mlflow data')
|
32 |
+
parser.add_argument("--cleanup", type=bool, default=True,
|
33 |
+
help='Whether to delete the local dir again after the script was run')
|
34 |
+
parser.add_argument("--output", type=str, default=None,
|
35 |
+
choices=["gif", "train_vs_val_loss"],
|
36 |
+
help='Which output to generate') #TODO: make this cleaner, atm it is confusing because each figure may need different set of args and it is not clear how to manage that
|
37 |
+
#TODO: idea -> fix the types of args for each figure which define the figure type but parametrize those things that can reasonably vary
|
38 |
+
args = parser.parse_args()
|
39 |
+
|
40 |
+
# 0. mlflow basics
|
41 |
+
mlflow.set_tracking_uri(args.tracking_uri)
|
42 |
+
|
43 |
+
# 1. specify experiment_name, run_name, representation and step
|
44 |
+
#is done via parse_args
|
45 |
+
|
46 |
+
# 2. use get_experiment_by_name to get experiment object
|
47 |
+
experiment = mlflow.get_experiment_by_name(args.experiment_name)
|
48 |
+
|
49 |
+
# 3. extract experiment_id
|
50 |
+
#experiment.experiment_id
|
51 |
+
|
52 |
+
# 4. use search_runs with experiment_id and run_name for string search query
|
53 |
+
filter_string = "tags.mlflow.runName = '{}'".format(args.run_name) #create the filter string with using the runName tag to query mlflow
|
54 |
+
runs = mlflow.search_runs(experiment.experiment_id, filter_string=filter_string) #returns a pandas data frame where each row is a run (if several exist under that name)
|
55 |
+
client = MlflowClient() #TODO: look more into the options of client
|
56 |
+
|
57 |
+
if args.output == "gif": #TODO: outsource these options to functions which are then loaded and can be called
|
58 |
+
# 5. extract run from list
|
59 |
+
#TODO: parent run and cv option for analysis
|
60 |
+
if args.local_dir:
|
61 |
+
local_dir = args.local_dir+"/artifacts"
|
62 |
+
else: #use the current working dir and make a subdir "artifacts" to store the data from mlflow
|
63 |
+
local_dir = str(pathlib.Path().resolve())+"/artifacts"
|
64 |
+
if not os.path.isdir('artifacts'):
|
65 |
+
os.mkdir(local_dir) #create the local_dir if it does not exist, yet #TODO: more advanced catching of existing files etc
|
66 |
+
dir = client.download_artifacts(runs["run_id"][0], "results", local_dir) #TODO: parametrize this number [0] so the right run is selected
|
67 |
+
|
68 |
+
# 6. get filenames in chronological sequence and write them to gif
|
69 |
+
dirs = [x[0] for x in os.walk(dir)]
|
70 |
+
dirs = sorted(dirs, key=str.lower)[1:] #sort chronologically and remove parent dir from list
|
71 |
+
|
72 |
+
with imageio.get_writer(args.gif_name+'.gif', mode='I') as writer: #https://imageio.readthedocs.io/en/stable/index.html#
|
73 |
+
for epoch in dirs: #extract the right file from each epoch
|
74 |
+
for _, _, files in os.walk(epoch): #
|
75 |
+
for name in files:
|
76 |
+
if args.representation in name and args.step in name and "png" in name:
|
77 |
+
image = imageio.imread(epoch+"/"+name)
|
78 |
+
writer.append_data(image)
|
79 |
+
|
80 |
+
# 7. cleanup the downloaded artifacts from client file system
|
81 |
+
if args.cleanup:
|
82 |
+
shutil.rmtree(local_dir) #delete the files downloaded from mlflow
|
83 |
+
|
84 |
+
elif args.output == "train_vs_val_loss":
|
85 |
+
train_loss = client.get_metric_history(runs["run_id"][0], "train_loss") #returns a list of metric entities https://www.mlflow.org/docs/latest/_modules/mlflow/entities/metric.html
|
86 |
+
val_loss = client.get_metric_history(runs["run_id"][0], "val_loss") #TODO: parametrize this number [0] so the right run is selected
|
87 |
+
train_loss = sorted(train_loss, key=lambda m: m.step) #sort the metric objects in list according to step property
|
88 |
+
val_loss = sorted(val_loss, key=lambda m: m.step)
|
89 |
+
plt.figure()
|
90 |
+
for m_train, m_val in zip(train_loss, val_loss):
|
91 |
+
plt.scatter(m_train.value, m_val.value, alpha=1/(m_train.step+1), color='blue')
|
92 |
+
plt.savefig("scatter.png") #TODO: parametrize filename
|
models/classifier.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import defaultdict
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.optim
|
6 |
+
from torchvision.models import resnet18
|
7 |
+
from torchvision.utils import make_grid, save_image
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
|
12 |
+
import mlflow.pytorch
|
13 |
+
|
14 |
+
|
15 |
+
def resnet_model(model=resnet18, pretrained=True, in_channels=3, fc_out_features=2):
|
16 |
+
resnet = model(pretrained=pretrained)
|
17 |
+
# if not pretrained: # TODO: add case for in_channels=4
|
18 |
+
# resnet.conv1 = torch.nn.Conv2d(channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
19 |
+
resnet.fc = torch.nn.Linear(in_features=512, out_features=fc_out_features, bias=True)
|
20 |
+
return resnet
|
21 |
+
|
22 |
+
|
23 |
+
class LitModel(pl.LightningModule):
|
24 |
+
|
25 |
+
def __init__(self,
|
26 |
+
classifier,
|
27 |
+
loss,
|
28 |
+
lr=1e-3,
|
29 |
+
weight_decay=0,
|
30 |
+
loss_aux=None,
|
31 |
+
adv_training=False,
|
32 |
+
metrics=None,
|
33 |
+
processor=None,
|
34 |
+
augmentation=None,
|
35 |
+
is_segmentation_task=False,
|
36 |
+
augmentation_on_eval=False,
|
37 |
+
metrics_on_training=True,
|
38 |
+
freeze_classifier=False,
|
39 |
+
freeze_processor=False,
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.classifier = classifier
|
44 |
+
self.processor = processor
|
45 |
+
|
46 |
+
self.lr = lr
|
47 |
+
self.weight_decay = weight_decay
|
48 |
+
self.loss_fn = loss
|
49 |
+
self.loss_aux_fn = loss_aux
|
50 |
+
self.adv_training = adv_training
|
51 |
+
self.metrics = metrics
|
52 |
+
self.augmentation = augmentation
|
53 |
+
self.is_segmentation_task = is_segmentation_task
|
54 |
+
self.augmentation_on_eval = augmentation_on_eval
|
55 |
+
self.metrics_on_training = metrics_on_training
|
56 |
+
|
57 |
+
self.freeze_classifier = freeze_classifier
|
58 |
+
self.freeze_processor = freeze_processor
|
59 |
+
|
60 |
+
if freeze_classifier:
|
61 |
+
pl.LightningModule.freeze(self.classifier)
|
62 |
+
if freeze_processor:
|
63 |
+
pl.LightningModule.freeze(self.processor)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
x = self.processor(x)
|
67 |
+
apply_augmentation_step = self.training or self.augmentation_on_eval
|
68 |
+
if self.augmentation is not None and apply_augmentation_step:
|
69 |
+
x = self.augmentation(x, retain_state=self.is_segmentation_task)
|
70 |
+
x = self.classifier(x)
|
71 |
+
return x
|
72 |
+
|
73 |
+
def update_step(self, batch, step_name):
|
74 |
+
x, y = batch
|
75 |
+
# debug(self.processor)
|
76 |
+
# debug(self.processor.parameters())
|
77 |
+
# debug.pause()
|
78 |
+
# print('type', type(self.processor).__name__)
|
79 |
+
|
80 |
+
logits = self(x)
|
81 |
+
|
82 |
+
apply_augmentation_mask = self.is_segmentation_task and (self.training or self.augmentation_on_eval)
|
83 |
+
if self.augmentation is not None and apply_augmentation_mask:
|
84 |
+
y = self.augmentation(y, mask_transform=True).contiguous()
|
85 |
+
|
86 |
+
loss = self.loss_fn(logits, y)
|
87 |
+
|
88 |
+
if self.loss_aux_fn is not None:
|
89 |
+
loss_aux = self.loss_aux_fn(x)
|
90 |
+
loss += loss_aux
|
91 |
+
|
92 |
+
self.log(f'{step_name}_loss', loss, on_step=False, on_epoch=True)
|
93 |
+
if self.loss_aux_fn is not None:
|
94 |
+
self.log(f'{step_name}_loss_aux', loss_aux, on_step=False, on_epoch=True)
|
95 |
+
|
96 |
+
if self.is_segmentation_task:
|
97 |
+
y_hat = F.logsigmoid(logits).exp().squeeze()
|
98 |
+
else:
|
99 |
+
y_hat = torch.argmax(logits, dim=1)
|
100 |
+
|
101 |
+
|
102 |
+
if self.metrics is not None:
|
103 |
+
for metric in self.metrics:
|
104 |
+
metric_name = metric.__name__ if hasattr(metric, '__name__') else type(metric).__name__
|
105 |
+
if metric_name == 'accuracy' or not self.training or self.metrics_on_training:
|
106 |
+
m = metric(y_hat.cpu().detach(), y.cpu())
|
107 |
+
self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
|
108 |
+
prog_bar=self.training or metric_name == 'accuracy')
|
109 |
+
if metric_name == 'iou_score' or not self.training or self.metrics_on_training:
|
110 |
+
m = metric(y_hat.cpu().detach(), y.cpu())
|
111 |
+
self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
|
112 |
+
prog_bar=self.training or metric_name == 'iou_score')
|
113 |
+
|
114 |
+
return loss
|
115 |
+
|
116 |
+
def training_step(self, batch, batch_idx):
|
117 |
+
return self.update_step(batch, 'train')
|
118 |
+
|
119 |
+
def validation_step(self, batch, batch_idx):
|
120 |
+
return self.update_step(batch, 'val')
|
121 |
+
|
122 |
+
def test_step(self, batch, batch_idx):
|
123 |
+
return self.update_step(batch, 'test')
|
124 |
+
|
125 |
+
def train(self, mode=True):
|
126 |
+
self.training = mode
|
127 |
+
# self.processor.train(False)
|
128 |
+
self.processor.train(mode=mode and not self.freeze_processor)
|
129 |
+
self.classifier.train(mode=mode and not self.freeze_classifier)
|
130 |
+
if self.adv_training and self.processor.batch_norm is not None: # don't update batchnorm in adversarial training
|
131 |
+
self.processor.batch_norm.track_running_stats = False
|
132 |
+
return self
|
133 |
+
|
134 |
+
def configure_optimizers(self):
|
135 |
+
self.optimizer = torch.optim.Adam(self.parameters(), self.lr, weight_decay=self.weight_decay)
|
136 |
+
# parameters = [self.processor.additive_layer]
|
137 |
+
# self.optimizer = torch.optim.Adam(parameters, self.lr, weight_decay=self.weight_decay)
|
138 |
+
return self.optimizer
|
139 |
+
# self.scheduler = {
|
140 |
+
# 'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(
|
141 |
+
# self.optimizer, mode='min', factor=0.2, patience=2, min_lr=1e-6, verbose=True,
|
142 |
+
# ),
|
143 |
+
# 'monitor': 'val_loss',
|
144 |
+
# }
|
145 |
+
# return [self.optimizer], [self.scheduler]
|
146 |
+
|
147 |
+
def get_progress_bar_dict(self):
|
148 |
+
items = super().get_progress_bar_dict()
|
149 |
+
items.pop('v_num')
|
150 |
+
return items
|
151 |
+
|
152 |
+
|
153 |
+
class TrackImagesCallback(pl.callbacks.base.Callback):
|
154 |
+
def __init__(self, data_loader, track_every_epoch=False, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True):
|
155 |
+
super().__init__()
|
156 |
+
self.data_loader = data_loader
|
157 |
+
|
158 |
+
self.track_every_epoch = track_every_epoch
|
159 |
+
|
160 |
+
self.track_processing = track_processing
|
161 |
+
self.track_gradients = track_gradients
|
162 |
+
self.track_predictions = track_predictions
|
163 |
+
self.save_tensors = save_tensors
|
164 |
+
|
165 |
+
def callback_track_images(self, trainer, save_loc):
|
166 |
+
track_images(trainer.model,
|
167 |
+
self.data_loader,
|
168 |
+
track_processing=self.track_processing,
|
169 |
+
track_gradients=self.track_gradients,
|
170 |
+
track_predictions=self.track_predictions,
|
171 |
+
save_tensors=self.save_tensors,
|
172 |
+
save_loc=save_loc,
|
173 |
+
)
|
174 |
+
|
175 |
+
def on_fit_end(self, trainer, pl_module):
|
176 |
+
if not self.track_every_epoch:
|
177 |
+
save_loc = 'results'
|
178 |
+
self.callback_track_images(trainer, save_loc)
|
179 |
+
|
180 |
+
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
181 |
+
if self.track_every_epoch:
|
182 |
+
save_loc = f'results/epoch_{trainer.current_epoch + 1:04d}'
|
183 |
+
self.callback_track_images(trainer, save_loc)
|
184 |
+
|
185 |
+
|
186 |
+
from utils.debug import debug
|
187 |
+
|
188 |
+
|
189 |
+
# @debug
|
190 |
+
def log_tensor(batch, path, save_tensors=True, nrow=8):
|
191 |
+
if save_tensors:
|
192 |
+
torch.save(batch, path)
|
193 |
+
mlflow.log_artifact(path, os.path.dirname(path))
|
194 |
+
|
195 |
+
img_path = path.replace('.pt', '.png')
|
196 |
+
split = img_path.split('/')
|
197 |
+
img_path = '/'.join(split[:-1]) + '/img_' + split[-1] # insert 'img_'; make it easier to find in mlflow
|
198 |
+
|
199 |
+
grid = make_grid(batch, nrow=nrow).squeeze()
|
200 |
+
save_image(grid, img_path)
|
201 |
+
mlflow.log_artifact(img_path, os.path.dirname(path))
|
202 |
+
|
203 |
+
|
204 |
+
def track_images(model, data_loader, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True, save_loc='results'):
|
205 |
+
|
206 |
+
device = model.device
|
207 |
+
processor = model.processor
|
208 |
+
classifier = model.classifier
|
209 |
+
|
210 |
+
if not hasattr(processor, 'stages'): # 'static' or 'none' pipeline
|
211 |
+
return
|
212 |
+
|
213 |
+
os.makedirs(save_loc, exist_ok=True)
|
214 |
+
|
215 |
+
# TODO: implement track_predictions
|
216 |
+
|
217 |
+
# inputs_full = []
|
218 |
+
labels_full = []
|
219 |
+
logits_full = []
|
220 |
+
stages_full = defaultdict(list)
|
221 |
+
grads_full = defaultdict(list)
|
222 |
+
|
223 |
+
for inputs, labels in data_loader:
|
224 |
+
|
225 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
226 |
+
inputs.requires_grad = True
|
227 |
+
|
228 |
+
processed_rgb = processor(inputs)
|
229 |
+
|
230 |
+
if track_gradients or track_predictions:
|
231 |
+
logits = classifier(processed_rgb)
|
232 |
+
|
233 |
+
# NOTE: should zero grads for good measure
|
234 |
+
loss = model.loss_fn(logits, labels)
|
235 |
+
loss.backward()
|
236 |
+
|
237 |
+
if track_predictions:
|
238 |
+
labels_full.append(labels.cpu().detach())
|
239 |
+
logits_full.append(logits.cpu().detach())
|
240 |
+
# inputs_full.append(inputs.cpu().detach())
|
241 |
+
|
242 |
+
for stage, batch in processor.stages.items():
|
243 |
+
stages_full[stage].append(batch.cpu().detach())
|
244 |
+
if track_gradients:
|
245 |
+
grads_full[stage].append(batch.grad.cpu().detach())
|
246 |
+
|
247 |
+
with torch.no_grad():
|
248 |
+
|
249 |
+
stages = stages_full
|
250 |
+
grads = grads_full
|
251 |
+
|
252 |
+
if track_processing:
|
253 |
+
for stage, batch in stages_full.items():
|
254 |
+
stages[stage] = torch.cat(batch)
|
255 |
+
|
256 |
+
if track_gradients:
|
257 |
+
for stage, batch in grads_full.items():
|
258 |
+
grads[stage] = torch.cat(batch)
|
259 |
+
|
260 |
+
for stage_nr, stage_name in enumerate(stages):
|
261 |
+
if track_processing:
|
262 |
+
batch = stages[stage_name]
|
263 |
+
log_tensor(batch, os.path.join(save_loc, f'processing_{stage_nr}_{stage_name}.pt'), save_tensors)
|
264 |
+
if track_gradients:
|
265 |
+
batch_grad = grads[stage_name]
|
266 |
+
batch_grad = batch_grad.abs()
|
267 |
+
batch_grad = (batch_grad - batch_grad.min()) / (batch_grad.max() - batch_grad.min())
|
268 |
+
log_tensor(batch_grad, os.path.join(
|
269 |
+
save_loc, f'gradients_{stage_nr}_{stage_name}.pt'), save_tensors)
|
270 |
+
|
271 |
+
# inputs = torch.cat(inputs_full)
|
272 |
+
|
273 |
+
if track_predictions: #and model.is_segmentation_task:
|
274 |
+
labels = torch.cat(labels_full)
|
275 |
+
logits = torch.cat(logits_full)
|
276 |
+
masks = labels.unsqueeze(1)
|
277 |
+
predictions = logits #torch.sigmoid(logits).unsqueeze(1)
|
278 |
+
#mask_vis = torch.cat((masks, predictions, masks * predictions), dim=1)
|
279 |
+
#log_tensor(mask_vis, os.path.join(save_loc, f'masks.pt'), save_tensors)
|
280 |
+
log_tensor(masks, os.path.join(save_loc, f'targets.pt'), save_tensors)
|
281 |
+
log_tensor(predictions, os.path.join(save_loc, f'preds.pt'), save_tensors)
|
perturbed-environment.yml
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: perturbed
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
dependencies:
|
5 |
+
- _ipyw_jlab_nb_ext_conf=0.1.0=py37_0
|
6 |
+
- _libgcc_mutex=0.1=main
|
7 |
+
- alabaster=0.7.12=py37_0
|
8 |
+
- anaconda=2019.10=py37_0
|
9 |
+
- anaconda-client=1.7.2=py37_0
|
10 |
+
- anaconda-navigator=1.9.7=py37_0
|
11 |
+
- anaconda-project=0.8.3=py_0
|
12 |
+
- asn1crypto=1.0.1=py37_0
|
13 |
+
- astroid=2.3.1=py37_0
|
14 |
+
- astropy=3.2.2=py37h7b6447c_0
|
15 |
+
- atomicwrites=1.3.0=py37_1
|
16 |
+
- attrs=19.2.0=py_0
|
17 |
+
- babel=2.7.0=py_0
|
18 |
+
- backcall=0.1.0=py37_0
|
19 |
+
- backports=1.0=py_2
|
20 |
+
- backports.functools_lru_cache=1.5=py_2
|
21 |
+
- backports.os=0.1.1=py37_0
|
22 |
+
- backports.shutil_get_terminal_size=1.0.0=py37_2
|
23 |
+
- backports.tempfile=1.0=py_1
|
24 |
+
- backports.weakref=1.0.post1=py_1
|
25 |
+
- beautifulsoup4=4.8.0=py37_0
|
26 |
+
- bitarray=1.0.1=py37h7b6447c_0
|
27 |
+
- bkcharts=0.2=py37_0
|
28 |
+
- blas=1.0=mkl
|
29 |
+
- bleach=3.1.0=py37_0
|
30 |
+
- blosc=1.16.3=hd408876_0
|
31 |
+
- bokeh=1.3.4=py37_0
|
32 |
+
- boto=2.49.0=py37_0
|
33 |
+
- bottleneck=1.2.1=py37h035aef0_1
|
34 |
+
- bzip2=1.0.8=h7b6447c_0
|
35 |
+
- ca-certificates=2019.8.28=0
|
36 |
+
- cairo=1.14.12=h8948797_3
|
37 |
+
- certifi=2019.9.11=py37_0
|
38 |
+
- cffi=1.12.3=py37h2e261b9_0
|
39 |
+
- chardet=3.0.4=py37_1003
|
40 |
+
- click=7.0=py37_0
|
41 |
+
- cloudpickle=1.2.2=py_0
|
42 |
+
- clyent=1.2.2=py37_1
|
43 |
+
- colorama=0.4.1=py37_0
|
44 |
+
- conda-package-handling=1.6.0=py37h7b6447c_0
|
45 |
+
- conda-verify=3.4.2=py_1
|
46 |
+
- contextlib2=0.6.0=py_0
|
47 |
+
- cryptography=2.7=py37h1ba5d50_0
|
48 |
+
- curl=7.65.3=hbc83047_0
|
49 |
+
- cycler=0.10.0=py37_0
|
50 |
+
- cython=0.29.13=py37he6710b0_0
|
51 |
+
- cytoolz=0.10.0=py37h7b6447c_0
|
52 |
+
- dask=2.5.2=py_0
|
53 |
+
- dask-core=2.5.2=py_0
|
54 |
+
- dbus=1.13.6=h746ee38_0
|
55 |
+
- decorator=4.4.0=py37_1
|
56 |
+
- defusedxml=0.6.0=py_0
|
57 |
+
- distributed=2.5.2=py_0
|
58 |
+
- docutils=0.15.2=py37_0
|
59 |
+
- entrypoints=0.3=py37_0
|
60 |
+
- et_xmlfile=1.0.1=py37_0
|
61 |
+
- expat=2.2.6=he6710b0_0
|
62 |
+
- fastcache=1.1.0=py37h7b6447c_0
|
63 |
+
- filelock=3.0.12=py_0
|
64 |
+
- flask=1.1.1=py_0
|
65 |
+
- fontconfig=2.13.0=h9420a91_0
|
66 |
+
- freetype=2.9.1=h8a8886c_1
|
67 |
+
- fribidi=1.0.5=h7b6447c_0
|
68 |
+
- future=0.17.1=py37_0
|
69 |
+
- get_terminal_size=1.0.0=haa9412d_0
|
70 |
+
- gevent=1.4.0=py37h7b6447c_0
|
71 |
+
- glib=2.56.2=hd408876_0
|
72 |
+
- glob2=0.7=py_0
|
73 |
+
- gmp=6.1.2=h6c8ec71_1
|
74 |
+
- gmpy2=2.0.8=py37h10f8cd9_2
|
75 |
+
- graphite2=1.3.13=h23475e2_0
|
76 |
+
- greenlet=0.4.15=py37h7b6447c_0
|
77 |
+
- gst-plugins-base=1.14.0=hbbd80ab_1
|
78 |
+
- gstreamer=1.14.0=hb453b48_1
|
79 |
+
- h5py=2.9.0=py37h7918eee_0
|
80 |
+
- harfbuzz=1.8.8=hffaf4a1_0
|
81 |
+
- hdf5=1.10.4=hb1b8bf9_0
|
82 |
+
- heapdict=1.0.1=py_0
|
83 |
+
- html5lib=1.0.1=py37_0
|
84 |
+
- icu=58.2=h9c2bf20_1
|
85 |
+
- idna=2.8=py37_0
|
86 |
+
- imageio=2.6.0=py37_0
|
87 |
+
- imagesize=1.1.0=py37_0
|
88 |
+
- intel-openmp=2019.4=243
|
89 |
+
- ipykernel=5.1.2=py37h39e3cac_0
|
90 |
+
- ipython=7.8.0=py37h39e3cac_0
|
91 |
+
- ipython_genutils=0.2.0=py37_0
|
92 |
+
- ipywidgets=7.5.1=py_0
|
93 |
+
- isort=4.3.21=py37_0
|
94 |
+
- itsdangerous=1.1.0=py37_0
|
95 |
+
- jbig=2.1=hdba287a_0
|
96 |
+
- jdcal=1.4.1=py_0
|
97 |
+
- jedi=0.15.1=py37_0
|
98 |
+
- jeepney=0.4.1=py_0
|
99 |
+
- jinja2=2.10.3=py_0
|
100 |
+
- joblib=0.13.2=py37_0
|
101 |
+
- jpeg=9b=h024ee3a_2
|
102 |
+
- json5=0.8.5=py_0
|
103 |
+
- jsonschema=3.0.2=py37_0
|
104 |
+
- jupyter=1.0.0=py37_7
|
105 |
+
- jupyter_client=5.3.3=py37_1
|
106 |
+
- jupyter_console=6.0.0=py37_0
|
107 |
+
- jupyter_core=4.5.0=py_0
|
108 |
+
- jupyterlab=1.1.4=pyhf63ae98_0
|
109 |
+
- jupyterlab_server=1.0.6=py_0
|
110 |
+
- keyring=18.0.0=py37_0
|
111 |
+
- kiwisolver=1.1.0=py37he6710b0_0
|
112 |
+
- krb5=1.16.1=h173b8e3_7
|
113 |
+
- lazy-object-proxy=1.4.2=py37h7b6447c_0
|
114 |
+
- libarchive=3.3.3=h5d8350f_5
|
115 |
+
- libcurl=7.65.3=h20c2e04_0
|
116 |
+
- libedit=3.1.20181209=hc058e9b_0
|
117 |
+
- libffi=3.2.1=hd88cf55_4
|
118 |
+
- libgcc-ng=9.1.0=hdf63c60_0
|
119 |
+
- libgfortran-ng=7.3.0=hdf63c60_0
|
120 |
+
- liblief=0.9.0=h7725739_2
|
121 |
+
- libpng=1.6.37=hbc83047_0
|
122 |
+
- libsodium=1.0.16=h1bed415_0
|
123 |
+
- libssh2=1.8.2=h1ba5d50_0
|
124 |
+
- libstdcxx-ng=9.1.0=hdf63c60_0
|
125 |
+
- libtiff=4.0.10=h2733197_2
|
126 |
+
- libtool=2.4.6=h7b6447c_5
|
127 |
+
- libuuid=1.0.3=h1bed415_2
|
128 |
+
- libxcb=1.13=h1bed415_1
|
129 |
+
- libxml2=2.9.9=hea5a465_1
|
130 |
+
- libxslt=1.1.33=h7d1a2b0_0
|
131 |
+
- llvmlite=0.29.0=py37hd408876_0
|
132 |
+
- locket=0.2.0=py37_1
|
133 |
+
- lxml=4.4.1=py37hefd8a0e_0
|
134 |
+
- lz4-c=1.8.1.2=h14c3975_0
|
135 |
+
- lzo=2.10=h49e0be7_2
|
136 |
+
- markupsafe=1.1.1=py37h7b6447c_0
|
137 |
+
- matplotlib=3.1.1=py37h5429711_0
|
138 |
+
- mccabe=0.6.1=py37_1
|
139 |
+
- mistune=0.8.4=py37h7b6447c_0
|
140 |
+
- mkl=2019.4=243
|
141 |
+
- mkl-service=2.3.0=py37he904b0f_0
|
142 |
+
- mkl_fft=1.0.14=py37ha843d7b_0
|
143 |
+
- mkl_random=1.1.0=py37hd6b4f25_0
|
144 |
+
- mock=3.0.5=py37_0
|
145 |
+
- more-itertools=7.2.0=py37_0
|
146 |
+
- mpc=1.1.0=h10f8cd9_1
|
147 |
+
- mpfr=4.0.1=hdf1c602_3
|
148 |
+
- mpmath=1.1.0=py37_0
|
149 |
+
- msgpack-python=0.6.1=py37hfd86e86_1
|
150 |
+
- multipledispatch=0.6.0=py37_0
|
151 |
+
- navigator-updater=0.2.1=py37_0
|
152 |
+
- nbconvert=5.6.0=py37_1
|
153 |
+
- nbformat=4.4.0=py37_0
|
154 |
+
- ncurses=6.1=he6710b0_1
|
155 |
+
- networkx=2.3=py_0
|
156 |
+
- nltk=3.4.5=py37_0
|
157 |
+
- nose=1.3.7=py37_2
|
158 |
+
- notebook=6.0.1=py37_0
|
159 |
+
- numba=0.45.1=py37h962f231_0
|
160 |
+
- numexpr=2.7.0=py37h9e4a6bb_0
|
161 |
+
- numpy=1.17.2=py37haad9e8e_0
|
162 |
+
- numpy-base=1.17.2=py37hde5b4d6_0
|
163 |
+
- numpydoc=0.9.1=py_0
|
164 |
+
- olefile=0.46=py37_0
|
165 |
+
- openpyxl=3.0.0=py_0
|
166 |
+
- openssl=1.1.1d=h7b6447c_2
|
167 |
+
- packaging=19.2=py_0
|
168 |
+
- pandoc=2.2.3.2=0
|
169 |
+
- pandocfilters=1.4.2=py37_1
|
170 |
+
- pango=1.42.4=h049681c_0
|
171 |
+
- parso=0.5.1=py_0
|
172 |
+
- partd=1.0.0=py_0
|
173 |
+
- patchelf=0.9=he6710b0_3
|
174 |
+
- path.py=12.0.1=py_0
|
175 |
+
- pathlib2=2.3.5=py37_0
|
176 |
+
- patsy=0.5.1=py37_0
|
177 |
+
- pcre=8.43=he6710b0_0
|
178 |
+
- pep8=1.7.1=py37_0
|
179 |
+
- pexpect=4.7.0=py37_0
|
180 |
+
- pickleshare=0.7.5=py37_0
|
181 |
+
- pip=19.2.3=py37_0
|
182 |
+
- pixman=0.38.0=h7b6447c_0
|
183 |
+
- pkginfo=1.5.0.1=py37_0
|
184 |
+
- pluggy=0.13.0=py37_0
|
185 |
+
- ply=3.11=py37_0
|
186 |
+
- prometheus_client=0.7.1=py_0
|
187 |
+
- prompt_toolkit=2.0.10=py_0
|
188 |
+
- psutil=5.6.3=py37h7b6447c_0
|
189 |
+
- ptyprocess=0.6.0=py37_0
|
190 |
+
- py=1.8.0=py37_0
|
191 |
+
- py-lief=0.9.0=py37h7725739_2
|
192 |
+
- pycodestyle=2.5.0=py37_0
|
193 |
+
- pycosat=0.6.3=py37h14c3975_0
|
194 |
+
- pycparser=2.19=py37_0
|
195 |
+
- pycrypto=2.6.1=py37h14c3975_9
|
196 |
+
- pycurl=7.43.0.3=py37h1ba5d50_0
|
197 |
+
- pyflakes=2.1.1=py37_0
|
198 |
+
- pygments=2.4.2=py_0
|
199 |
+
- pylint=2.4.2=py37_0
|
200 |
+
- pyodbc=4.0.27=py37he6710b0_0
|
201 |
+
- pyopenssl=19.0.0=py37_0
|
202 |
+
- pyparsing=2.4.2=py_0
|
203 |
+
- pyqt=5.9.2=py37h05f1152_2
|
204 |
+
- pyrsistent=0.15.4=py37h7b6447c_0
|
205 |
+
- pysocks=1.7.1=py37_0
|
206 |
+
- pytables=3.5.2=py37h71ec239_1
|
207 |
+
- pytest=5.2.1=py37_0
|
208 |
+
- pytest-arraydiff=0.3=py37h39e3cac_0
|
209 |
+
- pytest-astropy=0.5.0=py37_0
|
210 |
+
- pytest-doctestplus=0.4.0=py_0
|
211 |
+
- pytest-openfiles=0.4.0=py_0
|
212 |
+
- pytest-remotedata=0.3.2=py37_0
|
213 |
+
- python=3.7.4=h265db76_1
|
214 |
+
- python-dateutil=2.8.0=py37_0
|
215 |
+
- python-libarchive-c=2.8=py37_13
|
216 |
+
- pytz=2019.3=py_0
|
217 |
+
- pyyaml=5.1.2=py37h7b6447c_0
|
218 |
+
- pyzmq=18.1.0=py37he6710b0_0
|
219 |
+
- qt=5.9.7=h5867ecd_1
|
220 |
+
- qtawesome=0.6.0=py_0
|
221 |
+
- qtconsole=4.5.5=py_0
|
222 |
+
- qtpy=1.9.0=py_0
|
223 |
+
- readline=7.0=h7b6447c_5
|
224 |
+
- requests=2.22.0=py37_0
|
225 |
+
- ripgrep=0.10.0=hc07d326_0
|
226 |
+
- rope=0.14.0=py_0
|
227 |
+
- ruamel_yaml=0.15.46=py37h14c3975_0
|
228 |
+
- scikit-learn=0.21.3=py37hd81dba3_0
|
229 |
+
- scipy=1.3.1=py37h7c811a0_0
|
230 |
+
- seaborn=0.9.0=py37_0
|
231 |
+
- secretstorage=3.1.1=py37_0
|
232 |
+
- send2trash=1.5.0=py37_0
|
233 |
+
- setuptools=41.4.0=py37_0
|
234 |
+
- simplegeneric=0.8.1=py37_2
|
235 |
+
- singledispatch=3.4.0.3=py37_0
|
236 |
+
- sip=4.19.8=py37hf484d3e_0
|
237 |
+
- six=1.12.0=py37_0
|
238 |
+
- snappy=1.1.7=hbae5bb6_3
|
239 |
+
- snowballstemmer=2.0.0=py_0
|
240 |
+
- sortedcollections=1.1.2=py37_0
|
241 |
+
- sortedcontainers=2.1.0=py37_0
|
242 |
+
- soupsieve=1.9.3=py37_0
|
243 |
+
- sphinx=2.2.0=py_0
|
244 |
+
- sphinxcontrib=1.0=py37_1
|
245 |
+
- sphinxcontrib-applehelp=1.0.1=py_0
|
246 |
+
- sphinxcontrib-devhelp=1.0.1=py_0
|
247 |
+
- sphinxcontrib-htmlhelp=1.0.2=py_0
|
248 |
+
- sphinxcontrib-jsmath=1.0.1=py_0
|
249 |
+
- sphinxcontrib-qthelp=1.0.2=py_0
|
250 |
+
- sphinxcontrib-serializinghtml=1.1.3=py_0
|
251 |
+
- sphinxcontrib-websupport=1.1.2=py_0
|
252 |
+
- spyder=3.3.6=py37_0
|
253 |
+
- spyder-kernels=0.5.2=py37_0
|
254 |
+
- sqlalchemy=1.3.9=py37h7b6447c_0
|
255 |
+
- sqlite=3.30.0=h7b6447c_0
|
256 |
+
- statsmodels=0.10.1=py37hdd07704_0
|
257 |
+
- sympy=1.4=py37_0
|
258 |
+
- tbb=2019.4=hfd86e86_0
|
259 |
+
- tblib=1.4.0=py_0
|
260 |
+
- terminado=0.8.2=py37_0
|
261 |
+
- testpath=0.4.2=py37_0
|
262 |
+
- tk=8.6.8=hbc83047_0
|
263 |
+
- toolz=0.10.0=py_0
|
264 |
+
- tornado=6.0.3=py37h7b6447c_0
|
265 |
+
- traitlets=4.3.3=py37_0
|
266 |
+
- unicodecsv=0.14.1=py37_0
|
267 |
+
- unixodbc=2.3.7=h14c3975_0
|
268 |
+
- wcwidth=0.1.7=py37_0
|
269 |
+
- webencodings=0.5.1=py37_1
|
270 |
+
- werkzeug=0.16.0=py_0
|
271 |
+
- wheel=0.33.6=py37_0
|
272 |
+
- widgetsnbextension=3.5.1=py37_0
|
273 |
+
- wrapt=1.11.2=py37h7b6447c_0
|
274 |
+
- wurlitzer=1.0.3=py37_0
|
275 |
+
- xlrd=1.2.0=py37_0
|
276 |
+
- xlsxwriter=1.2.1=py_0
|
277 |
+
- xlwt=1.3.0=py37_0
|
278 |
+
- xz=5.2.4=h14c3975_4
|
279 |
+
- yaml=0.1.7=had09818_2
|
280 |
+
- zeromq=4.3.1=he6710b0_3
|
281 |
+
- zict=1.0.0=py_0
|
282 |
+
- zipp=0.6.0=py_0
|
283 |
+
- zlib=1.2.11=h7b6447c_3
|
284 |
+
- zstd=1.3.7=h0b5b093_0
|
285 |
+
- pip:
|
286 |
+
- absl-py==0.12.0
|
287 |
+
- aiohttp==3.7.4.post0
|
288 |
+
- albumentations==0.5.2
|
289 |
+
- alembic==1.4.1
|
290 |
+
- arrow==0.17.0
|
291 |
+
- async-timeout==3.0.1
|
292 |
+
- b2sdk==1.4.0
|
293 |
+
- boto3==1.17.36
|
294 |
+
- botocore==1.20.36
|
295 |
+
- cachetools==4.2.1
|
296 |
+
- colour-demosaicing==0.1.6
|
297 |
+
- colour-science==0.3.16
|
298 |
+
- configparser==5.0.0
|
299 |
+
- databricks-cli==0.10.0
|
300 |
+
- docker==4.2.0
|
301 |
+
- docopt==0.6.2
|
302 |
+
- efficientnet-pytorch==0.6.3
|
303 |
+
- fsspec==0.8.7
|
304 |
+
- funcsigs==1.0.2
|
305 |
+
- gitdb==4.0.4
|
306 |
+
- gitpython==3.1.1
|
307 |
+
- google-auth==1.28.0
|
308 |
+
- google-auth-oauthlib==0.4.3
|
309 |
+
- gorilla==0.3.0
|
310 |
+
- grpcio==1.36.1
|
311 |
+
- gunicorn==20.0.4
|
312 |
+
- imgaug==0.4.0
|
313 |
+
- importlib-metadata==3.7.3
|
314 |
+
- jmespath==0.10.0
|
315 |
+
- logfury==0.1.2
|
316 |
+
- mako==1.1.2
|
317 |
+
- markdown==3.3.4
|
318 |
+
- mlflow==1.14.1
|
319 |
+
- multidict==5.1.0
|
320 |
+
- munch==2.5.0
|
321 |
+
- oauthlib==3.1.0
|
322 |
+
- opencv-python==4.5.1.48
|
323 |
+
- opencv-python-headless==4.5.1.48
|
324 |
+
- pandas==1.2.3
|
325 |
+
- pillow==8.1.2
|
326 |
+
- pipreqs==0.4.10
|
327 |
+
- plotly==4.14.3
|
328 |
+
- pretrainedmodels==0.7.4
|
329 |
+
- prettytable==2.1.0
|
330 |
+
- prometheus-flask-exporter==0.13.0
|
331 |
+
- protobuf==3.11.3
|
332 |
+
- pyasn1==0.4.8
|
333 |
+
- pyasn1-modules==0.2.8
|
334 |
+
- python-editor==1.0.4
|
335 |
+
- pytorch-lightning==1.2.5
|
336 |
+
- pywavelets==1.1.1
|
337 |
+
- querystring-parser==1.2.4
|
338 |
+
- rawpy==0.16.0
|
339 |
+
- requests-oauthlib==1.3.0
|
340 |
+
- retrying==1.3.3
|
341 |
+
- rsa==4.7.2
|
342 |
+
- s3transfer==0.3.6
|
343 |
+
- scikit-image==0.18.1
|
344 |
+
- segmentation-models-pytorch==0.1.3
|
345 |
+
- shapely==1.7.1
|
346 |
+
- simplejson==3.17.0
|
347 |
+
- smmap==3.0.2
|
348 |
+
- sqlparse==0.3.1
|
349 |
+
- tabulate==0.8.7
|
350 |
+
- tensorboard==2.4.1
|
351 |
+
- tensorboard-plugin-wit==1.8.0
|
352 |
+
- tifffile==2021.3.17
|
353 |
+
- timm==0.3.2
|
354 |
+
- torch==1.8.0
|
355 |
+
- torchmetrics==0.2.0
|
356 |
+
- torchvision==0.9.0
|
357 |
+
- tqdm==4.59.0
|
358 |
+
- typing-extensions==3.7.4.3
|
359 |
+
- urllib3==1.25.11
|
360 |
+
- websocket-client==0.57.0
|
361 |
+
- yarg==0.1.9
|
362 |
+
- yarl==1.6.3
|
363 |
+
prefix: /home/nobis/anaconda3/envs/perturbed
|
processingpipeline/numpy_static_pipeline_show.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64edef77495ab24143430e7a5d880b6f211568371f37eab03e1b32fb2f5b8015
|
3 |
+
size 1906586
|
processingpipeline/pipeline.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Raw Image Pipeline
|
3 |
+
"""
|
4 |
+
__author__ = "Marco Aversa"
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from rawpy import * # XXX: no * imports!
|
9 |
+
from scipy import ndimage
|
10 |
+
from scipy import fftpack
|
11 |
+
from scipy.signal import convolve2d
|
12 |
+
|
13 |
+
from skimage.filters import unsharp_mask
|
14 |
+
from skimage.color import rgb2yuv, yuv2rgb, rgb2hsv, hsv2rgb
|
15 |
+
from skimage.restoration import denoise_tv_chambolle, denoise_tv_bregman, denoise_nl_means, denoise_bilateral, denoise_wavelet, estimate_sigma
|
16 |
+
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
|
19 |
+
from colour_demosaicing import (demosaicing_CFA_Bayer_bilinear,
|
20 |
+
demosaicing_CFA_Bayer_Malvar2004,
|
21 |
+
demosaicing_CFA_Bayer_Menon2007)
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
from utils.dataset import Subset
|
27 |
+
from torch.utils.data import DataLoader
|
28 |
+
|
29 |
+
from colour_demosaicing import (demosaicing_CFA_Bayer_bilinear,
|
30 |
+
demosaicing_CFA_Bayer_Malvar2004,
|
31 |
+
demosaicing_CFA_Bayer_Menon2007)
|
32 |
+
|
33 |
+
import matplotlib.pyplot as plt
|
34 |
+
|
35 |
+
|
36 |
+
class RawProcessingPipeline(object):
|
37 |
+
|
38 |
+
"""Applies the raw-processing pipeline from pipeline.py"""
|
39 |
+
|
40 |
+
def __init__(self, camera_parameters, debayer='bilinear', sharpening='unsharp_masking', denoising='gaussian'):
|
41 |
+
'''
|
42 |
+
Args:
|
43 |
+
camera_parameters (tuple): (black_level, white_balance, colour_matrix)
|
44 |
+
debayer (str): specifies the algorithm used as debayer; choose from {'bilinear','malvar2004','menon2007'}
|
45 |
+
sharpening (str): specifies the algorithm used for sharpening; choose from {'sharpening_filter','unsharp_masking'}
|
46 |
+
denoising (str): specifies the algorithm used for denoising; choose from choose from {'gaussian_denoising','median_denoising','fft_denoising'}
|
47 |
+
'''
|
48 |
+
|
49 |
+
self.camera_parameters = camera_parameters
|
50 |
+
|
51 |
+
self.debayer = debayer
|
52 |
+
self.sharpening = sharpening
|
53 |
+
self.denoising = denoising
|
54 |
+
|
55 |
+
def __call__(self, img):
|
56 |
+
"""
|
57 |
+
Args:
|
58 |
+
img (ndarry of dtype float.32): image of size (H,W)
|
59 |
+
return:
|
60 |
+
img (tensor of dtype float): image of size (3,H,W)
|
61 |
+
"""
|
62 |
+
black_level, white_balance, colour_matrix = self.camera_parameters
|
63 |
+
img = processing(img, black_level, white_balance, colour_matrix,
|
64 |
+
debayer=self.debayer, sharpening=self.sharpening, denoising=self.denoising)
|
65 |
+
img = img.transpose(2, 0, 1)
|
66 |
+
|
67 |
+
return torch.Tensor(img)
|
68 |
+
|
69 |
+
|
70 |
+
def processing(img, black_level, white_balance, colour_matrix, debayer="bilinear", sharpening="unsharp_masking",
|
71 |
+
sharp_radius=1.0, sharp_amount=1.0, denoising="median_filter", median_kernel_size=3,
|
72 |
+
gaussian_sigma=0.5, fft_fraction=0.3, weight_chambolle=0.01, weight_bregman=100,
|
73 |
+
sigma_bilateral=0.6, gamma=2.2, bits=16):
|
74 |
+
"""Apply pipeline on a raw image
|
75 |
+
|
76 |
+
Args:
|
77 |
+
rawImg (ndarray): raw image
|
78 |
+
debayer (str): debayer algorithm
|
79 |
+
white_balance (None, ndarray): white balance array (if None it will take the default camera white balance array)
|
80 |
+
colour_matrix (None, ndarray): colour matrix (if None it will take the default camera colour matrix) - Size: 3x3
|
81 |
+
gamma (float): exponent for the non linear gamma correction.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
img (ndarray): post-processed image
|
85 |
+
|
86 |
+
"""
|
87 |
+
|
88 |
+
# Remove Black Level
|
89 |
+
img = remove_blacklv(img, black_level)
|
90 |
+
|
91 |
+
# Apply demosaicing - We don't have access to these 3 functions
|
92 |
+
if debayer == "bilinear":
|
93 |
+
img = demosaicing_CFA_Bayer_bilinear(img)
|
94 |
+
if debayer == "malvar2004":
|
95 |
+
img = demosaicing_CFA_Bayer_Malvar2004(img)
|
96 |
+
if debayer == "menon2007":
|
97 |
+
img = demosaicing_CFA_Bayer_Menon2007(img)
|
98 |
+
|
99 |
+
# White Balance Correction
|
100 |
+
|
101 |
+
# Sunny images white balance array -> 2<r<2.8, g=1.0, 1.3<b<1.6
|
102 |
+
# Tungsten images white balance array -> 1.3<r<1.7, g=1.0, 2.2<b<2.8
|
103 |
+
# Shade images white balance array -> 2.4<r<3.2, g=1.0, 1.1<b<1.3
|
104 |
+
|
105 |
+
img = wb_correction(img, white_balance)
|
106 |
+
|
107 |
+
# Colour Correction
|
108 |
+
img = colour_correction(img, colour_matrix)
|
109 |
+
|
110 |
+
# Sharpening
|
111 |
+
if sharpening == "sharpening_filter": # Fixed sharpening
|
112 |
+
img = sharpening_filter(img)
|
113 |
+
if sharpening == "unsharp_masking": # Higher is radius and amount, higher is the sharpening
|
114 |
+
img = unsharp_masking(img, radius=sharp_radius, amount=sharp_amount, multichannel=True)
|
115 |
+
|
116 |
+
# Denoising
|
117 |
+
if denoising == "median_denoising":
|
118 |
+
img = median_denoising(img, size=median_kernel_size)
|
119 |
+
if denoising == "gaussian_denoising":
|
120 |
+
img = gaussian_denoising(img, sigma=gaussian_sigma)
|
121 |
+
if denoising == "fft_denoising": # fft_fraction = [0.0001,0.5]
|
122 |
+
img = fft_denoising(img, keep_fraction=fft_fraction, row_cut=False, column_cut=True)
|
123 |
+
|
124 |
+
# We don't have access to these 3 functions
|
125 |
+
if denoising == "tv_chambolle": # lower is weight, less is the denoising
|
126 |
+
img = denoise_tv_chambolle(img, weight=weight_chambolle, eps=0.0002, n_iter_max=200, multichannel=True)
|
127 |
+
if denoising == "tv_bregman": # lower is weight, more is the denoising
|
128 |
+
img = denoise_tv_bregman(img, weight=weight_bregman, max_iter=100,
|
129 |
+
eps=0.001, isotropic=True, multichannel=True)
|
130 |
+
# if denoising == "wavelet":
|
131 |
+
# img = denoise_wavelet(img.copy(), sigma=None, wavelet='db1', mode='soft', wavelet_levels=None, multichannel=True,
|
132 |
+
# convert2ycbcr=False, method='BayesShrink', rescale_sigma=True)
|
133 |
+
if denoising == "bilateral": # higher is sigma_spatial, more is the denoising
|
134 |
+
img = denoise_bilateral(img, win_size=None, sigma_color=None, sigma_spatial=sigma_bilateral,
|
135 |
+
bins=10000, mode='constant', cval=0, multichannel=True)
|
136 |
+
|
137 |
+
# Gamma Correction
|
138 |
+
img = np.clip(img, 0, 1)
|
139 |
+
img = adjust_gamma(img, gamma=gamma)
|
140 |
+
|
141 |
+
return img
|
142 |
+
|
143 |
+
|
144 |
+
def get_camera_parameters(rawpyImg):
|
145 |
+
black_level = rawpyImg.black_level_per_channel
|
146 |
+
white_balance = rawpyImg.camera_whitebalance[:3]
|
147 |
+
colour_matrix = rawpyImg.color_matrix[:, :3].flatten().tolist()
|
148 |
+
|
149 |
+
return black_level, white_balance, colour_matrix
|
150 |
+
|
151 |
+
|
152 |
+
def remove_blacklv(rawImg, black_level):
|
153 |
+
rawImg[0::2, 0::2] -= black_level[0] # R
|
154 |
+
rawImg[0::2, 1::2] -= black_level[1] # G
|
155 |
+
rawImg[1::2, 0::2] -= black_level[2] # G
|
156 |
+
rawImg[1::2, 1::2] -= black_level[3] # B
|
157 |
+
|
158 |
+
return rawImg
|
159 |
+
|
160 |
+
|
161 |
+
def wb_correction(img, white_balance):
|
162 |
+
return img * white_balance
|
163 |
+
|
164 |
+
|
165 |
+
def colour_correction(img, colour_matrix):
|
166 |
+
colour_matrix = np.array(colour_matrix).reshape(3, 3)
|
167 |
+
return np.einsum('ijk,lk->ijl', img, colour_matrix)
|
168 |
+
|
169 |
+
|
170 |
+
def unsharp_masking(img, radius=1.0, amount=1.0,
|
171 |
+
multichannel=False, preserve_range=True):
|
172 |
+
|
173 |
+
img = rgb2yuv(img)
|
174 |
+
img[:, :, 0] = unsharp_mask(img[:, :, 0], radius=radius, amount=amount,
|
175 |
+
multichannel=multichannel, preserve_range=preserve_range)
|
176 |
+
img = yuv2rgb(img)
|
177 |
+
return img
|
178 |
+
|
179 |
+
|
180 |
+
def sharpening_filter(image, iterations=1, kernel=np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])):
|
181 |
+
|
182 |
+
# https://towardsdatascience.com/image-processing-with-python-blurring-and-sharpening-for-beginners-3bcebec0583a
|
183 |
+
|
184 |
+
img_yuv = rgb2yuv(image)
|
185 |
+
|
186 |
+
for i in range(iterations):
|
187 |
+
img_yuv[:, :, 0] = convolve2d(img_yuv[:, :, 0], kernel, 'same', boundary='fill', fillvalue=0)
|
188 |
+
|
189 |
+
final_image = yuv2rgb(img_yuv)
|
190 |
+
|
191 |
+
return final_image
|
192 |
+
|
193 |
+
|
194 |
+
def median_denoising(img, size=3):
|
195 |
+
|
196 |
+
img = rgb2yuv(img)
|
197 |
+
img[:, :, 0] = ndimage.median_filter(img[:, :, 0], size)
|
198 |
+
img = yuv2rgb(img)
|
199 |
+
|
200 |
+
return img
|
201 |
+
|
202 |
+
|
203 |
+
def gaussian_denoising(img, sigma=0.5):
|
204 |
+
|
205 |
+
img = rgb2yuv(img)
|
206 |
+
img[:, :, 0] = ndimage.gaussian_filter(img[:, :, 0], sigma)
|
207 |
+
img = yuv2rgb(img)
|
208 |
+
|
209 |
+
return img
|
210 |
+
|
211 |
+
|
212 |
+
def fft_denoising(img, keep_fraction=0.3, row_cut=False, column_cut=True):
|
213 |
+
""" keep_fraction = 0.5 --> same image as input
|
214 |
+
keep_fraction --> 0 --> remove all details """
|
215 |
+
# http://scipy-lectures.org/intro/scipy/auto_examples/solutions/plot_fft_image_denoise.html
|
216 |
+
|
217 |
+
im_fft = fftpack.fft2(img)
|
218 |
+
|
219 |
+
# Call ff a copy of the original transform. Numpy arrays have a copy
|
220 |
+
# method for this purpose.
|
221 |
+
im_fft2 = im_fft
|
222 |
+
|
223 |
+
# Set r and c to be the number of rows and columns of the array.
|
224 |
+
r, c, _ = im_fft2.shape
|
225 |
+
|
226 |
+
# Set to zero all rows with indices between r*keep_fraction and r*(1-keep_fraction):
|
227 |
+
if row_cut == True:
|
228 |
+
im_fft2[int(r * keep_fraction):int(r * (1 - keep_fraction))] = 0
|
229 |
+
|
230 |
+
# Similarly with the columns:
|
231 |
+
if column_cut == True:
|
232 |
+
im_fft2[:, int(c * keep_fraction):int(c * (1 - keep_fraction))] = 0
|
233 |
+
|
234 |
+
# Reconstruct the denoised image from the filtered spectrum, keep only the
|
235 |
+
# real part for display.
|
236 |
+
im_new = fftpack.ifft2(im_fft2).real
|
237 |
+
|
238 |
+
return im_new
|
239 |
+
|
240 |
+
|
241 |
+
def adjust_gamma(img, gamma=1.0):
|
242 |
+
invGamma = 1.0 / gamma
|
243 |
+
img = (img ** invGamma)
|
244 |
+
return img
|
245 |
+
|
246 |
+
|
247 |
+
def show_img(img, title="no_title", size=12, histo=True, bins=300, bits=16, x_range=-1):
|
248 |
+
"""Plot image and its histogram
|
249 |
+
|
250 |
+
Args:
|
251 |
+
img (ndarray): image to plot
|
252 |
+
title (str): title of the plot
|
253 |
+
histo (bool): True - Plot histrograms per channel of the image. False - Plot the curve of histogram in a continue way
|
254 |
+
bins (int): number of bins of the histogram
|
255 |
+
size (int): figure size
|
256 |
+
bits (int): number of bits per pixel in the ndarray
|
257 |
+
x_range (list): maximum x range of the histogram (if -1 it will be take all x values)
|
258 |
+
"""
|
259 |
+
shape = img.shape
|
260 |
+
|
261 |
+
fig = plt.figure(figsize=(size, size))
|
262 |
+
|
263 |
+
# show original image
|
264 |
+
fig.add_subplot(221)
|
265 |
+
if len(shape) > 2 and img.max() > 255:
|
266 |
+
img_to_show = (img.copy() * 255. / (2**bits - 1)).astype(int)
|
267 |
+
else:
|
268 |
+
img_to_show = img.copy().astype(int)
|
269 |
+
plt.imshow(img_to_show)
|
270 |
+
if title != "no_title":
|
271 |
+
plt.title(title)
|
272 |
+
|
273 |
+
fig.add_subplot(222)
|
274 |
+
|
275 |
+
if len(shape) > 2:
|
276 |
+
if histo == True:
|
277 |
+
plt.hist(img[:, :, 0].flatten(), bins=bins, label="Channel1", color="red", alpha=0.5)
|
278 |
+
plt.hist(img[:, :, 1].flatten(), bins=bins, label="Channel2", color="green", alpha=0.5)
|
279 |
+
plt.hist(img[:, :, 2].flatten(), bins=bins, label="Channel3", color="blue", alpha=0.5)
|
280 |
+
if x_range != -1:
|
281 |
+
plt.xlim([x_range[0], x_range[1]])
|
282 |
+
else:
|
283 |
+
h1, b1 = np.histogram(img[:, :, 0].flatten(), bins=bins)
|
284 |
+
h2, b2 = np.histogram(img[:, :, 1].flatten(), bins=bins)
|
285 |
+
h3, b3 = np.histogram(img[:, :, 2].flatten(), bins=bins)
|
286 |
+
plt.plot(b1[:-1], h1, label="Channel1", color="red", alpha=0.5)
|
287 |
+
plt.plot(b2[:-1], h2, label="Channel2", color="green", alpha=0.5)
|
288 |
+
plt.plot(b3[:-1], h3, label="Channel3", color="blue", alpha=0.5)
|
289 |
+
|
290 |
+
plt.legend()
|
291 |
+
else:
|
292 |
+
if histo == True:
|
293 |
+
plt.hist(img.flatten(), bins=bins)
|
294 |
+
if x_range != -1:
|
295 |
+
plt.xlim([x_range[0], x_range[1]])
|
296 |
+
else:
|
297 |
+
h, b = np.histogram(img.flatten(), bins=bins)
|
298 |
+
plt.plot(b[:-1], h)
|
299 |
+
|
300 |
+
plt.xlabel("Intensities")
|
301 |
+
plt.ylabel("Counts")
|
302 |
+
|
303 |
+
plt.show()
|
304 |
+
|
305 |
+
|
306 |
+
def get_statistics(dataset, train_indices, transform=None):
|
307 |
+
"""Calculates the mean and the standard deviation of a given sub train set of dataset
|
308 |
+
|
309 |
+
Args:
|
310 |
+
dataset (Subset of DroneDataset):
|
311 |
+
train_indices (tensor): indicies correponding to a subset of the dataset
|
312 |
+
transform (Compose): list of transformations compatible with Compose to be applied before calculations
|
313 |
+
return:
|
314 |
+
mean (tensor of dtype float): size (C,1,1)
|
315 |
+
std (tensor of dtype float): size (C,1,1)
|
316 |
+
"""
|
317 |
+
|
318 |
+
trainset = Subset(dataset, indices=train_indices, transform=transform)
|
319 |
+
dataloader = DataLoader(trainset, batch_size=len(trainset), shuffle=False)
|
320 |
+
dataiter = iter(dataloader)
|
321 |
+
|
322 |
+
images, labels = dataiter.next()
|
323 |
+
|
324 |
+
if len(images.shape) == 3:
|
325 |
+
mean, std = torch.mean(images, axis=(0, 1, 2)), torch.std(images, axis=(0, 1, 2))
|
326 |
+
return mean, std
|
327 |
+
else:
|
328 |
+
mean, std = torch.mean(images, axis=(0, 2, 3))[:, None, None], torch.std(images, axis=(0, 2, 3))[:, None, None]
|
329 |
+
return mean, std
|
processingpipeline/torch_pipeline.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from numpy.lib.function_base import interp
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
if not os.path.exists('README.md'):
|
6 |
+
os.chdir('..')
|
7 |
+
|
8 |
+
from processingpipeline.pipeline import processing as default_processing
|
9 |
+
from utils.base import np2torch, torch2np
|
10 |
+
|
11 |
+
import segmentation_models_pytorch as smp
|
12 |
+
|
13 |
+
from utils.debug import debug
|
14 |
+
|
15 |
+
K_G = torch.Tensor([[0, 1, 0],
|
16 |
+
[1, 4, 1],
|
17 |
+
[0, 1, 0]]) / 4
|
18 |
+
|
19 |
+
K_RB = torch.Tensor([[1, 2, 1],
|
20 |
+
[2, 4, 2],
|
21 |
+
[1, 2, 1]]) / 4
|
22 |
+
|
23 |
+
M_RGB_2_YUV = torch.Tensor([[0.299, 0.587, 0.114],
|
24 |
+
[-0.14714119, -0.28886916, 0.43601035],
|
25 |
+
[0.61497538, -0.51496512, -0.10001026]])
|
26 |
+
M_YUV_2_RGB = torch.Tensor([[1.0000000000e+00, -4.1827794561e-09, 1.1398830414e+00],
|
27 |
+
[1.0000000000e+00, -3.9464232326e-01, -5.8062183857e-01],
|
28 |
+
[1.0000000000e+00, 2.0320618153e+00, -1.2232658220e-09]])
|
29 |
+
|
30 |
+
K_BLUR = torch.Tensor([[6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08],
|
31 |
+
[2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05],
|
32 |
+
[2.0755e-04, 8.3731e-02, 6.1869e-01, 8.3731e-02, 2.0755e-04],
|
33 |
+
[2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05],
|
34 |
+
[6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08]])
|
35 |
+
K_SHARP = torch.Tensor([[0, -1, 0],
|
36 |
+
[-1, 5, -1],
|
37 |
+
[0, -1, 0]])
|
38 |
+
DEFAULT_CAMERA_PARAMS = (
|
39 |
+
[0., 0., 0., 0.],
|
40 |
+
[1., 1., 1.],
|
41 |
+
[1., 0., 0., 0., 1., 0., 0., 0., 1.],
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
class RawToRGB(nn.Module):
|
46 |
+
def __init__(self, reduce_size=True, out_channels=3, track_stages=False, normalize_mosaic=None):
|
47 |
+
super().__init__()
|
48 |
+
self.stages = None
|
49 |
+
self.buffer = None
|
50 |
+
self.reduce_size = reduce_size
|
51 |
+
self.out_channels = out_channels
|
52 |
+
self.track_stages = track_stages
|
53 |
+
self.normalize_mosaic = normalize_mosaic
|
54 |
+
|
55 |
+
def forward(self, raw):
|
56 |
+
self.stages = {}
|
57 |
+
self.buffer = {}
|
58 |
+
|
59 |
+
rgb = raw2rgb(raw, reduce_size=self.reduce_size, out_channels=self.out_channels)
|
60 |
+
self.stages['demosaic'] = rgb
|
61 |
+
if self.normalize_mosaic:
|
62 |
+
rgb = self.normalize_mosaic(rgb)
|
63 |
+
|
64 |
+
if self.track_stages and raw.requires_grad:
|
65 |
+
for stage in self.stages.values():
|
66 |
+
stage.retain_grad()
|
67 |
+
|
68 |
+
self.buffer['processed_rgb'] = rgb
|
69 |
+
|
70 |
+
return rgb
|
71 |
+
|
72 |
+
|
73 |
+
class NNProcessing(nn.Module):
|
74 |
+
def __init__(self, track_stages=False, normalize_mosaic=None, batch_norm_output=True):
|
75 |
+
super().__init__()
|
76 |
+
self.stages = None
|
77 |
+
self.buffer = None
|
78 |
+
self.track_stages = track_stages
|
79 |
+
self.model = smp.UnetPlusPlus(
|
80 |
+
encoder_name='resnet34',
|
81 |
+
encoder_depth=3,
|
82 |
+
decoder_channels=[256, 128, 64],
|
83 |
+
in_channels=3,
|
84 |
+
classes=3,
|
85 |
+
)
|
86 |
+
self.batch_norm = None if not batch_norm_output else nn.BatchNorm2d(3)
|
87 |
+
self.normalize_mosaic = normalize_mosaic
|
88 |
+
|
89 |
+
def forward(self, raw):
|
90 |
+
self.stages = {}
|
91 |
+
self.buffer = {}
|
92 |
+
# self.stages['raw'] = raw
|
93 |
+
rgb = raw2rgb(raw)
|
94 |
+
if self.normalize_mosaic:
|
95 |
+
rgb = self.normalize_mosaic(rgb)
|
96 |
+
self.stages['demosaic'] = rgb
|
97 |
+
rgb = self.model(rgb)
|
98 |
+
if self.batch_norm is not None:
|
99 |
+
rgb = self.batch_norm(rgb)
|
100 |
+
self.stages['rgb'] = rgb
|
101 |
+
|
102 |
+
if self.track_stages and raw.requires_grad:
|
103 |
+
for stage in self.stages.values():
|
104 |
+
stage.retain_grad()
|
105 |
+
|
106 |
+
self.buffer['processed_rgb'] = rgb
|
107 |
+
|
108 |
+
return rgb
|
109 |
+
|
110 |
+
|
111 |
+
class ParametrizedProcessing(nn.Module):
|
112 |
+
def __init__(self, camera_parameters, track_stages=False, batch_norm_output=True, noise_layer=False):
|
113 |
+
super().__init__()
|
114 |
+
self.stages = None
|
115 |
+
self.buffer = None
|
116 |
+
self.track_stages = track_stages
|
117 |
+
|
118 |
+
black_level, white_balance, colour_matrix = camera_parameters
|
119 |
+
self.register_buffer('black_level', torch.as_tensor(black_level))
|
120 |
+
self.register_buffer('colour_correction',
|
121 |
+
torch.as_tensor(white_balance).reshape(1, 3)
|
122 |
+
* torch.as_tensor(colour_matrix).reshape(3, 3))
|
123 |
+
self.register_buffer('M_RGB_2_YUV', M_RGB_2_YUV.clone())
|
124 |
+
self.register_buffer('M_YUV_2_RGB', M_YUV_2_RGB.clone())
|
125 |
+
|
126 |
+
self.gamma_correct = nn.Parameter(torch.Tensor([2.2]))
|
127 |
+
|
128 |
+
self.debayer = Debayer()
|
129 |
+
|
130 |
+
self.sharpening_filter = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
|
131 |
+
self.sharpening_filter.weight.data[0][0] = K_SHARP.clone()
|
132 |
+
|
133 |
+
self.gaussian_blur = nn.Conv2d(1, 1, kernel_size=5, padding=2, padding_mode='reflect', bias=False)
|
134 |
+
self.gaussian_blur.weight.data[0][0] = K_BLUR.clone()
|
135 |
+
|
136 |
+
self.batch_norm = nn.BatchNorm2d(3) if batch_norm_output else None
|
137 |
+
|
138 |
+
# if noise_layer:
|
139 |
+
# for param in self.parameters():
|
140 |
+
# param.requires_grad = False
|
141 |
+
|
142 |
+
self.additive_layer = nn.Parameter(0.001 * torch.randn((1, 3, 256, 256))
|
143 |
+
) if noise_layer else None # XXX: can this be 0?
|
144 |
+
|
145 |
+
def forward(self, raw):
|
146 |
+
assert raw.ndim == 3, f"needs dims (B, H, W), got {raw.shape}"
|
147 |
+
|
148 |
+
self.stages = {}
|
149 |
+
self.buffer = {}
|
150 |
+
|
151 |
+
# self.stages['raw'] = raw
|
152 |
+
|
153 |
+
rgb = raw2rgb(raw, black_level=self.black_level, reduce_size=False)
|
154 |
+
rgb = rgb.contiguous()
|
155 |
+
self.stages['demosaic'] = rgb
|
156 |
+
|
157 |
+
rgb = self.debayer(rgb)
|
158 |
+
# self.stages['debayer'] = rgb
|
159 |
+
|
160 |
+
rgb = torch.einsum('bchw,kc->bkhw', rgb, self.colour_correction).contiguous()
|
161 |
+
self.stages['color_correct'] = rgb
|
162 |
+
|
163 |
+
yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous()
|
164 |
+
yuv[:, [0], ...] = self.sharpening_filter(yuv[:, [0], ...])
|
165 |
+
|
166 |
+
if self.track_stages: # keep stage in computational graph for grad information
|
167 |
+
rgb = torch.einsum('bchw,kc->bkhw', yuv.clone(), self.M_YUV_2_RGB).contiguous()
|
168 |
+
self.stages['sharpening'] = rgb
|
169 |
+
yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous()
|
170 |
+
|
171 |
+
yuv[:, [0], ...] = self.gaussian_blur(yuv[:, [0], ...])
|
172 |
+
rgb = torch.einsum('bchw,kc->bkhw', yuv, self.M_YUV_2_RGB).contiguous()
|
173 |
+
self.stages['gaussian'] = rgb
|
174 |
+
|
175 |
+
rgb = torch.clip(rgb, 1e-5, 1)
|
176 |
+
self.stages['clipped'] = rgb
|
177 |
+
|
178 |
+
rgb = torch.exp((1 / self.gamma_correct) * torch.log(rgb))
|
179 |
+
self.stages['gamma_correct'] = rgb
|
180 |
+
|
181 |
+
if self.additive_layer is not None:
|
182 |
+
# rgb = rgb + 0 * self.additive_layer
|
183 |
+
rgb = rgb + self.additive_layer
|
184 |
+
self.stages['noise'] = rgb
|
185 |
+
|
186 |
+
if self.batch_norm is not None:
|
187 |
+
rgb = self.batch_norm(rgb)
|
188 |
+
|
189 |
+
if self.track_stages and raw.requires_grad:
|
190 |
+
for stage in self.stages.values():
|
191 |
+
stage.retain_grad()
|
192 |
+
|
193 |
+
self.buffer['processed_rgb'] = rgb
|
194 |
+
|
195 |
+
return rgb
|
196 |
+
|
197 |
+
|
198 |
+
class Debayer(nn.Conv2d):
|
199 |
+
def __init__(self):
|
200 |
+
super().__init__(3, 3, kernel_size=3, padding=1, padding_mode='reflect', bias=False) # default_pipeline uses 'replicate'
|
201 |
+
self.weight.data.fill_(0)
|
202 |
+
self.weight.data[0, 0] = K_RB.clone()
|
203 |
+
self.weight.data[1, 1] = K_G.clone()
|
204 |
+
self.weight.data[2, 2] = K_RB.clone()
|
205 |
+
|
206 |
+
|
207 |
+
def raw2rgb(raw, black_level=None, reduce_size=True, out_channels=3):
|
208 |
+
"""transform raw image with 1 channel to rgb with 3 channels
|
209 |
+
Args:
|
210 |
+
raw (Tensor): raw Tensor of shape (B, H, W)
|
211 |
+
black_level (iterable, optional): RGGB black level values to subtract
|
212 |
+
reduce_size (bool, optional): if False, the output image will have the same height and width
|
213 |
+
as the raw input, i.e. (B, C, H, W), empty values are filled with zeros.
|
214 |
+
if True, the output dimensions are reduced by half (B, C, H//2, W//2),
|
215 |
+
the two green channels are averaged.
|
216 |
+
out_channels (int, optional): number of output channels. One of {3, 4}.
|
217 |
+
"""
|
218 |
+
assert out_channels in [3, 4]
|
219 |
+
if black_level is None:
|
220 |
+
black_level = [0, 0, 0, 0]
|
221 |
+
Bch, H, W = raw.shape
|
222 |
+
R = raw[:, 0::2, 0::2] - black_level[0] # R
|
223 |
+
G1 = raw[:, 0::2, 1::2] - black_level[1] # G
|
224 |
+
G2 = raw[:, 1::2, 0::2] - black_level[2] # G
|
225 |
+
B = raw[:, 1::2, 1::2] - black_level[3] # B
|
226 |
+
if reduce_size:
|
227 |
+
rgb = torch.zeros((Bch, out_channels, H // 2, W // 2), device=raw.device)
|
228 |
+
if out_channels == 3:
|
229 |
+
rgb[:, 0, :, :] = R
|
230 |
+
rgb[:, 1, :, :] = (G1 + G2) / 2
|
231 |
+
rgb[:, 2, :, :] = B
|
232 |
+
elif out_channels == 4:
|
233 |
+
rgb[:, 0, :, :] = R
|
234 |
+
rgb[:, 1, :, :] = G1
|
235 |
+
rgb[:, 2, :, :] = G2
|
236 |
+
rgb[:, 3, :, :] = B
|
237 |
+
else:
|
238 |
+
rgb = torch.zeros((Bch, out_channels, H, W), device=raw.device)
|
239 |
+
if out_channels == 3:
|
240 |
+
rgb[:, 0, 0::2, 0::2] = R
|
241 |
+
rgb[:, 1, 0::2, 1::2] = G1
|
242 |
+
rgb[:, 1, 1::2, 0::2] = G2
|
243 |
+
rgb[:, 2, 1::2, 1::2] = B
|
244 |
+
elif out_channels == 4:
|
245 |
+
rgb[:, 0, 0::2, 0::2] = R
|
246 |
+
rgb[:, 1, 0::2, 1::2] = G1
|
247 |
+
rgb[:, 2, 1::2, 0::2] = G2
|
248 |
+
rgb[:, 3, 1::2, 1::2] = B
|
249 |
+
return rgb
|
250 |
+
|
251 |
+
|
252 |
+
# pipeline validation
|
253 |
+
if __name__ == "__main__":
|
254 |
+
|
255 |
+
import torch
|
256 |
+
import numpy as np
|
257 |
+
|
258 |
+
if not os.path.exists('README.md'):
|
259 |
+
os.chdir('..')
|
260 |
+
|
261 |
+
import matplotlib.pyplot as plt
|
262 |
+
from utils.dataset import get_dataset
|
263 |
+
from utils.base import np2torch, torch2np
|
264 |
+
|
265 |
+
from utils.debug import debug
|
266 |
+
from processingpipeline.pipeline import processing as default_processing
|
267 |
+
|
268 |
+
raw_dataset = get_dataset('DS')
|
269 |
+
loader = torch.utils.data.DataLoader(raw_dataset, batch_size=1)
|
270 |
+
batch_raw, batch_mask = next(iter(loader))
|
271 |
+
|
272 |
+
# torch proc
|
273 |
+
camera_parameters = raw_dataset.camera_parameters
|
274 |
+
black_level = camera_parameters[0]
|
275 |
+
|
276 |
+
proc = ParametrizedProcessing(camera_parameters)
|
277 |
+
|
278 |
+
batch_rgb = proc(batch_raw)
|
279 |
+
rgb = batch_rgb[0]
|
280 |
+
|
281 |
+
# numpy proc
|
282 |
+
raw_img = batch_raw[0]
|
283 |
+
numpy_raw = torch2np(raw_img)
|
284 |
+
|
285 |
+
default_rgb = default_processing(numpy_raw, *camera_parameters,
|
286 |
+
sharpening='sharpening_filter', denoising='gaussian_denoising')
|
287 |
+
|
288 |
+
rgb_valid = np2torch(default_rgb)
|
289 |
+
|
290 |
+
print("pipeline norm difference:", (rgb - rgb_valid).norm().item())
|
291 |
+
|
292 |
+
rgb_mosaic = raw2rgb(batch_raw, reduce_size=False).squeeze()
|
293 |
+
rgb_reduced = raw2rgb(batch_raw, reduce_size=True).squeeze()
|
294 |
+
|
295 |
+
plt.figure(figsize=(16, 8))
|
296 |
+
plt.subplot(151)
|
297 |
+
plt.title('Raw')
|
298 |
+
plt.imshow(torch2np(raw_img))
|
299 |
+
plt.subplot(152)
|
300 |
+
plt.title('RGB Mosaic')
|
301 |
+
plt.imshow(torch2np(rgb_mosaic))
|
302 |
+
plt.subplot(153)
|
303 |
+
plt.title('RGB Reduced')
|
304 |
+
plt.imshow(torch2np(rgb_reduced))
|
305 |
+
plt.subplot(154)
|
306 |
+
plt.title('Torch Pipeline')
|
307 |
+
plt.imshow(torch2np(rgb))
|
308 |
+
plt.subplot(155)
|
309 |
+
plt.title('Default Pipeline')
|
310 |
+
plt.imshow(torch2np(rgb_valid))
|
311 |
+
plt.show()
|
312 |
+
|
313 |
+
# assert rgb.allclose(rgb_valid)
|
sanity_checks_and_statistics.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:53f62c6ce9a6656a31c3e0ae1deded2e4f9818cd891381dbe1030dd5edc5f278
|
3 |
+
size 6103871
|
show_classification_results.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:62c5dbc4bb22ecd26bc691c1f574b6bcf07b7cd48f62668506955df3513afe55
|
3 |
+
size 10556940
|
show_results.sh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
datasets='Microscopy Drone'
|
4 |
+
augmentations='weak strong none'
|
5 |
+
|
6 |
+
for augment in $augmentations
|
7 |
+
do
|
8 |
+
for data in $datasets
|
9 |
+
do
|
10 |
+
|
11 |
+
python show_results.py \
|
12 |
+
--dataset $data \
|
13 |
+
--augmentation $augment \
|
14 |
+
|
15 |
+
done
|
16 |
+
done
|
train.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import copy
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
import mlflow.pytorch
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from torchvision.models import resnet18
|
12 |
+
import torchvision.transforms as T
|
13 |
+
from pytorch_lightning.metrics.functional import accuracy
|
14 |
+
import pytorch_lightning as pl
|
15 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
16 |
+
|
17 |
+
from utils.base import display_mlflow_run_info, str2bool, fetch_from_mlflow, get_name, data_loader_mean_and_std
|
18 |
+
from utils.debug import debug
|
19 |
+
from utils.augmentation import get_augmentation
|
20 |
+
from utils.dataset import Subset, get_dataset, k_fold
|
21 |
+
|
22 |
+
from processingpipeline.pipeline import RawProcessingPipeline
|
23 |
+
from processingpipeline.torch_pipeline import raw2rgb, RawToRGB, ParametrizedProcessing, NNProcessing
|
24 |
+
|
25 |
+
from models.classifier import log_tensor, resnet_model, LitModel, TrackImagesCallback
|
26 |
+
|
27 |
+
import segmentation_models_pytorch as smp
|
28 |
+
|
29 |
+
from utils.pytorch_ssim import SSIM
|
30 |
+
|
31 |
+
# args to set up task
|
32 |
+
parser = argparse.ArgumentParser(description="classification_task")
|
33 |
+
parser.add_argument("--tracking_uri", type=str,
|
34 |
+
default="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com", help='URI of the mlflow server on AWS')
|
35 |
+
parser.add_argument("--processor_uri", type=str, default=None,
|
36 |
+
help='URI of the processing model (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/processing-model)')
|
37 |
+
parser.add_argument("--classifier_uri", type=str, default=None,
|
38 |
+
help='URI of the net (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/prediction-model)')
|
39 |
+
parser.add_argument("--state_dict_uri", type=str,
|
40 |
+
default=None, help='URI of the indices you want to load (e.g. s3://mlflow-artifacts-601883093460/7/4326da05aca54107be8c554de0674a14/artifacts/training')
|
41 |
+
|
42 |
+
parser.add_argument("--experiment_name", type=str,
|
43 |
+
default='classification learnable pipeline', help='Specify the experiment you are running, e.g. end2end segmentation')
|
44 |
+
parser.add_argument("--run_name", type=str,
|
45 |
+
default='test run', help='Specify the name of your run')
|
46 |
+
|
47 |
+
parser.add_argument("--log_model", type=str2bool, default=True, help='Enables model logging')
|
48 |
+
parser.add_argument("--save_locally", action='store_true',
|
49 |
+
help='Model will be saved locally if action is taken') # TODO: bypass mlflow
|
50 |
+
|
51 |
+
parser.add_argument("--track_processing", action='store_true',
|
52 |
+
help='Save images after each trasformation of the pipeline for the test set')
|
53 |
+
parser.add_argument("--track_processing_gradients", action='store_true',
|
54 |
+
help='Save images of gradients after each trasformation of the pipeline for the test set')
|
55 |
+
parser.add_argument("--track_save_tensors", action='store_true',
|
56 |
+
help='Save the torch tensors after each trasformation of the pipeline for the test set')
|
57 |
+
parser.add_argument("--track_predictions", action='store_true',
|
58 |
+
help='Save images after each trasformation of the pipeline for the test set + input gradient')
|
59 |
+
parser.add_argument("--track_n_images", default=5,
|
60 |
+
help='Track the n first elements of dataset. Only used for args.track_processing=True')
|
61 |
+
parser.add_argument("--track_every_epoch", action='store_true', help='Track images every epoch or once after training')
|
62 |
+
|
63 |
+
# args to create dataset
|
64 |
+
parser.add_argument("--seed", type=int, default=1, help='Global seed')
|
65 |
+
parser.add_argument("--dataset", type=str, default='Microscopy',
|
66 |
+
choices=["Drone", "DroneSegmentation", "Microscopy"], help='Select dataset')
|
67 |
+
|
68 |
+
parser.add_argument("--n_splits", type=int, default=1, help='Number of splits used for training')
|
69 |
+
parser.add_argument("--train_size", type=float, default=0.8, help='Fraction of training points in dataset')
|
70 |
+
|
71 |
+
# args for training
|
72 |
+
parser.add_argument("--lr", type=float, default=1e-5, help="learning rate used for training")
|
73 |
+
parser.add_argument("--epochs", type=int, default=3, help="numper of epochs")
|
74 |
+
parser.add_argument("--batch_size", type=int, default=32, help="Training batch size")
|
75 |
+
parser.add_argument("--augmentation", type=str, default='none',
|
76 |
+
choices=["none", "weak", "strong"], help="Applies augmentation to training")
|
77 |
+
parser.add_argument("--augmentation_on_valid_epoch", action='store_true',
|
78 |
+
help='Track images every epoch or once after training') # TODO: implement, actually should be disabled by default for 'val' and 'test
|
79 |
+
parser.add_argument("--check_val_every_n_epoch", type=int, default=1)
|
80 |
+
|
81 |
+
# args to specify the processing
|
82 |
+
parser.add_argument("--processing_mode", type=str, default="parametrized",
|
83 |
+
choices=["parametrized", "static", "neural_network", "none"],
|
84 |
+
help="Which type of raw to rgb processing should be used")
|
85 |
+
|
86 |
+
# args to specify model
|
87 |
+
parser.add_argument("--classifier_network", type=str, default='ResNet18',
|
88 |
+
help='Type of pretrained network') # TODO: implement different choices
|
89 |
+
parser.add_argument("--classifier_pretrained", action='store_true',
|
90 |
+
help='Whether to use a pre-trained model or not')
|
91 |
+
parser.add_argument("--smp_encoder", type=str, default='resnet34', help='segmentation model encoder')
|
92 |
+
|
93 |
+
parser.add_argument("--freeze_processor", action='store_true', help="Freeze raw to rgb processing model weights")
|
94 |
+
parser.add_argument("--freeze_classifier", action='store_true', help="Freeze classification model weights")
|
95 |
+
|
96 |
+
# args to specify static pipeline transformations
|
97 |
+
parser.add_argument("--sp_debayer", type=str, default='bilinear',
|
98 |
+
choices=['bilinear', 'malvar2004', 'menon2007'], help="Specify algorithm used as debayer")
|
99 |
+
parser.add_argument("--sp_sharpening", type=str, default='sharpening_filter',
|
100 |
+
choices=['sharpening_filter', 'unsharp_masking'], help="Specify algorithm used for sharpening")
|
101 |
+
parser.add_argument("--sp_denoising", type=str, default='gaussian_denoising',
|
102 |
+
choices=['gaussian_denoising', 'median_denoising', 'fft_denoising'], help="Specify algorithm used for denoising")
|
103 |
+
|
104 |
+
# args to choose training mode
|
105 |
+
parser.add_argument("--adv_training", action='store_true', help="Enable adversarial training")
|
106 |
+
parser.add_argument("--adv_aux_weight", type=float, default=1, help="Weighting of the adversarial auxilliary loss")
|
107 |
+
parser.add_argument("--adv_aux_loss", type=str, default='ssim', choices=['l2', 'ssim'],
|
108 |
+
help="Type of adversarial auxilliary regularization loss")
|
109 |
+
|
110 |
+
parser.add_argument("--cache_downloaded_models", type=str2bool, default=True)
|
111 |
+
|
112 |
+
parser.add_argument('--test_run', action='store_true')
|
113 |
+
|
114 |
+
if 'ipykernel_launcher' in sys.argv[0]:
|
115 |
+
args = parser.parse_args([
|
116 |
+
'--dataset=Microscopy',
|
117 |
+
'--epochs=100',
|
118 |
+
'--augmentation=strong',
|
119 |
+
'--lr=1e-5',
|
120 |
+
'--freeze_processor',
|
121 |
+
# '--track_processing',
|
122 |
+
# '--test_run',
|
123 |
+
# '--track_predictions',
|
124 |
+
# '--track_every_epoch',
|
125 |
+
# '--adv_training',
|
126 |
+
# '--adv_aux_weight=100',
|
127 |
+
# '--adv_aux_loss=l2',
|
128 |
+
# '--log_model=',
|
129 |
+
])
|
130 |
+
else:
|
131 |
+
args = parser.parse_args()
|
132 |
+
|
133 |
+
|
134 |
+
def run_train(args):
|
135 |
+
|
136 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
137 |
+
training_mode = 'adversarial' if args.adv_training else 'default'
|
138 |
+
|
139 |
+
# set tracking uri, this is the address of the mlflow server where light experimental data will be stored
|
140 |
+
mlflow.set_tracking_uri(args.tracking_uri)
|
141 |
+
mlflow.set_experiment(args.experiment_name)
|
142 |
+
os.environ["AWS_ACCESS_KEY_ID"] = #TODO: add your AWS access key if you want to write your results to our collaborative lab server
|
143 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = #TODO: add your AWS seceret access key if you want to write your results to our collaborative lab server
|
144 |
+
|
145 |
+
# dataset
|
146 |
+
|
147 |
+
dataset = get_dataset(args.dataset)
|
148 |
+
|
149 |
+
print(f'dataset: {type(dataset).__name__}[{len(dataset)}]')
|
150 |
+
print(f'task: {dataset.task}')
|
151 |
+
print(f'mode: {training_mode} training')
|
152 |
+
print(f'# cross-validation subsets: {args.n_splits}')
|
153 |
+
pl.seed_everything(args.seed)
|
154 |
+
idxs_kfold = k_fold(dataset, n_splits=args.n_splits, seed=args.seed, train_size=args.train_size)
|
155 |
+
|
156 |
+
with mlflow.start_run(run_name=args.run_name) as parent_run:
|
157 |
+
|
158 |
+
for k_iter, idxs in enumerate(idxs_kfold):
|
159 |
+
|
160 |
+
print(f"K_fold subset: {k_iter+1}/{args.n_splits}")
|
161 |
+
|
162 |
+
if args.processing_mode == 'static':
|
163 |
+
if args.dataset == "Drone" or args.dataset == "DroneSegmentation":
|
164 |
+
mean = torch.tensor([0.35, 0.36, 0.35])
|
165 |
+
std = torch.tensor([0.12, 0.11, 0.12])
|
166 |
+
elif args.dataset == "Microscopy":
|
167 |
+
mean = torch.tensor([0.91, 0.84, 0.94])
|
168 |
+
std = torch.tensor([0.08, 0.12, 0.05])
|
169 |
+
|
170 |
+
dataset.transform = T.Compose([RawProcessingPipeline(
|
171 |
+
camera_parameters=dataset.camera_parameters,
|
172 |
+
debayer=args.sp_debayer,
|
173 |
+
sharpening=args.sp_sharpening,
|
174 |
+
denoising=args.sp_denoising,
|
175 |
+
), T.Normalize(mean, std)])
|
176 |
+
# XXX: Not clean
|
177 |
+
|
178 |
+
processor = nn.Identity()
|
179 |
+
|
180 |
+
if args.processor_uri is not None and args.processing_mode != 'none':
|
181 |
+
print('Fetching processor: ', end='')
|
182 |
+
model = fetch_from_mlflow(args.processor_uri, use_cache=args.cache_downloaded_models)
|
183 |
+
processor = model.processor
|
184 |
+
for param in processor.parameters():
|
185 |
+
param.requires_grad = True
|
186 |
+
model.processor = None
|
187 |
+
del model
|
188 |
+
else:
|
189 |
+
print(f'processing_mode: {args.processing_mode}')
|
190 |
+
normalize_mosaic = None # normalize after raw has been passed to raw2rgb
|
191 |
+
if args.dataset == "Microscopy":
|
192 |
+
mosaic_mean = [0.5663, 0.1401, 0.0731]
|
193 |
+
mosaic_std = [0.097, 0.0423, 0.008]
|
194 |
+
normalize_mosaic = T.Normalize(mosaic_mean, mosaic_std)
|
195 |
+
|
196 |
+
track_stages = args.track_processing or args.track_processing_gradients
|
197 |
+
if args.processing_mode == 'parametrized':
|
198 |
+
processor = ParametrizedProcessing(
|
199 |
+
camera_parameters=dataset.camera_parameters, track_stages=track_stages, batch_norm_output=True,
|
200 |
+
noise_layer=args.adv_training, # XXX: Remove?
|
201 |
+
)
|
202 |
+
elif args.processing_mode == 'neural_network':
|
203 |
+
processor = NNProcessing(track_stages=track_stages,
|
204 |
+
normalize_mosaic=normalize_mosaic, batch_norm_output=True)
|
205 |
+
elif args.processing_mode == 'none':
|
206 |
+
processor = RawToRGB(reduce_size=True, out_channels=3, track_stages=track_stages,
|
207 |
+
normalize_mosaic=normalize_mosaic)
|
208 |
+
|
209 |
+
if args.classifier_uri: # fetch classifier
|
210 |
+
print('Fetching classifier: ', end='')
|
211 |
+
model = fetch_from_mlflow(args.classifier_uri, use_cache=args.cache_downloaded_models)
|
212 |
+
classifier = model.classifier
|
213 |
+
model.classifier = None
|
214 |
+
del model
|
215 |
+
else:
|
216 |
+
if dataset.task == 'classification':
|
217 |
+
classifier = resnet_model(
|
218 |
+
model=resnet18,
|
219 |
+
pretrained=args.classifier_pretrained,
|
220 |
+
in_channels=3,
|
221 |
+
fc_out_features=len(dataset.classes)
|
222 |
+
)
|
223 |
+
else:
|
224 |
+
# XXX: add other network choices to args.smp_network (FPN) and args.network
|
225 |
+
classifier = smp.UnetPlusPlus(
|
226 |
+
encoder_name=args.smp_encoder,
|
227 |
+
encoder_depth=5,
|
228 |
+
encoder_weights='imagenet',
|
229 |
+
in_channels=3,
|
230 |
+
classes=1,
|
231 |
+
activation=None,
|
232 |
+
)
|
233 |
+
|
234 |
+
if args.freeze_processor and len(list(iter(processor.parameters()))) == 0:
|
235 |
+
print('Note: freezing processor without parameters.')
|
236 |
+
assert not (args.freeze_processor and args.freeze_classifier), 'Likely no parameters to train.'
|
237 |
+
|
238 |
+
if dataset.task == 'classification':
|
239 |
+
loss = nn.CrossEntropyLoss()
|
240 |
+
metrics = [accuracy]
|
241 |
+
else:
|
242 |
+
# loss = utils.base.smp_get_loss(args.smp_loss) # XXX: add other losses to args.smp_loss
|
243 |
+
loss = smp.losses.DiceLoss(mode='binary', from_logits=True)
|
244 |
+
metrics = [smp.utils.metrics.IoU()]
|
245 |
+
|
246 |
+
loss_aux = None
|
247 |
+
|
248 |
+
if args.adv_training:
|
249 |
+
|
250 |
+
assert args.processing_mode == 'parametrized', f"Processing mode ({args.processing_mode}) should be set to 'parametrized' for adversarial training"
|
251 |
+
assert args.freeze_classifier, "Classifier should be frozen for adversarial training"
|
252 |
+
assert not args.freeze_processor, "Processor should not be frozen for adversarial training"
|
253 |
+
|
254 |
+
processor_default = copy.deepcopy(processor)
|
255 |
+
processor_default.track_stages = False
|
256 |
+
processor_default.eval()
|
257 |
+
processor_default.to(DEVICE)
|
258 |
+
# debug(processor_default)
|
259 |
+
|
260 |
+
def l2_regularization(x, y):
|
261 |
+
return (x - y).norm()
|
262 |
+
|
263 |
+
if args.adv_aux_loss == 'l2':
|
264 |
+
regularization = l2_regularization
|
265 |
+
elif args.adv_aux_loss == 'ssim':
|
266 |
+
regularization = SSIM(window_size=11)
|
267 |
+
else:
|
268 |
+
NotImplementedError(args.adv_aux_loss)
|
269 |
+
|
270 |
+
class AuxLoss(nn.Module):
|
271 |
+
def __init__(self, loss_aux, weight=1):
|
272 |
+
super().__init__()
|
273 |
+
self.loss_aux = loss_aux
|
274 |
+
self.weight = weight
|
275 |
+
|
276 |
+
def forward(self, x):
|
277 |
+
x_reference = processor_default(x)
|
278 |
+
x_processed = processor.buffer['processed_rgb']
|
279 |
+
return self.weight * self.loss_aux(x_reference, x_processed)
|
280 |
+
|
281 |
+
class WeightedLoss(nn.Module):
|
282 |
+
def __init__(self, loss, weight=1):
|
283 |
+
super().__init__()
|
284 |
+
self.loss = loss
|
285 |
+
self.weight = weight
|
286 |
+
|
287 |
+
def forward(self, x, y):
|
288 |
+
return self.weight * self.loss(x, y)
|
289 |
+
|
290 |
+
def __repr__(self):
|
291 |
+
return f'{self.weight} * {get_name(self.loss)}'
|
292 |
+
|
293 |
+
loss = WeightedLoss(loss=nn.CrossEntropyLoss(), weight=-1)
|
294 |
+
# loss = WeightedLoss(loss=nn.CrossEntropyLoss(), weight=0)
|
295 |
+
loss_aux = AuxLoss(
|
296 |
+
loss_aux=regularization,
|
297 |
+
weight=args.adv_aux_weight,
|
298 |
+
)
|
299 |
+
|
300 |
+
augmentation = get_augmentation(args.augmentation)
|
301 |
+
|
302 |
+
model = LitModel(
|
303 |
+
classifier=classifier,
|
304 |
+
processor=processor,
|
305 |
+
loss=loss,
|
306 |
+
loss_aux=loss_aux,
|
307 |
+
adv_training=args.adv_training,
|
308 |
+
metrics=metrics,
|
309 |
+
augmentation=augmentation,
|
310 |
+
is_segmentation_task=dataset.task == 'segmentation',
|
311 |
+
freeze_classifier=args.freeze_classifier,
|
312 |
+
freeze_processor=args.freeze_processor,
|
313 |
+
)
|
314 |
+
|
315 |
+
# get train_set_dict
|
316 |
+
if args.state_dict_uri:
|
317 |
+
state_dict = mlflow.pytorch.load_state_dict(args.state_dict_uri)
|
318 |
+
train_indices = state_dict['train_indices']
|
319 |
+
valid_indices = state_dict['valid_indices']
|
320 |
+
else:
|
321 |
+
train_indices = idxs[0]
|
322 |
+
valid_indices = idxs[1]
|
323 |
+
state_dict = vars(args).copy()
|
324 |
+
|
325 |
+
track_indices = list(range(args.track_n_images))
|
326 |
+
|
327 |
+
if dataset.task == 'classification':
|
328 |
+
state_dict['classes'] = dataset.classes
|
329 |
+
state_dict['device'] = DEVICE
|
330 |
+
state_dict['train_indices'] = train_indices
|
331 |
+
state_dict['valid_indices'] = valid_indices
|
332 |
+
state_dict['elements in train set'] = len(train_indices)
|
333 |
+
state_dict['elements in test set'] = len(valid_indices)
|
334 |
+
|
335 |
+
if args.test_run:
|
336 |
+
train_indices = train_indices[:args.batch_size]
|
337 |
+
valid_indices = valid_indices[:args.batch_size]
|
338 |
+
|
339 |
+
train_set = Subset(dataset, indices=train_indices)
|
340 |
+
valid_set = Subset(dataset, indices=valid_indices)
|
341 |
+
track_set = Subset(dataset, indices=track_indices)
|
342 |
+
|
343 |
+
train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=16, shuffle=True)
|
344 |
+
valid_loader = DataLoader(valid_set, batch_size=args.batch_size, num_workers=16, shuffle=False)
|
345 |
+
track_loader = DataLoader(track_set, batch_size=args.batch_size, num_workers=16, shuffle=False)
|
346 |
+
|
347 |
+
with mlflow.start_run(run_name=f"{args.run_name}_{k_iter}", nested=True) as child_run:
|
348 |
+
|
349 |
+
#mlflow.pytorch.autolog(silent=True)
|
350 |
+
|
351 |
+
if k_iter == 0:
|
352 |
+
display_mlflow_run_info(child_run)
|
353 |
+
|
354 |
+
mlflow.pytorch.log_state_dict(state_dict, artifact_path=None)
|
355 |
+
|
356 |
+
hparams = {
|
357 |
+
'dataset': args.dataset,
|
358 |
+
'processing_mode': args.processing_mode,
|
359 |
+
'training_mode': training_mode,
|
360 |
+
}
|
361 |
+
if training_mode == 'adversarial':
|
362 |
+
hparams['adv_aux_weight'] = args.adv_aux_weight
|
363 |
+
hparams['adv_aux_loss'] = args.adv_aux_loss
|
364 |
+
|
365 |
+
mlflow.log_params(hparams)
|
366 |
+
|
367 |
+
with open('results/state_dict.txt', 'w') as f:
|
368 |
+
f.write('python ' + ' '.join(sys.argv) + '\n')
|
369 |
+
f.write('\n'.join([f'{k}={v}' for k, v in state_dict.items()]))
|
370 |
+
mlflow.log_artifact('results/state_dict.txt', artifact_path=None)
|
371 |
+
|
372 |
+
mlf_logger = pl.loggers.MLFlowLogger(experiment_name=args.experiment_name,
|
373 |
+
tracking_uri=args.tracking_uri,)
|
374 |
+
mlf_logger._run_id = child_run.info.run_id
|
375 |
+
|
376 |
+
callbacks = []
|
377 |
+
if args.track_processing:
|
378 |
+
callbacks += [TrackImagesCallback(track_loader,
|
379 |
+
track_every_epoch=args.track_every_epoch,
|
380 |
+
track_processing=args.track_processing,
|
381 |
+
track_gradients=args.track_processing_gradients,
|
382 |
+
track_predictions=args.track_predictions,
|
383 |
+
save_tensors=args.track_save_tensors)]
|
384 |
+
|
385 |
+
#if True: #args.save_best:
|
386 |
+
# if dataset.task == 'classification':
|
387 |
+
#checkpoint_callback = ModelCheckpoint(pathmonitor="val_accuracy", mode='max')
|
388 |
+
# checkpoint_callback = ModelCheckpoint(dirpath=args.tracking_uri, save_top_k=1, verbose=True, monitor="val_accuracy", mode="max") #dirpath=args.tracking_uri,
|
389 |
+
# else:
|
390 |
+
# checkpoint_callback = ModelCheckpoint(monitor="val_iou_score")
|
391 |
+
#callbacks += [checkpoint_callback]
|
392 |
+
|
393 |
+
trainer = pl.Trainer(
|
394 |
+
gpus=1 if DEVICE == 'cuda' else 0,
|
395 |
+
min_epochs=args.epochs,
|
396 |
+
max_epochs=args.epochs,
|
397 |
+
logger=mlf_logger,
|
398 |
+
callbacks=callbacks,
|
399 |
+
check_val_every_n_epoch=args.check_val_every_n_epoch,
|
400 |
+
#checkpoint_callback=True,
|
401 |
+
)
|
402 |
+
|
403 |
+
if args.log_model:
|
404 |
+
mlflow.pytorch.autolog(log_every_n_epoch=10)
|
405 |
+
print(f'model_uri="{mlflow.get_artifact_uri()}/model"')
|
406 |
+
|
407 |
+
t = trainer.fit(
|
408 |
+
model,
|
409 |
+
train_dataloader=train_loader,
|
410 |
+
val_dataloaders=valid_loader,
|
411 |
+
)
|
412 |
+
|
413 |
+
# if args.adv_training:
|
414 |
+
# for (name, p1), p2 in zip(processor.named_parameters(), processor_default.cpu().parameters()):
|
415 |
+
# print(f"param '{name}' diff: {p2 - p1}, l2: {(p2-p1).norm().item()}")
|
416 |
+
return model
|
417 |
+
|
418 |
+
|
419 |
+
if __name__ == '__main__':
|
420 |
+
model = run_train(args)
|
train.sh
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# # Parametrized Training
|
4 |
+
# 100 epochs, frozen_processor: http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com/#/experiments/49/runs/2803f44514e34a0f87d591520706e876
|
5 |
+
# model_uri="s3://mlflow-artifacts-601883093460/49/2803f44514e34a0f87d591520706e876/artifacts/model"
|
6 |
+
|
7 |
+
# used for training current model to 100% train and 80% val accuracy
|
8 |
+
# python train.py \
|
9 |
+
# --experiment_name parametrized \
|
10 |
+
# --classifier_uri "${model_uri}" \
|
11 |
+
# --run_name par_full_kurt \
|
12 |
+
# --dataset Microscopy \
|
13 |
+
# --lr 1e-5 \
|
14 |
+
# --epochs 50 \
|
15 |
+
# --freeze_classifier \
|
16 |
+
|
17 |
+
# --freeze_processor \
|
18 |
+
|
19 |
+
# # Adversarial Training
|
20 |
+
|
21 |
+
# python train.py \
|
22 |
+
# --experiment_name adversarial \
|
23 |
+
# --run_name adv_frozen_processor \
|
24 |
+
# --classifier_uri "${model_uri}" \
|
25 |
+
# --dataset Microscopy \
|
26 |
+
# --adv_training \
|
27 |
+
# --lr 1e-3 \
|
28 |
+
# --epochs 7 \
|
29 |
+
# --freeze_classifier \
|
30 |
+
# --track_processing \
|
31 |
+
# --track_every_epoch \
|
32 |
+
# --log_model=False \
|
33 |
+
# --adv_aux_weight=0.1 \
|
34 |
+
# --adv_aux_loss "l2" \
|
35 |
+
|
36 |
+
# --adv_aux_weight=2e-5 \
|
37 |
+
# --adv_aux_weight=2e-5 \
|
38 |
+
# --adv_aux_weight=1.9e-5 \
|
39 |
+
|
40 |
+
# Cross pipeline training (Segmentation/Classification)
|
41 |
+
|
42 |
+
# Static Pipeline Script
|
43 |
+
|
44 |
+
# datasets="Microscopy Drone DroneSegmentation"
|
45 |
+
datasets="DroneSegmentation"
|
46 |
+
augmentations="weak strong none"
|
47 |
+
|
48 |
+
demosaicings="bilinear malvar2004 menon2007"
|
49 |
+
sharpenings="sharpening_filter unsharp_masking"
|
50 |
+
denoisings="median_denoising gaussian_denoising"
|
51 |
+
|
52 |
+
for augment in $augmentations
|
53 |
+
do
|
54 |
+
for data in $datasets
|
55 |
+
do
|
56 |
+
for demosaicing in $demosaicings
|
57 |
+
do
|
58 |
+
for sharpening in $sharpenings
|
59 |
+
do
|
60 |
+
for denoising in $denoisings
|
61 |
+
do
|
62 |
+
|
63 |
+
python train.py \
|
64 |
+
--experiment_name ABtesting \
|
65 |
+
--run_name "$data"_"$demosaicing"_"$sharpening"_"$denoising"_"$augment" \
|
66 |
+
--dataset "$data" \
|
67 |
+
--batch_size 4 \
|
68 |
+
--lr 1e-5 \
|
69 |
+
--epochs 100 \
|
70 |
+
--sp_debayer "$demosaicing" \
|
71 |
+
--sp_sharpening "$sharpening" \
|
72 |
+
--sp_denoising "$denoising" \
|
73 |
+
--processing_mode "static" \
|
74 |
+
--augmentation "$augment" \
|
75 |
+
--n_split 5 \
|
76 |
+
|
77 |
+
done
|
78 |
+
done
|
79 |
+
done
|
80 |
+
done
|
81 |
+
done
|
utils/Cperturb.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Code extracted from the paper:
|
3 |
+
|
4 |
+
@articlehendrycks2019robustness,
|
5 |
+
title=Benchmarking Neural Network Robustness to Common Corruptions and Perturbations,
|
6 |
+
author=Dan Hendrycks and Thomas Dietterich,
|
7 |
+
journal=Proceedings of the International Conference on Learning Representations,
|
8 |
+
year=2019
|
9 |
+
}
|
10 |
+
|
11 |
+
The code is modified to fit with our model
|
12 |
+
'''
|
13 |
+
|
14 |
+
import os
|
15 |
+
from PIL import Image
|
16 |
+
import os.path
|
17 |
+
import time
|
18 |
+
import torch
|
19 |
+
import torchvision.datasets as dset
|
20 |
+
import torchvision.transforms as trn
|
21 |
+
import torch.utils.data as data
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
from PIL import Image
|
25 |
+
|
26 |
+
|
27 |
+
# /////////////// Distortion Helpers ///////////////
|
28 |
+
|
29 |
+
import skimage as sk
|
30 |
+
from skimage.filters import gaussian
|
31 |
+
from io import BytesIO
|
32 |
+
from wand.image import Image as WandImage
|
33 |
+
from wand.api import library as wandlibrary
|
34 |
+
import wand.color as WandColor
|
35 |
+
import ctypes
|
36 |
+
from PIL import Image as PILImage
|
37 |
+
import cv2
|
38 |
+
from scipy.ndimage import zoom as scizoom
|
39 |
+
from scipy.ndimage.interpolation import map_coordinates
|
40 |
+
import warnings
|
41 |
+
|
42 |
+
warnings.simplefilter("ignore", UserWarning)
|
43 |
+
|
44 |
+
|
45 |
+
def disk(radius, alias_blur=0.1, dtype=np.float32):
|
46 |
+
if radius <= 8:
|
47 |
+
L = np.arange(-8, 8 + 1)
|
48 |
+
ksize = (3, 3)
|
49 |
+
else:
|
50 |
+
L = np.arange(-radius, radius + 1)
|
51 |
+
ksize = (5, 5)
|
52 |
+
X, Y = np.meshgrid(L, L)
|
53 |
+
aliased_disk = np.array((X ** 2 + Y ** 2) <= radius ** 2, dtype=dtype)
|
54 |
+
aliased_disk /= np.sum(aliased_disk)
|
55 |
+
|
56 |
+
# supersample disk to antialias
|
57 |
+
return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)
|
58 |
+
|
59 |
+
|
60 |
+
# Tell Python about the C method
|
61 |
+
wandlibrary.MagickMotionBlurImage.argtypes = (ctypes.c_void_p, # wand
|
62 |
+
ctypes.c_double, # radius
|
63 |
+
ctypes.c_double, # sigma
|
64 |
+
ctypes.c_double) # angle
|
65 |
+
|
66 |
+
|
67 |
+
# Extend wand.image.Image class to include method signature
|
68 |
+
class MotionImage(WandImage):
|
69 |
+
def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
|
70 |
+
wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)
|
71 |
+
|
72 |
+
|
73 |
+
# modification of https://github.com/FLHerne/mapgen/blob/master/diamondsquare.py
|
74 |
+
def plasma_fractal(mapsize=32, wibbledecay=3):
|
75 |
+
"""
|
76 |
+
Generate a heightmap using diamond-square algorithm.
|
77 |
+
Return square 2d array, side length 'mapsize', of floats in range 0-255.
|
78 |
+
'mapsize' must be a power of two.
|
79 |
+
"""
|
80 |
+
assert (mapsize & (mapsize - 1) == 0)
|
81 |
+
maparray = np.empty((mapsize, mapsize), dtype=np.float_)
|
82 |
+
maparray[0, 0] = 0
|
83 |
+
stepsize = mapsize
|
84 |
+
wibble = 100
|
85 |
+
|
86 |
+
def wibbledmean(array):
|
87 |
+
return array / 4 + wibble * np.random.uniform(-wibble, wibble, array.shape)
|
88 |
+
|
89 |
+
def fillsquares():
|
90 |
+
"""For each square of points stepsize apart,
|
91 |
+
calculate middle value as mean of points + wibble"""
|
92 |
+
cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
|
93 |
+
squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0)
|
94 |
+
squareaccum += np.roll(squareaccum, shift=-1, axis=1)
|
95 |
+
maparray[stepsize // 2:mapsize:stepsize,
|
96 |
+
stepsize // 2:mapsize:stepsize] = wibbledmean(squareaccum)
|
97 |
+
|
98 |
+
def filldiamonds():
|
99 |
+
"""For each diamond of points stepsize apart,
|
100 |
+
calculate middle value as mean of points + wibble"""
|
101 |
+
mapsize = maparray.shape[0]
|
102 |
+
drgrid = maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize]
|
103 |
+
ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
|
104 |
+
ldrsum = drgrid + np.roll(drgrid, 1, axis=0)
|
105 |
+
lulsum = ulgrid + np.roll(ulgrid, -1, axis=1)
|
106 |
+
ltsum = ldrsum + lulsum
|
107 |
+
maparray[0:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(ltsum)
|
108 |
+
tdrsum = drgrid + np.roll(drgrid, 1, axis=1)
|
109 |
+
tulsum = ulgrid + np.roll(ulgrid, -1, axis=0)
|
110 |
+
ttsum = tdrsum + tulsum
|
111 |
+
maparray[stepsize // 2:mapsize:stepsize, 0:mapsize:stepsize] = wibbledmean(ttsum)
|
112 |
+
|
113 |
+
while stepsize >= 2:
|
114 |
+
fillsquares()
|
115 |
+
filldiamonds()
|
116 |
+
stepsize //= 2
|
117 |
+
wibble /= wibbledecay
|
118 |
+
|
119 |
+
maparray -= maparray.min()
|
120 |
+
return maparray / maparray.max()
|
121 |
+
|
122 |
+
|
123 |
+
def clipped_zoom(img, zoom_factor):
|
124 |
+
h = img.shape[0]
|
125 |
+
# ceil crop height(= crop width)
|
126 |
+
ch = int(np.ceil(h / zoom_factor))
|
127 |
+
|
128 |
+
top = (h - ch) // 2
|
129 |
+
img = scizoom(img[top:top + ch, top:top + ch], (zoom_factor, zoom_factor, 1), order=1)
|
130 |
+
# trim off any extra pixels
|
131 |
+
trim_top = (img.shape[0] - h) // 2
|
132 |
+
|
133 |
+
return img[trim_top:trim_top + h, trim_top:trim_top + h]
|
134 |
+
|
135 |
+
|
136 |
+
# /////////////// End Distortion Helpers ///////////////
|
137 |
+
|
138 |
+
|
139 |
+
# /////////////// Distortions ///////////////
|
140 |
+
|
141 |
+
class Distortions:
|
142 |
+
def __init__(self, severity=1, transform='identity'):
|
143 |
+
self.severity = severity
|
144 |
+
self.transform = transform
|
145 |
+
|
146 |
+
def __call__(self, img):
|
147 |
+
assert torch.is_tensor(img), 'Input data need to be a torch.tensor'
|
148 |
+
assert len(img.shape) == 3, 'Input image should be RGB'
|
149 |
+
img = self.torch2np(img)
|
150 |
+
t = getattr(self, self.transform)
|
151 |
+
img = t(img, self.severity)
|
152 |
+
return self.np2torch(img).float()
|
153 |
+
|
154 |
+
def np2torch(self,x):
|
155 |
+
return torch.tensor(x).permute(2,0,1)
|
156 |
+
|
157 |
+
def torch2np(self,x):
|
158 |
+
return np.array(x.permute(1,2,0))
|
159 |
+
|
160 |
+
def identity(self,x, severity=1):
|
161 |
+
return x
|
162 |
+
|
163 |
+
def gaussian_noise(self, x, severity=1):
|
164 |
+
c = [0.04, 0.06, .08, .09, .10][severity - 1]
|
165 |
+
return np.clip(x + np.random.normal(size=x.shape, scale=c), 0, 1)
|
166 |
+
|
167 |
+
|
168 |
+
def shot_noise(self, x, severity=1):
|
169 |
+
c = [500, 250, 100, 75, 50][severity - 1]
|
170 |
+
return np.clip(np.random.poisson(x * c) / c, 0, 1)
|
171 |
+
|
172 |
+
|
173 |
+
def impulse_noise(self, x, severity=1):
|
174 |
+
c = [.01, .02, .03, .05, .07][severity - 1]
|
175 |
+
|
176 |
+
x = sk.util.random_noise(x, mode='s&p', amount=c)
|
177 |
+
return np.clip(x, 0, 1)
|
178 |
+
|
179 |
+
|
180 |
+
def speckle_noise(self, x, severity=1):
|
181 |
+
c = [.06, .1, .12, .16, .2][severity - 1]
|
182 |
+
return np.clip(x + x * np.random.normal(size=x.shape, scale=c), 0, 1)
|
183 |
+
|
184 |
+
|
185 |
+
def gaussian_blur(self, x, severity=1):
|
186 |
+
c = [.4, .6, 0.7, .8, 1][severity - 1]
|
187 |
+
|
188 |
+
x = gaussian(x, sigma=c, multichannel=True)
|
189 |
+
return np.clip(x, 0, 1)
|
190 |
+
|
191 |
+
|
192 |
+
def glass_blur(self, x, severity=1):
|
193 |
+
# sigma, max_delta, iterations
|
194 |
+
c = [(0.05,1,1), (0.25,1,1), (0.4,1,1), (0.25,1,2), (0.4,1,2)][severity - 1]
|
195 |
+
|
196 |
+
x = gaussian(x, sigma=c[0], multichannel=True)
|
197 |
+
|
198 |
+
# locally shuffle pixels
|
199 |
+
for i in range(c[2]):
|
200 |
+
for h in range(32 - c[1], c[1], -1):
|
201 |
+
for w in range(32 - c[1], c[1], -1):
|
202 |
+
dx, dy = np.random.randint(-c[1], c[1], size=(2,))
|
203 |
+
h_prime, w_prime = h + dy, w + dx
|
204 |
+
# swap
|
205 |
+
x[h, w], x[h_prime, w_prime] = x[h_prime, w_prime], x[h, w]
|
206 |
+
|
207 |
+
return np.clip(gaussian(x, sigma=c[0], multichannel=True), 0, 1)
|
208 |
+
|
209 |
+
|
210 |
+
def defocus_blur(self, x, severity=1):
|
211 |
+
c = [(0.3, 0.4), (0.4, 0.5), (0.5, 0.6), (1, 0.2), (1.5, 0.1)][severity - 1]
|
212 |
+
kernel = disk(radius=c[0], alias_blur=c[1])
|
213 |
+
|
214 |
+
channels = []
|
215 |
+
for d in range(3):
|
216 |
+
channels.append(cv2.filter2D(x[:, :, d], -1, kernel))
|
217 |
+
channels = np.array(channels).transpose((1, 2, 0)) # 3x32x32 -> 32x32x3
|
218 |
+
|
219 |
+
return np.clip(channels, 0, 1)
|
220 |
+
|
221 |
+
|
222 |
+
def motion_blur(self, x, severity=1):
|
223 |
+
c = [(6,1), (6,1.5), (6,2), (8,2), (9,2.5)][severity - 1]
|
224 |
+
|
225 |
+
output = BytesIO()
|
226 |
+
x.save(output, format='PNG')
|
227 |
+
x = MotionImage(blob=output.getvalue())
|
228 |
+
|
229 |
+
x.motion_blur(radius=c[0], sigma=c[1], angle=np.random.uniform(-45, 45))
|
230 |
+
|
231 |
+
x = cv2.imdecode(np.fromstring(x.make_blob(), np.uint8),
|
232 |
+
cv2.IMREAD_UNCHANGED)
|
233 |
+
|
234 |
+
if x.shape != (32, 32):
|
235 |
+
return np.clip(x[..., [2, 1, 0]], 0, 1) # BGR to RGB
|
236 |
+
else: # greyscale to RGB
|
237 |
+
return np.clip(np.array([x, x, x]).transpose((1, 2, 0)), 0, 1)
|
238 |
+
|
239 |
+
|
240 |
+
def zoom_blur(self, x, severity=1):
|
241 |
+
c = [np.arange(1, 1.06, 0.01), np.arange(1, 1.11, 0.01), np.arange(1, 1.16, 0.01),
|
242 |
+
np.arange(1, 1.21, 0.01), np.arange(1, 1.26, 0.01)][severity - 1]
|
243 |
+
out = np.zeros_like(x)
|
244 |
+
for zoom_factor in c:
|
245 |
+
out += clipped_zoom(x, zoom_factor)
|
246 |
+
|
247 |
+
x = (x + out) / (len(c) + 1)
|
248 |
+
return np.clip(x, 0, 1)
|
249 |
+
|
250 |
+
|
251 |
+
def fog(self, x, severity=1):
|
252 |
+
c = [(.2,3), (.5,3), (0.75,2.5), (1,2), (1.5,1.75)][severity - 1]
|
253 |
+
max_val = x.max()
|
254 |
+
x += c[0] * plasma_fractal(wibbledecay=c[1])[:32, :32][..., np.newaxis]
|
255 |
+
return np.clip(x * max_val / (max_val + c[0]), 0, 1)
|
256 |
+
|
257 |
+
|
258 |
+
def frost(self, x, severity=1):
|
259 |
+
c = [(1, 0.2), (1, 0.3), (0.9, 0.4), (0.85, 0.4), (0.75, 0.45)][severity - 1]
|
260 |
+
idx = np.random.randint(5)
|
261 |
+
filename = ['./frost1.png', './frost2.png', './frost3.png', './frost4.jpg', './frost5.jpg', './frost6.jpg'][idx]
|
262 |
+
frost = cv2.imread(filename)
|
263 |
+
frost = cv2.resize(frost, (0, 0), fx=0.2, fy=0.2)
|
264 |
+
# randomly crop and convert to rgb
|
265 |
+
x_start, y_start = np.random.randint(0, frost.shape[0] - 32), np.random.randint(0, frost.shape[1] - 32)
|
266 |
+
frost = frost[x_start:x_start + 32, y_start:y_start + 32][..., [2, 1, 0]]
|
267 |
+
|
268 |
+
return np.clip(c[0] * np.array(x) + c[1] * frost, 0, 1)
|
269 |
+
|
270 |
+
|
271 |
+
def snow(self, x, severity=1):
|
272 |
+
c = [(0.1,0.2,1,0.6,8,3,0.95),
|
273 |
+
(0.1,0.2,1,0.5,10,4,0.9),
|
274 |
+
(0.15,0.3,1.75,0.55,10,4,0.9),
|
275 |
+
(0.25,0.3,2.25,0.6,12,6,0.85),
|
276 |
+
(0.3,0.3,1.25,0.65,14,12,0.8)][severity - 1]
|
277 |
+
|
278 |
+
snow_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1]) # [:2] for monochrome
|
279 |
+
|
280 |
+
snow_layer = clipped_zoom(snow_layer[..., np.newaxis], c[2])
|
281 |
+
snow_layer[snow_layer < c[3]] = 0
|
282 |
+
|
283 |
+
snow_layer = PILImage.fromarray((np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), mode='L')
|
284 |
+
output = BytesIO()
|
285 |
+
snow_layer.save(output, format='PNG')
|
286 |
+
snow_layer = MotionImage(blob=output.getvalue())
|
287 |
+
|
288 |
+
snow_layer.motion_blur(radius=c[4], sigma=c[5], angle=np.random.uniform(-135, -45))
|
289 |
+
|
290 |
+
snow_layer = cv2.imdecode(np.fromstring(snow_layer.make_blob(), np.uint8),
|
291 |
+
cv2.IMREAD_UNCHANGED) / (2**16-1)
|
292 |
+
snow_layer = snow_layer[..., np.newaxis]
|
293 |
+
|
294 |
+
x = c[6] * x + (1 - c[6]) * np.maximum(x, cv2.cvtColor(x, cv2.COLOR_RGB2GRAY).reshape(32, 32, 1) * 1.5 + 0.5)
|
295 |
+
return np.clip(x + snow_layer + np.rot90(snow_layer, k=2), 0, 1)
|
296 |
+
|
297 |
+
|
298 |
+
def spatter(self, x, severity=1):
|
299 |
+
c = [(0.62,0.1,0.7,0.7,0.5,0),
|
300 |
+
(0.65,0.1,0.8,0.7,0.5,0),
|
301 |
+
(0.65,0.3,1,0.69,0.5,0),
|
302 |
+
(0.65,0.1,0.7,0.69,0.6,1),
|
303 |
+
(0.65,0.1,0.5,0.68,0.6,1)][severity - 1]
|
304 |
+
|
305 |
+
liquid_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1])
|
306 |
+
|
307 |
+
liquid_layer = gaussian(liquid_layer, sigma=c[2])
|
308 |
+
liquid_layer[liquid_layer < c[3]] = 0
|
309 |
+
if c[5] == 0:
|
310 |
+
liquid_layer = (liquid_layer * (2**16-1)).astype(np.uint8)
|
311 |
+
dist = (2**16-1) - cv2.Canny(liquid_layer, 50, 150)
|
312 |
+
dist = cv2.distanceTransform(dist, cv2.DIST_L2, 5)
|
313 |
+
_, dist = cv2.threshold(dist, 20, 20, cv2.THRESH_TRUNC)
|
314 |
+
dist = cv2.blur(dist, (3, 3)).astype(np.uint8)
|
315 |
+
dist = cv2.equalizeHist(dist)
|
316 |
+
# ker = np.array([[-1,-2,-3],[-2,0,0],[-3,0,1]], dtype=np.float32)
|
317 |
+
# ker -= np.mean(ker)
|
318 |
+
ker = np.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]])
|
319 |
+
dist = cv2.filter2D(dist, cv2.CV_8U, ker)
|
320 |
+
dist = cv2.blur(dist, (3, 3)).astype(np.float32)
|
321 |
+
|
322 |
+
m = cv2.cvtColor(liquid_layer * dist, cv2.COLOR_GRAY2BGRA)
|
323 |
+
m /= np.max(m, axis=(0, 1))
|
324 |
+
m *= c[4]
|
325 |
+
|
326 |
+
# water is pale turqouise
|
327 |
+
color = np.concatenate((175 / 255. * np.ones_like(m[..., :1]),
|
328 |
+
238 / 255. * np.ones_like(m[..., :1]),
|
329 |
+
238 / 255. * np.ones_like(m[..., :1])), axis=2)
|
330 |
+
|
331 |
+
color = cv2.cvtColor(color, cv2.COLOR_BGR2BGRA)
|
332 |
+
x = cv2.cvtColor(x, cv2.COLOR_BGR2BGRA)
|
333 |
+
|
334 |
+
return cv2.cvtColor(np.clip(x + m * color, 0, 1), cv2.COLOR_BGRA2BGR) * (2**16-1)
|
335 |
+
else:
|
336 |
+
m = np.where(liquid_layer > c[3], 1, 0)
|
337 |
+
m = gaussian(m.astype(np.float32), sigma=c[4])
|
338 |
+
m[m < 0.8] = 0
|
339 |
+
# m = np.abs(m) ** (1/c[4])
|
340 |
+
|
341 |
+
# mud brown
|
342 |
+
color = np.concatenate((63 / 255. * np.ones_like(x[..., :1]),
|
343 |
+
42 / 255. * np.ones_like(x[..., :1]),
|
344 |
+
20 / 255. * np.ones_like(x[..., :1])), axis=2)
|
345 |
+
|
346 |
+
color *= m[..., np.newaxis]
|
347 |
+
x *= (1 - m[..., np.newaxis])
|
348 |
+
|
349 |
+
return np.clip(x + color, 0, 1)
|
350 |
+
|
351 |
+
|
352 |
+
def contrast(self, x, severity=1):
|
353 |
+
c = [.75, .5, .4, .3, 0.15][severity - 1]
|
354 |
+
means = np.mean(x, axis=(0, 1), keepdims=True)
|
355 |
+
return np.clip((x - means) * c + means, 0, 1)
|
356 |
+
|
357 |
+
|
358 |
+
def brightness(self, x, severity=1):
|
359 |
+
c = [.05, .1, .15, .2, .3][severity - 1]
|
360 |
+
|
361 |
+
x = sk.color.rgb2hsv(x)
|
362 |
+
x[:, :, 2] = np.clip(x[:, :, 2] + c, 0, 1)
|
363 |
+
x = sk.color.hsv2rgb(x)
|
364 |
+
|
365 |
+
return np.clip(x, 0, 1)
|
366 |
+
|
367 |
+
|
368 |
+
def saturate(self, x, severity=1):
|
369 |
+
c = [(0.3, 0), (0.1, 0), (1.5, 0), (2, 0.1), (2.5, 0.2)][severity - 1]
|
370 |
+
|
371 |
+
x = sk.color.rgb2hsv(x)
|
372 |
+
x[:, :, 1] = np.clip(x[:, :, 1] * c[0] + c[1], 0, 1)
|
373 |
+
x = sk.color.hsv2rgb(x)
|
374 |
+
|
375 |
+
return np.clip(x, 0, 1)
|
376 |
+
|
377 |
+
|
378 |
+
def jpeg_compression(self, x, severity=1):
|
379 |
+
c = [80, 65, 58, 50, 40][severity - 1]
|
380 |
+
|
381 |
+
output = BytesIO()
|
382 |
+
x.save(output, 'JPEG', quality=c)
|
383 |
+
x = PILImage.open(output)
|
384 |
+
|
385 |
+
return x
|
386 |
+
|
387 |
+
|
388 |
+
def pixelate(self, x, severity=1):
|
389 |
+
c = [0.95, 0.9, 0.85, 0.75, 0.65][severity - 1]
|
390 |
+
|
391 |
+
x = x.resize((int(32 * c), int(32 * c)), PILImage.BOX)
|
392 |
+
x = x.resize((32, 32), PILImage.BOX)
|
393 |
+
|
394 |
+
return x
|
395 |
+
|
396 |
+
|
397 |
+
# mod of https://gist.github.com/erniejunior/601cdf56d2b424757de5
|
398 |
+
def elastic_transform(self, image, severity=1):
|
399 |
+
IMSIZE = 32
|
400 |
+
c = [(IMSIZE*0, IMSIZE*0, IMSIZE*0.08),
|
401 |
+
(IMSIZE*0.05, IMSIZE*0.2, IMSIZE*0.07),
|
402 |
+
(IMSIZE*0.08, IMSIZE*0.06, IMSIZE*0.06),
|
403 |
+
(IMSIZE*0.1, IMSIZE*0.04, IMSIZE*0.05),
|
404 |
+
(IMSIZE*0.1, IMSIZE*0.03, IMSIZE*0.03)][severity - 1]
|
405 |
+
|
406 |
+
shape = image.shape
|
407 |
+
shape_size = shape[:2]
|
408 |
+
|
409 |
+
# random affine
|
410 |
+
center_square = np.float32(shape_size) // 2
|
411 |
+
square_size = min(shape_size) // 3
|
412 |
+
pts1 = np.float32([center_square + square_size,
|
413 |
+
[center_square[0] + square_size, center_square[1] - square_size],
|
414 |
+
center_square - square_size])
|
415 |
+
pts2 = pts1 + np.random.uniform(-c[2], c[2], size=pts1.shape).astype(np.float32)
|
416 |
+
M = cv2.getAffineTransform(pts1, pts2)
|
417 |
+
image = cv2.warpAffine(image, M, shape_size[::-1], borderMode=cv2.BORDER_REFLECT_101)
|
418 |
+
|
419 |
+
dx = (gaussian(np.random.uniform(-1, 1, size=shape[:2]),
|
420 |
+
c[1], mode='reflect', truncate=3) * c[0]).astype(np.float32)
|
421 |
+
dy = (gaussian(np.random.uniform(-1, 1, size=shape[:2]),
|
422 |
+
c[1], mode='reflect', truncate=3) * c[0]).astype(np.float32)
|
423 |
+
dx, dy = dx[..., np.newaxis], dy[..., np.newaxis]
|
424 |
+
|
425 |
+
x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2]))
|
426 |
+
indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z, (-1, 1))
|
427 |
+
return np.clip(map_coordinates(image, indices, order=1, mode='reflect').reshape(shape), 0, 1)
|
428 |
+
|
429 |
+
if __name__=='__main__':
|
430 |
+
import os
|
431 |
+
|
432 |
+
import numpy as np
|
433 |
+
import matplotlib.pyplot as plt
|
434 |
+
import tifffile as tiff
|
435 |
+
import torch
|
436 |
+
|
437 |
+
os.system('cd ..')
|
438 |
+
|
439 |
+
img = tiff.imread('/home/marco/perturbed-minds/perturbed-minds/data/microscopy/images/rgb_scale100/Ma190c_lame1_zone1_composite_Mcropped_1.tiff')
|
440 |
+
img = np.array(img)/(2**16-1)
|
441 |
+
img = torch.tensor(img).permute(2,0,1)
|
442 |
+
|
443 |
+
def identity(x, sev):
|
444 |
+
return x
|
445 |
+
|
446 |
+
if not os.path.exists('results/Cimages'):
|
447 |
+
os.makedirs('results/Cimages')
|
448 |
+
|
449 |
+
transformations = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
|
450 |
+
'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
|
451 |
+
|
452 |
+
# glass_blur, defocus_blur, motion_blur, fog, frost, snow, spatter, jpeg_compression, pixelate,
|
453 |
+
|
454 |
+
plt.figure()
|
455 |
+
plt.imshow(img.permute(1,2,0))
|
456 |
+
plt.title('identity')
|
457 |
+
plt.show()
|
458 |
+
plt.savefig(f'results/Cimages/1_identity.png')
|
459 |
+
|
460 |
+
|
461 |
+
for i,t in enumerate(transformations):
|
462 |
+
|
463 |
+
fig = plt.figure(figsize=(25,5))
|
464 |
+
columns = 5
|
465 |
+
rows = 1
|
466 |
+
|
467 |
+
for sev in range(1,6):
|
468 |
+
dist = Distortions(severity=sev, transform=t)
|
469 |
+
fig.add_subplot(rows, columns, sev)
|
470 |
+
plt.imshow(dist(img).permute(1,2,0))
|
471 |
+
plt.title(f'{t} {sev}')
|
472 |
+
plt.xticks([], [])
|
473 |
+
plt.yticks([], [])
|
474 |
+
plt.show()
|
475 |
+
plt.savefig(f'results/Cimages/{i+2}_{t}.png')
|
utils/augmentation.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torchvision.transforms as T
|
6 |
+
|
7 |
+
|
8 |
+
class RandomRotate90(): # Note: not the same as T.RandomRotation(90)
|
9 |
+
def __call__(self, x):
|
10 |
+
x = x.rot90(random.randint(0, 3), dims=(-1, -2))
|
11 |
+
return x
|
12 |
+
|
13 |
+
def __repr__(self):
|
14 |
+
return self.__class__.__name__
|
15 |
+
|
16 |
+
|
17 |
+
class AddGaussianNoise():
|
18 |
+
def __init__(self, std=0.01):
|
19 |
+
self.std = std
|
20 |
+
|
21 |
+
def __call__(self, x):
|
22 |
+
# noise = torch.randn_like(x) * self.std
|
23 |
+
# out = x + noise
|
24 |
+
# debug(x)
|
25 |
+
# debug(noise)
|
26 |
+
# debug(out)
|
27 |
+
return x + torch.randn_like(x) * self.std
|
28 |
+
|
29 |
+
def __repr__(self):
|
30 |
+
return self.__class__.__name__ + f'(std={self.std})'
|
31 |
+
|
32 |
+
|
33 |
+
def set_global_seed(seed):
|
34 |
+
torch.random.manual_seed(seed)
|
35 |
+
np.random.seed(seed % (2**32 - 1))
|
36 |
+
random.seed(seed)
|
37 |
+
|
38 |
+
|
39 |
+
class ComposeState(T.Compose):
|
40 |
+
def __init__(self, transforms):
|
41 |
+
self.transforms = []
|
42 |
+
self.mask_transforms = []
|
43 |
+
|
44 |
+
for t in transforms:
|
45 |
+
apply_for_mask = True
|
46 |
+
if isinstance(t, tuple):
|
47 |
+
t, apply_for_mask = t
|
48 |
+
self.transforms.append(t)
|
49 |
+
if apply_for_mask:
|
50 |
+
self.mask_transforms.append(t)
|
51 |
+
|
52 |
+
self.seed = None
|
53 |
+
|
54 |
+
# @debug
|
55 |
+
def __call__(self, x, retain_state=False, mask_transform=False):
|
56 |
+
if self.seed is not None: # retain previous state
|
57 |
+
set_global_seed(self.seed)
|
58 |
+
if retain_state: # save state for next call
|
59 |
+
self.seed = self.seed or torch.seed()
|
60 |
+
set_global_seed(self.seed)
|
61 |
+
else:
|
62 |
+
self.seed = None # reset / ignore state
|
63 |
+
|
64 |
+
transforms = self.transforms if not mask_transform else self.mask_transforms
|
65 |
+
for t in transforms:
|
66 |
+
x = t(x)
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
augmentation_weak = ComposeState([
|
71 |
+
T.RandomHorizontalFlip(),
|
72 |
+
T.RandomVerticalFlip(),
|
73 |
+
RandomRotate90(),
|
74 |
+
])
|
75 |
+
|
76 |
+
|
77 |
+
augmentation_strong = ComposeState([
|
78 |
+
T.RandomHorizontalFlip(p=0.5),
|
79 |
+
T.RandomVerticalFlip(p=0.5),
|
80 |
+
T.RandomApply([T.RandomRotation(90)], p=0.5),
|
81 |
+
# (transform, apply_to_mask=True)
|
82 |
+
(T.RandomApply([AddGaussianNoise(std=0.0005)], p=0.5), False),
|
83 |
+
(T.RandomAdjustSharpness(0.5, p=0.5), False),
|
84 |
+
])
|
85 |
+
|
86 |
+
|
87 |
+
def get_augmentation(type):
|
88 |
+
if type == 'none':
|
89 |
+
return None
|
90 |
+
if type == 'weak':
|
91 |
+
return augmentation_weak
|
92 |
+
if type == 'strong':
|
93 |
+
return augmentation_strong
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == '__main__':
|
97 |
+
import os
|
98 |
+
if not os.path.exists('README.md'):
|
99 |
+
os.chdir('..')
|
100 |
+
|
101 |
+
# from utils.debug import debug
|
102 |
+
from utils.dataset import get_dataset
|
103 |
+
import matplotlib.pyplot as plt
|
104 |
+
|
105 |
+
dataset = get_dataset('DS') # drone segmentation
|
106 |
+
img, mask = dataset[10]
|
107 |
+
mask = (mask + 0.2) / 1.2
|
108 |
+
|
109 |
+
plt.figure(figsize=(14, 8))
|
110 |
+
plt.subplot(121)
|
111 |
+
plt.imshow(img)
|
112 |
+
plt.subplot(122)
|
113 |
+
plt.imshow(mask)
|
114 |
+
plt.suptitle('no augmentation')
|
115 |
+
plt.show()
|
116 |
+
|
117 |
+
from utils.base import np2torch, torch2np
|
118 |
+
img, mask = np2torch(img), np2torch(mask)
|
119 |
+
|
120 |
+
# from utils.augmentation import get_augmentation
|
121 |
+
augmentation = get_augmentation('strong')
|
122 |
+
|
123 |
+
set_global_seed(1)
|
124 |
+
|
125 |
+
for i in range(1, 4):
|
126 |
+
plt.figure(figsize=(14, 8))
|
127 |
+
plt.subplot(121)
|
128 |
+
plt.imshow(torch2np(augmentation(img.unsqueeze(0), retain_state=True)).squeeze())
|
129 |
+
plt.subplot(122)
|
130 |
+
plt.imshow(torch2np(augmentation(mask.unsqueeze(0), mask_transform=True)).squeeze())
|
131 |
+
plt.suptitle(f'augmentation test {i}')
|
132 |
+
plt.show()
|
utils/base.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for other scripts
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import shutil
|
7 |
+
|
8 |
+
import random
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import mlflow
|
12 |
+
from mlflow.tracking import MlflowClient
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
from IPython.display import display, Markdown
|
16 |
+
|
17 |
+
from b2sdk.v1 import *
|
18 |
+
|
19 |
+
import argparse
|
20 |
+
|
21 |
+
|
22 |
+
class SmartFormatter(argparse.HelpFormatter):
|
23 |
+
|
24 |
+
def _split_lines(self, text, width):
|
25 |
+
if text.startswith('R|'):
|
26 |
+
return text[2:].splitlines()
|
27 |
+
# this is the RawTextHelpFormatter._split_lines
|
28 |
+
return argparse.HelpFormatter._split_lines(self, text, width)
|
29 |
+
|
30 |
+
|
31 |
+
def str2bool(string):
|
32 |
+
return string == 'True'
|
33 |
+
|
34 |
+
|
35 |
+
def np2torch(nparray):
|
36 |
+
"""Convert numpy array to torch tensor
|
37 |
+
For array with more than 3 channels, it is better to use an input array in the format BxHxWxC
|
38 |
+
|
39 |
+
Args:
|
40 |
+
numpy array (ndarray) BxHxWxC
|
41 |
+
Returns:
|
42 |
+
torch tensor (tensor) BxCxHxW"""
|
43 |
+
|
44 |
+
tensor = torch.Tensor(nparray)
|
45 |
+
|
46 |
+
if tensor.ndim == 2:
|
47 |
+
return tensor
|
48 |
+
if tensor.ndim == 3:
|
49 |
+
height, width, channels = tensor.shape
|
50 |
+
if channels <= 3: # Single image with more channels (HxWxC)
|
51 |
+
return tensor.permute(2, 0, 1)
|
52 |
+
|
53 |
+
if tensor.ndim == 4: # More images with more channels (BxHxWxC)
|
54 |
+
return tensor.permute(0, 3, 1, 2)
|
55 |
+
|
56 |
+
return tensor
|
57 |
+
|
58 |
+
|
59 |
+
def torch2np(torchtensor):
|
60 |
+
"""Convert torch tensor to numpy array
|
61 |
+
For tensor with more than 3 channels or batch, it is better to use an input tensor in the format BxCxHxW
|
62 |
+
|
63 |
+
Args:
|
64 |
+
torch tensor (tensor) BxCxHxW
|
65 |
+
Returns:
|
66 |
+
numpy array (ndarray) BxHxWxC"""
|
67 |
+
|
68 |
+
ndarray = torchtensor.detach().cpu().numpy().astype(np.float32)
|
69 |
+
|
70 |
+
if ndarray.ndim == 3: # Single image with more channels (CxHxW)
|
71 |
+
channels, height, width = ndarray.shape
|
72 |
+
if channels <= 3:
|
73 |
+
return ndarray.transpose(1, 2, 0)
|
74 |
+
|
75 |
+
if ndarray.ndim == 4: # More images with more channels (BxCxHxW)
|
76 |
+
return ndarray.transpose(0, 2, 3, 1)
|
77 |
+
|
78 |
+
return ndarray
|
79 |
+
|
80 |
+
|
81 |
+
def set_random_seed(seed):
|
82 |
+
np.random.seed(seed) # cpu vars
|
83 |
+
torch.manual_seed(seed) # cpu vars
|
84 |
+
random.seed(seed) # Python
|
85 |
+
if torch.cuda.is_available():
|
86 |
+
torch.cuda.manual_seed(seed)
|
87 |
+
torch.cuda.manual_seed_all(seed) # gpu vars
|
88 |
+
torch.backends.cudnn.deterministic = True # needed
|
89 |
+
torch.backends.cudnn.benchmark = False
|
90 |
+
|
91 |
+
|
92 |
+
def normalize(img):
|
93 |
+
"""Normalize images
|
94 |
+
|
95 |
+
Args:
|
96 |
+
imgs (ndarray): image to normalize --> size: (Height,Width,Channels)
|
97 |
+
Returns:
|
98 |
+
normalized (ndarray): normalized image
|
99 |
+
mu (ndarray): mean
|
100 |
+
sigma (ndarray): standard deviation
|
101 |
+
"""
|
102 |
+
|
103 |
+
img = img.astype(float)
|
104 |
+
|
105 |
+
if len(img.shape) == 2:
|
106 |
+
img = img[:, :, np.newaxis]
|
107 |
+
|
108 |
+
height, width, channels = img.shape
|
109 |
+
|
110 |
+
mu, sigma = np.empty(channels), np.empty(channels)
|
111 |
+
|
112 |
+
for ch in range(channels):
|
113 |
+
temp_mu = img[:, :, ch].mean()
|
114 |
+
temp_sigma = img[:, :, ch].std()
|
115 |
+
|
116 |
+
img[:, :, ch] = (img[:, :, ch] - temp_mu) / (temp_sigma + 1e-4)
|
117 |
+
|
118 |
+
mu[ch] = temp_mu
|
119 |
+
sigma[ch] = temp_sigma
|
120 |
+
|
121 |
+
return img, mu, sigma
|
122 |
+
|
123 |
+
|
124 |
+
def b2_list_files(folder=''):
|
125 |
+
bucket = get_b2_bucket()
|
126 |
+
for file_info, _ in bucket.ls(folder, show_versions=False):
|
127 |
+
print(file_info.file_name)
|
128 |
+
|
129 |
+
|
130 |
+
def get_b2_bucket():
|
131 |
+
bucket_name = 'perturbed-minds'
|
132 |
+
application_key_id = '003d6b042de536a0000000004'
|
133 |
+
application_key = 'K003E5Cr+BAYlvSHfg2ynLtvS5aNq78'
|
134 |
+
info = InMemoryAccountInfo()
|
135 |
+
b2_api = B2Api(info)
|
136 |
+
b2_api.authorize_account('production', application_key_id, application_key)
|
137 |
+
bucket = b2_api.get_bucket_by_name(bucket_name)
|
138 |
+
return bucket
|
139 |
+
|
140 |
+
|
141 |
+
def b2_download_folder(b2_dir, local_dir, force_download=False, mirror_folder=True):
|
142 |
+
"""Downloads a folder from the b2 bucket and optionally cleans
|
143 |
+
up files that are no longer on the server
|
144 |
+
|
145 |
+
Args:
|
146 |
+
b2_dir (str): path to folder on the b2 server
|
147 |
+
local_dir (str): path to folder on the local machine
|
148 |
+
force_download (bool, optional): force the download, if set to `False`,
|
149 |
+
files with matching names on the local machine will be skipped
|
150 |
+
mirror_folder (bool, optional): if set to `True`, files that are found in
|
151 |
+
the local directory, but are not on the server will be deleted
|
152 |
+
"""
|
153 |
+
bucket = get_b2_bucket()
|
154 |
+
|
155 |
+
if not os.path.exists(local_dir):
|
156 |
+
os.makedirs(local_dir)
|
157 |
+
elif not force_download:
|
158 |
+
return
|
159 |
+
|
160 |
+
download_files = [file_info.file_name.split(b2_dir + '/')[-1]
|
161 |
+
for file_info, _ in bucket.ls(b2_dir, show_versions=False)]
|
162 |
+
|
163 |
+
for file_name in download_files:
|
164 |
+
if file_name.endswith('/.bzEmpty'): # subdirectory, download recursively
|
165 |
+
subdir = file_name.replace('/.bzEmpty', '')
|
166 |
+
if len(subdir) > 0:
|
167 |
+
b2_subdir = os.path.join(b2_dir, subdir)
|
168 |
+
local_subdir = os.path.join(local_dir, subdir)
|
169 |
+
if b2_subdir != b2_dir:
|
170 |
+
b2_download_folder(b2_subdir, local_subdir, force_download=force_download,
|
171 |
+
mirror_folder=mirror_folder)
|
172 |
+
else: # file
|
173 |
+
b2_file = os.path.join(b2_dir, file_name)
|
174 |
+
local_file = os.path.join(local_dir, file_name)
|
175 |
+
if not os.path.exists(local_file) or force_download:
|
176 |
+
print(f"downloading b2://{b2_file} -> {local_file}")
|
177 |
+
bucket.download_file_by_name(b2_file, DownloadDestLocalFile(local_file))
|
178 |
+
|
179 |
+
if mirror_folder: # remove all files that are not on the b2 server anymore
|
180 |
+
for i, file in enumerate(download_files):
|
181 |
+
if file.endswith('/.bzEmpty'): # subdirectory, download recursively
|
182 |
+
download_files[i] = file.replace('/.bzEmpty', '')
|
183 |
+
for file_name in os.listdir(local_dir):
|
184 |
+
if file_name not in download_files:
|
185 |
+
local_file = os.path.join(local_dir, file_name)
|
186 |
+
print(f"deleting {local_file}")
|
187 |
+
if os.path.isdir(local_file):
|
188 |
+
shutil.rmtree(local_file)
|
189 |
+
else:
|
190 |
+
os.remove(local_file)
|
191 |
+
|
192 |
+
|
193 |
+
def get_name(obj):
|
194 |
+
return obj.__name__ if hasattr(obj, '__name__') else type(obj).__name__
|
195 |
+
|
196 |
+
|
197 |
+
def get_mlflow_model_by_name(experiment_name, run_name,
|
198 |
+
tracking_uri = "http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com",
|
199 |
+
download_model = True):
|
200 |
+
|
201 |
+
# 0. mlflow basics
|
202 |
+
mlflow.set_tracking_uri(tracking_uri)
|
203 |
+
os.environ["AWS_ACCESS_KEY_ID"] = #TODO: add your AWS access key if you want to write your results to our collaborative lab server
|
204 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = #TODO: add your AWS seceret access key if you want to write your results to our collaborative lab server
|
205 |
+
|
206 |
+
# # 1. use get_experiment_by_name to get experiment objec
|
207 |
+
experiment = mlflow.get_experiment_by_name(experiment_name)
|
208 |
+
|
209 |
+
# # 2. use search_runs with experiment_id for string search query
|
210 |
+
if os.path.isfile('cache/runs_names.pkl'):
|
211 |
+
runs = pd.read_pickle('cache/runs_names.pkl')
|
212 |
+
if runs['tags.mlflow.runName'][runs['tags.mlflow.runName'] == run_name].empty:
|
213 |
+
runs = fetch_runs_list_mlflow(experiment) #returns a pandas data frame where each row is a run (if several exist under that name)
|
214 |
+
else:
|
215 |
+
runs = fetch_runs_list_mlflow(experiment) #returns a pandas data frame where each row is a run (if several exist under that name)
|
216 |
+
|
217 |
+
# 3. get the selected run between all runs inside the selected experiment
|
218 |
+
run = runs.loc[runs['tags.mlflow.runName'] == run_name]
|
219 |
+
|
220 |
+
# 4. check if there is only a run with that name
|
221 |
+
assert len(run) == 1, "More runs with this name"
|
222 |
+
index_run = run.index[0]
|
223 |
+
artifact_uri = run.loc[index_run, 'artifact_uri']
|
224 |
+
|
225 |
+
# 5. load state_dict of your run
|
226 |
+
state_dict = mlflow.pytorch.load_state_dict(artifact_uri)
|
227 |
+
|
228 |
+
# 6. load model of your run
|
229 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
230 |
+
# model = mlflow.pytorch.load_model(os.path.join(
|
231 |
+
# artifact_uri, "model"), map_location=torch.device(DEVICE))
|
232 |
+
model = fetch_from_mlflow(os.path.join(
|
233 |
+
artifact_uri, "model"), use_cache=True, download_model=download_model)
|
234 |
+
|
235 |
+
return state_dict, model
|
236 |
+
|
237 |
+
def data_loader_mean_and_std(data_loader, transform=None):
|
238 |
+
means = []
|
239 |
+
stds = []
|
240 |
+
for x, y in data_loader:
|
241 |
+
if transform is not None:
|
242 |
+
x = transform(x)
|
243 |
+
means.append(x.mean(dim=(0, 2, 3)).unsqueeze(0))
|
244 |
+
stds.append(x.std(dim=(0, 2, 3)).unsqueeze(0))
|
245 |
+
return torch.cat(means).mean(dim=0), torch.cat(stds).mean(dim=0)
|
246 |
+
|
247 |
+
def fetch_runs_list_mlflow(experiment):
|
248 |
+
runs = mlflow.search_runs(experiment.experiment_id)
|
249 |
+
runs.to_pickle('cache/runs_names.pkl') # where to save it, usually as a .pkl
|
250 |
+
return runs
|
251 |
+
|
252 |
+
def fetch_from_mlflow(uri, use_cache=True, download_model=True):
|
253 |
+
cache_loc = os.path.join('cache', uri.split('//')[1]) + '.pt'
|
254 |
+
if use_cache and os.path.exists(cache_loc):
|
255 |
+
print(f'loading cached model from {cache_loc} ...')
|
256 |
+
return torch.load(cache_loc)
|
257 |
+
else:
|
258 |
+
print(f'fetching model from {uri} ...')
|
259 |
+
model = mlflow.pytorch.load_model(uri)
|
260 |
+
os.makedirs(os.path.dirname(cache_loc), exist_ok=True)
|
261 |
+
if download_model:
|
262 |
+
torch.save(model, cache_loc, pickle_module=mlflow.pytorch.pickle_module)
|
263 |
+
return model
|
264 |
+
|
265 |
+
|
266 |
+
def display_mlflow_run_info(run):
|
267 |
+
uri = mlflow.get_tracking_uri()
|
268 |
+
experiment_id = run.info.experiment_id
|
269 |
+
experiment_name = mlflow.get_experiment(experiment_id).name
|
270 |
+
run_id = run.info.run_id
|
271 |
+
run_name = run.data.tags['mlflow.runName']
|
272 |
+
experiment_url = f'{uri}/#/experiments/{experiment_id}'
|
273 |
+
run_url = f'{experiment_url}/runs/{run_id}'
|
274 |
+
|
275 |
+
print(f'view results at {run_url}')
|
276 |
+
display(Markdown(
|
277 |
+
f"[<a href='{experiment_url}'>experiment {experiment_id} '{experiment_name}'</a>]"
|
278 |
+
f" > "
|
279 |
+
f"[<a href='{run_url}'>run '{run_name}' {run_id}</a>]"
|
280 |
+
))
|
281 |
+
print('')
|
282 |
+
|
283 |
+
|
284 |
+
def get_train_test_indices_drone(df, frac, seed=None):
|
285 |
+
""" Split indices of a DataFrame with binary and balanced labels into balanced subindices
|
286 |
+
|
287 |
+
Args:
|
288 |
+
df (pd.DataFrame): {0,1}-labeled data
|
289 |
+
frac (float): fraction of indicies in first subset
|
290 |
+
random_seed (int): random seed used as random state in np.random and as argument for random.seed()
|
291 |
+
Returns:
|
292 |
+
train_indices (torch.tensor): balanced subset of indices corresponding to rows in the DataFrame
|
293 |
+
test_indices (torch.tensor): balanced subset of indices corresponding to rows in the DataFrame
|
294 |
+
"""
|
295 |
+
|
296 |
+
split_idx = int(len(df) * frac / 2)
|
297 |
+
df_with = df[df['label'] == 1]
|
298 |
+
df_without = df[df['label'] == 0]
|
299 |
+
|
300 |
+
np.random.seed(seed)
|
301 |
+
df_with_train = df_with.sample(n=split_idx, random_state=seed)
|
302 |
+
df_with_test = df_with.drop(df_with_train.index)
|
303 |
+
|
304 |
+
df_without_train = df_without.sample(n=split_idx, random_state=seed)
|
305 |
+
df_without_test = df_without.drop(df_without_train.index)
|
306 |
+
|
307 |
+
train_indices = list(df_without_train.index) + list(df_with_train.index)
|
308 |
+
test_indices = list(df_without_test.index) + list(df_with_test.index)
|
309 |
+
|
310 |
+
""""
|
311 |
+
print('fraction of 1-label in train set: {}'.format(len(df_with_train)/(len(df_with_train) + len(df_without_train))))
|
312 |
+
print('fraction of 1-label in test set: {}'.format(len(df_with_test)/(len(df_with_test) + len(df_with_test))))
|
313 |
+
"""
|
314 |
+
|
315 |
+
return train_indices, test_indices
|
316 |
+
|
317 |
+
|
318 |
+
def smp_get_loss(loss):
|
319 |
+
if loss == "Dice":
|
320 |
+
return smp.losses.DiceLoss(mode='binary', from_logits=True)
|
321 |
+
if loss == "BCE":
|
322 |
+
return nn.BCELoss()
|
323 |
+
elif loss == "BCEWithLogits":
|
324 |
+
return smp.losses.BCEWithLogitsLoss()
|
325 |
+
elif loss == "DicyBCE":
|
326 |
+
from pytorch_toolbelt import losses as ptbl
|
327 |
+
return ptbl.JointLoss(ptbl.DiceLoss(mode='binary', from_logits=False),
|
328 |
+
nn.BCELoss(),
|
329 |
+
first_weight=args.dice_weight,
|
330 |
+
second_weight=args.bce_weight)
|
utils/dataset.py
ADDED
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import rawpy
|
4 |
+
import random
|
5 |
+
from PIL import Image
|
6 |
+
import tifffile as tiff
|
7 |
+
import zipfile
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
13 |
+
from sklearn.model_selection import StratifiedShuffleSplit
|
14 |
+
|
15 |
+
if not os.path.exists('README.md'): # set pwd to root
|
16 |
+
os.chdir('..')
|
17 |
+
|
18 |
+
from utils.splitting import split_img
|
19 |
+
from utils.base import np2torch, torch2np, b2_download_folder
|
20 |
+
|
21 |
+
IMAGE_FILE_TYPES = ['dng', 'png', 'tif', 'tiff']
|
22 |
+
|
23 |
+
|
24 |
+
def get_dataset(name, I_ratio=1.0):
|
25 |
+
# DroneDataset
|
26 |
+
if name in ('DC', 'Drone', 'DroneClassification', 'DroneDatasetClassificationTiled'):
|
27 |
+
return DroneDatasetClassificationTiled(I_ratio=I_ratio)
|
28 |
+
if name in ('DS', 'DroneSegmentation', 'DroneDatasetSegmentationTiled'):
|
29 |
+
return DroneDatasetSegmentationTiled(I_ratio=I_ratio)
|
30 |
+
|
31 |
+
# MicroscopyDataset
|
32 |
+
if name in ('M', 'Microscopy', 'MicroscopyDataset'):
|
33 |
+
return MicroscopyDataset(I_ratio=I_ratio)
|
34 |
+
|
35 |
+
# for testing
|
36 |
+
if name in ('DSF', 'DroneDatasetSegmentationFull'):
|
37 |
+
return DroneDatasetSegmentationFull(I_ratio=I_ratio)
|
38 |
+
if name in ('MRGB', 'MicroscopyRGB', 'MicroscopyDatasetRGB'):
|
39 |
+
return MicroscopyDatasetRGB(I_ratio=I_ratio)
|
40 |
+
|
41 |
+
raise ValueError(name)
|
42 |
+
|
43 |
+
|
44 |
+
def load_image(path):
|
45 |
+
file_type = path.split('.')[-1].lower()
|
46 |
+
if file_type == 'dng':
|
47 |
+
img = rawpy.imread(path).raw_image_visible
|
48 |
+
elif file_type == 'tiff' or file_type == 'tif':
|
49 |
+
img = np.array(tiff.imread(path), dtype=np.float32)
|
50 |
+
else:
|
51 |
+
img = np.array(Image.open(path), dtype=np.float32)
|
52 |
+
return img
|
53 |
+
|
54 |
+
|
55 |
+
def list_images_in_dir(path):
|
56 |
+
image_list = [os.path.join(path, img_name)
|
57 |
+
for img_name in sorted(os.listdir(path))
|
58 |
+
if img_name.split('.')[-1].lower() in IMAGE_FILE_TYPES]
|
59 |
+
return image_list
|
60 |
+
|
61 |
+
|
62 |
+
class ImageFolderDataset(Dataset):
|
63 |
+
"""Creates a dataset of images in img_dir and corresponding masks in mask_dir.
|
64 |
+
Corresponding mask files need to contain the filename of the image.
|
65 |
+
Files are expected to be of the same filetype.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
img_dir (str): path to image folder
|
69 |
+
mask_dir (str): path to mask folder
|
70 |
+
transform (callable, optional): transformation to apply to image and mask
|
71 |
+
bits (int, optional): normalize image by dividing by 2^bits - 1
|
72 |
+
"""
|
73 |
+
|
74 |
+
task = 'classification'
|
75 |
+
|
76 |
+
def __init__(self, img_dir, labels, transform=None, bits=1):
|
77 |
+
|
78 |
+
self.img_dir = img_dir
|
79 |
+
self.labels = labels
|
80 |
+
|
81 |
+
self.images = list_images_in_dir(img_dir)
|
82 |
+
|
83 |
+
assert len(self.images) == len(self.labels)
|
84 |
+
|
85 |
+
self.transform = transform
|
86 |
+
self.bits = bits
|
87 |
+
|
88 |
+
def __repr__(self):
|
89 |
+
rep = f"{type(self).__name__}: ImageFolderDataset[{len(self.images)}]"
|
90 |
+
for n, (img, label) in enumerate(zip(self.images, self.labels)):
|
91 |
+
rep += f'\nimage: {img}\tlabel: {label}'
|
92 |
+
if n > 10:
|
93 |
+
rep += '\n...'
|
94 |
+
break
|
95 |
+
return rep
|
96 |
+
|
97 |
+
def __len__(self):
|
98 |
+
return len(self.images)
|
99 |
+
|
100 |
+
def __getitem__(self, idx):
|
101 |
+
|
102 |
+
label = self.labels[idx]
|
103 |
+
|
104 |
+
img = load_image(self.images[idx])
|
105 |
+
img = img / (2**self.bits - 1)
|
106 |
+
if self.transform is not None:
|
107 |
+
img = self.transform(img)
|
108 |
+
|
109 |
+
if len(img.shape) == 2:
|
110 |
+
assert img.shape == (256, 256), f"Invalid size for {self.images[idx]}"
|
111 |
+
else:
|
112 |
+
assert img.shape == (3, 256, 256), f"Invalid size for {self.images[idx]}"
|
113 |
+
|
114 |
+
return img, label
|
115 |
+
|
116 |
+
|
117 |
+
class ImageFolderDatasetSegmentation(Dataset):
|
118 |
+
"""Creates a dataset of images in `img_dir` and corresponding masks in `mask_dir`.
|
119 |
+
Corresponding mask files need to contain the filename of the image.
|
120 |
+
Files are expected to be of the same filetype.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
img_dir (str): path to image folder
|
124 |
+
mask_dir (str): path to mask folder
|
125 |
+
transform (callable, optional): transformation to apply to image and mask
|
126 |
+
bits (int, optional): normalize image by dividing by 2^bits - 1
|
127 |
+
"""
|
128 |
+
|
129 |
+
task = 'segmentation'
|
130 |
+
|
131 |
+
def __init__(self, img_dir, mask_dir, transform=None, bits=1):
|
132 |
+
|
133 |
+
self.img_dir = img_dir
|
134 |
+
self.mask_dir = mask_dir
|
135 |
+
|
136 |
+
self.images = list_images_in_dir(img_dir)
|
137 |
+
self.masks = list_images_in_dir(mask_dir)
|
138 |
+
|
139 |
+
check_image_folder_consistency(self.images, self.masks)
|
140 |
+
|
141 |
+
self.transform = transform
|
142 |
+
self.bits = bits
|
143 |
+
|
144 |
+
def __repr__(self):
|
145 |
+
rep = f"{type(self).__name__}: ImageFolderDatasetSegmentation[{len(self.images)}]"
|
146 |
+
for n, (img, mask) in enumerate(zip(self.images, self.masks)):
|
147 |
+
rep += f'\nimage: {img}\tmask: {mask}'
|
148 |
+
if n > 10:
|
149 |
+
rep += '\n...'
|
150 |
+
break
|
151 |
+
return rep
|
152 |
+
|
153 |
+
def __len__(self):
|
154 |
+
return len(self.images)
|
155 |
+
|
156 |
+
def __getitem__(self, idx):
|
157 |
+
|
158 |
+
img = load_image(self.images[idx])
|
159 |
+
mask = load_image(self.masks[idx])
|
160 |
+
|
161 |
+
img = img / (2**self.bits - 1)
|
162 |
+
mask = (mask > 0).astype(np.float32)
|
163 |
+
|
164 |
+
if self.transform is not None:
|
165 |
+
img = self.transform(img)
|
166 |
+
|
167 |
+
return img, mask
|
168 |
+
|
169 |
+
class MultiIntensity(Dataset):
|
170 |
+
"""Wrap datasets with different intesities
|
171 |
+
|
172 |
+
Args:
|
173 |
+
datasets (list): list of datasets to wrap
|
174 |
+
"""
|
175 |
+
def __init__(self, datasets):
|
176 |
+
self.dataset = datasets[0]
|
177 |
+
|
178 |
+
for d in range(1,len(datasets)):
|
179 |
+
self.dataset.images = self.dataset.images+datasets[d].images
|
180 |
+
self.dataset.labels = self.dataset.labels+datasets[d].labels
|
181 |
+
|
182 |
+
def __len__(self):
|
183 |
+
return len(self.dataset)
|
184 |
+
|
185 |
+
def __repr__(self):
|
186 |
+
return f"Subset [{len(self.dataset)}] of " + repr(self.dataset)
|
187 |
+
|
188 |
+
def __getitem__(self, idx):
|
189 |
+
x, y = self.dataset[idx]
|
190 |
+
if self.transform is not None:
|
191 |
+
x = self.transform(x)
|
192 |
+
return x, y
|
193 |
+
|
194 |
+
class Subset(Dataset):
|
195 |
+
"""Define a subset of a dataset by only selecting given indices.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
dataset (Dataset): full dataset
|
199 |
+
indices (list): subset indices
|
200 |
+
"""
|
201 |
+
|
202 |
+
def __init__(self, dataset, indices=None, transform=None):
|
203 |
+
self.dataset = dataset
|
204 |
+
self.indices = indices if indices is not None else range(len(dataset))
|
205 |
+
self.transform = transform
|
206 |
+
|
207 |
+
def __len__(self):
|
208 |
+
return len(self.indices)
|
209 |
+
|
210 |
+
def __repr__(self):
|
211 |
+
return f"Subset [{len(self)}] of " + repr(self.dataset)
|
212 |
+
|
213 |
+
def __getitem__(self, idx):
|
214 |
+
x, y = self.dataset[self.indices[idx]]
|
215 |
+
if self.transform is not None:
|
216 |
+
x = self.transform(x)
|
217 |
+
return x, y
|
218 |
+
|
219 |
+
|
220 |
+
class DroneDatasetSegmentationFull(ImageFolderDatasetSegmentation):
|
221 |
+
"""Dataset consisting of full-sized numpy images and masks. Images are normalized to range [0, 1].
|
222 |
+
"""
|
223 |
+
|
224 |
+
black_level = [0.0625, 0.0626, 0.0625, 0.0626]
|
225 |
+
white_balance = [2.86653646, 1., 1.73079425]
|
226 |
+
colour_matrix = [1.50768983, -0.33571374, -0.17197604, -0.23048614,
|
227 |
+
1.70698738, -0.47650126, -0.03119153, -0.32803956, 1.35923111]
|
228 |
+
camera_parameters = black_level, white_balance, colour_matrix
|
229 |
+
|
230 |
+
def __init__(self, I_ratio=1.0, transform=None, force_download=False, bits=16):
|
231 |
+
|
232 |
+
assert I_ratio in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]
|
233 |
+
|
234 |
+
img_dir = f'data/drone/images_full/raw_scale{int(I_ratio*100):03d}'
|
235 |
+
mask_dir = 'data/drone/masks_full'
|
236 |
+
|
237 |
+
download_drone_dataset(force_download) # XXX: zip files and add checksum? date?
|
238 |
+
|
239 |
+
super().__init__(img_dir=img_dir, mask_dir=mask_dir, transform=transform, bits=bits)
|
240 |
+
|
241 |
+
|
242 |
+
class DroneDatasetSegmentationTiled(ImageFolderDatasetSegmentation):
|
243 |
+
"""Dataset consisting of tiled numpy images and masks. Images are in range [0, 1]
|
244 |
+
Args:
|
245 |
+
tile_size (int, optional): size of the tiled images. Defaults to 256.
|
246 |
+
"""
|
247 |
+
|
248 |
+
camera_parameters = DroneDatasetSegmentationFull.camera_parameters
|
249 |
+
|
250 |
+
def __init__(self, I_ratio=1.0, transform=None):
|
251 |
+
|
252 |
+
tile_size = 256
|
253 |
+
|
254 |
+
img_dir = f'data/drone/images_tiles_{tile_size}/raw_scale{int(I_ratio*100):03d}'
|
255 |
+
mask_dir = f'data/drone/masks_tiles_{tile_size}'
|
256 |
+
|
257 |
+
if not os.path.exists(img_dir) or not os.path.exists(mask_dir):
|
258 |
+
dataset_full = DroneDatasetSegmentationFull(I_ratio=I_ratio, bits=1)
|
259 |
+
print("tiling dataset..")
|
260 |
+
create_tiles_dataset(dataset_full, img_dir, mask_dir, tile_size=tile_size)
|
261 |
+
|
262 |
+
super().__init__(img_dir=img_dir, mask_dir=mask_dir, transform=transform, bits=16)
|
263 |
+
|
264 |
+
|
265 |
+
class DroneDatasetClassificationTiled(ImageFolderDataset):
|
266 |
+
|
267 |
+
camera_parameters = DroneDatasetSegmentationFull.camera_parameters
|
268 |
+
|
269 |
+
def __init__(self, I_ratio=1.0, transform=None):
|
270 |
+
|
271 |
+
random_state = 72
|
272 |
+
tile_size = 256
|
273 |
+
thr = 0.01
|
274 |
+
|
275 |
+
img_dir = f'data/drone/classification/images_tiles_{tile_size}/raw_scale{int(I_ratio*100):03d}_thr_{thr}'
|
276 |
+
mask_dir = f'data/drone/classification/masks_tiles_{tile_size}_thr_{thr}'
|
277 |
+
df_path = f'data/drone/classification/dataset_tiles_{tile_size}_{random_state}_{thr}.csv'
|
278 |
+
|
279 |
+
if not os.path.exists(img_dir) or not os.path.exists(mask_dir):
|
280 |
+
dataset_full = DroneDatasetSegmentationFull(I_ratio=I_ratio, bits=1)
|
281 |
+
print("tiling dataset..")
|
282 |
+
create_tiles_dataset_binary(dataset_full, img_dir, mask_dir, random_state, thr, tile_size=tile_size)
|
283 |
+
|
284 |
+
self.classes = ['car', 'no car']
|
285 |
+
self.df = pd.read_csv(df_path)
|
286 |
+
labels = self.df['label'].to_list()
|
287 |
+
|
288 |
+
super().__init__(img_dir=img_dir, labels=labels, transform=transform, bits=16)
|
289 |
+
|
290 |
+
images, class_labels = read_label_csv(self.df)
|
291 |
+
self.images = [os.path.join(self.img_dir, image) for image in images]
|
292 |
+
self.labels = class_labels
|
293 |
+
|
294 |
+
|
295 |
+
class MicroscopyDataset(ImageFolderDataset):
|
296 |
+
"""MicroscopyDataset raw images
|
297 |
+
|
298 |
+
Args:
|
299 |
+
I_ratio (float): Original image rescaled by this factor, possible values [0.01,0.05,0.1,0.25,0.5,0.75,1.0]
|
300 |
+
raw (bool): Select rgb dataset or raw dataset
|
301 |
+
transform (callable, optional): transformation to apply to image and mask
|
302 |
+
bits (int, optional): normalize image by dividing by 2^bits - 1
|
303 |
+
"""
|
304 |
+
|
305 |
+
black_level = [9.834368023181512e-06, 9.834368023181512e-06, 9.834368023181512e-06, 9.834368023181512e-06]
|
306 |
+
white_balance = [-0.6567, 1.9673, 3.5304]
|
307 |
+
colour_matrix = [-2.0338, 0.0933, 0.4157, -0.0286, 2.6464, -0.0574, -0.5516, -0.0947, 2.9308]
|
308 |
+
|
309 |
+
camera_parameters = black_level, white_balance, colour_matrix
|
310 |
+
|
311 |
+
dataset_mean = [0.91, 0.84, 0.94]
|
312 |
+
dataset_std = [0.08, 0.12, 0.05]
|
313 |
+
|
314 |
+
def __init__(self, I_ratio=1.0, transform=None, bits=16, force_download=False):
|
315 |
+
|
316 |
+
assert I_ratio in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]
|
317 |
+
|
318 |
+
download_microscopy_dataset(force_download=force_download)
|
319 |
+
|
320 |
+
self.img_dir = f'data/microscopy/images/raw_scale{int(I_ratio*100):03d}'
|
321 |
+
self.transform = transform
|
322 |
+
self.bits = bits
|
323 |
+
|
324 |
+
self.label_file = 'data/microscopy/labels/Ma190c_annotations.dat'
|
325 |
+
|
326 |
+
self.valid_classes = ['BAS', 'EBO', 'EOS', 'KSC', 'LYA', 'LYT', 'MMZ', 'MOB',
|
327 |
+
'MON', 'MYB', 'MYO', 'NGB', 'NGS', 'PMB', 'PMO', 'UNC']
|
328 |
+
|
329 |
+
self.invalid_files = ['Ma190c_lame3_zone13_composite_Mcropped_2.tiff', ]
|
330 |
+
|
331 |
+
images, class_labels = read_label_file(self.label_file)
|
332 |
+
|
333 |
+
# filter classes with low appearance
|
334 |
+
self.valid_classes = [class_label for class_label in self.valid_classes
|
335 |
+
if class_labels.count(class_label) > 4]
|
336 |
+
|
337 |
+
# remove invalid classes and invalid files from (images, class_labels)
|
338 |
+
images, class_labels = list(zip(*[
|
339 |
+
(image, class_label)
|
340 |
+
for image, class_label in zip(images, class_labels)
|
341 |
+
if class_label in self.valid_classes and image not in self.invalid_files
|
342 |
+
]))
|
343 |
+
|
344 |
+
self.classes = list(sorted({*class_labels}))
|
345 |
+
|
346 |
+
# store full path
|
347 |
+
self.images = [os.path.join(self.img_dir, image) for image in images]
|
348 |
+
|
349 |
+
# reindex labels
|
350 |
+
self.labels = [self.classes.index(class_label) for class_label in class_labels]
|
351 |
+
|
352 |
+
|
353 |
+
class MicroscopyDatasetRGB(MicroscopyDataset):
|
354 |
+
"""MicroscopyDataset RGB images
|
355 |
+
|
356 |
+
Args:
|
357 |
+
I_ratio (float): Original image rescaled by this factor, possible values [0.01,0.05,0.1,0.25,0.5,0.75,1.0]
|
358 |
+
raw (bool): Select rgb dataset or raw dataset
|
359 |
+
transform (callable, optional): transformation to apply to image and mask
|
360 |
+
bits (int, optional): normalize image by dividing by 2^bits - 1
|
361 |
+
"""
|
362 |
+
camera_parameters = None
|
363 |
+
|
364 |
+
dataset_mean = None
|
365 |
+
dataset_std = None
|
366 |
+
|
367 |
+
def __init__(self, I_ratio=1.0, transform=None, bits=16, force_download=False):
|
368 |
+
super().__init__(I_ratio=I_ratio, transform=transform, bits=bits, force_download=force_download)
|
369 |
+
self.images = [image.replace('raw', 'rgb') for image in self.images] # XXX: hack
|
370 |
+
|
371 |
+
|
372 |
+
def read_label_file(label_file_path):
|
373 |
+
|
374 |
+
images = []
|
375 |
+
class_labels = []
|
376 |
+
|
377 |
+
with open(label_file_path, "rb") as data:
|
378 |
+
for line in data:
|
379 |
+
file_name, class_label = line.decode("utf-8").split()
|
380 |
+
image = file_name + '.tiff'
|
381 |
+
images.append(image)
|
382 |
+
class_labels.append(class_label)
|
383 |
+
|
384 |
+
return images, class_labels
|
385 |
+
|
386 |
+
|
387 |
+
def read_label_csv(df):
|
388 |
+
|
389 |
+
images = []
|
390 |
+
class_labels = []
|
391 |
+
|
392 |
+
for file_name, label in zip(df['file name'], df['label']):
|
393 |
+
image = file_name + '.tif'
|
394 |
+
images.append(image)
|
395 |
+
class_labels.append(int(label))
|
396 |
+
return images, class_labels
|
397 |
+
|
398 |
+
|
399 |
+
def download_drone_dataset(force_download):
|
400 |
+
b2_download_folder('drone/images', 'data/drone/images_full', force_download=force_download)
|
401 |
+
b2_download_folder('drone/masks', 'data/drone/masks_full', force_download=force_download)
|
402 |
+
unzip_drone_images()
|
403 |
+
|
404 |
+
|
405 |
+
def download_microscopy_dataset(force_download):
|
406 |
+
b2_download_folder('Data histopathology/WhiteCellsImages',
|
407 |
+
'data/microscopy/images', force_download=force_download)
|
408 |
+
b2_download_folder('Data histopathology/WhiteCellsLabels',
|
409 |
+
'data/microscopy/labels', force_download=force_download)
|
410 |
+
unzip_microscopy_images()
|
411 |
+
|
412 |
+
|
413 |
+
def unzip_microscopy_images():
|
414 |
+
|
415 |
+
if os.path.isfile('data/microscopy/labels/.bzEmpty'):
|
416 |
+
os.remove('data/microscopy/labels/.bzEmpty')
|
417 |
+
|
418 |
+
for file in os.listdir('data/microscopy/images'):
|
419 |
+
if file.endswith(".zip"):
|
420 |
+
zip = zipfile.ZipFile(os.path.join('data/microscopy/images', file))
|
421 |
+
zip.extractall('data/microscopy/images')
|
422 |
+
os.remove(os.path.join('data/microscopy/images', file))
|
423 |
+
|
424 |
+
def unzip_drone_images():
|
425 |
+
|
426 |
+
if os.path.isfile('data/drone/masks_full/.bzEmpty'):
|
427 |
+
os.remove('data/drone/masks_full/.bzEmpty')
|
428 |
+
|
429 |
+
for file in os.listdir('data/drone/images_full'):
|
430 |
+
if file.endswith(".zip"):
|
431 |
+
zip = zipfile.ZipFile(os.path.join('data/drone/images_full', file))
|
432 |
+
zip.extractall('data/drone/images_full')
|
433 |
+
os.remove(os.path.join('data/drone/images_full', file))
|
434 |
+
|
435 |
+
|
436 |
+
def create_tiles_dataset(dataset, img_dir, mask_dir, tile_size=256):
|
437 |
+
for folder in [img_dir, mask_dir]:
|
438 |
+
if not os.path.exists(folder):
|
439 |
+
os.makedirs(folder)
|
440 |
+
for n, (img, mask) in enumerate(dataset):
|
441 |
+
tiled_img = split_img(img, ROIs=(tile_size, tile_size), step=(tile_size, tile_size))
|
442 |
+
tiled_mask = split_img(mask, ROIs=(tile_size, tile_size), step=(tile_size, tile_size))
|
443 |
+
tiled_img, tiled_mask = class_detection(tiled_img, tiled_mask) # Remove images without cars in it
|
444 |
+
for i, (sub_img, sub_mask) in enumerate(zip(tiled_img, tiled_mask)):
|
445 |
+
tile_id = f"{n:02d}_{i:05d}"
|
446 |
+
Image.fromarray(sub_img).save(os.path.join(img_dir, tile_id + '.tif'))
|
447 |
+
Image.fromarray(sub_mask > 0).save(os.path.join(mask_dir, tile_id + '.png'))
|
448 |
+
|
449 |
+
|
450 |
+
def create_tiles_dataset_binary(dataset, img_dir, mask_dir, random_state, thr, tile_size=256):
|
451 |
+
|
452 |
+
for folder in [img_dir, mask_dir]:
|
453 |
+
if not os.path.exists(folder):
|
454 |
+
os.makedirs(folder)
|
455 |
+
|
456 |
+
ids = []
|
457 |
+
labels = []
|
458 |
+
|
459 |
+
for n, (img, mask) in enumerate(dataset):
|
460 |
+
tiled_img = split_img(img, ROIs=(tile_size, tile_size), step=(tile_size, tile_size))
|
461 |
+
tiled_mask = split_img(mask, ROIs=(tile_size, tile_size), step=(tile_size, tile_size))
|
462 |
+
|
463 |
+
X_with, X_without, Y_with, Y_without = binary_class_detection(
|
464 |
+
tiled_img, tiled_mask, random_state, thr) # creates balanced arrays with class and without class
|
465 |
+
|
466 |
+
for i, (sub_X_with, sub_Y_with) in enumerate(zip(X_with, Y_with)):
|
467 |
+
tile_id = f"{n:02d}_{i:05d}"
|
468 |
+
ids.append(tile_id)
|
469 |
+
labels.append(0)
|
470 |
+
Image.fromarray(sub_X_with).save(os.path.join(img_dir, tile_id + '.tif'))
|
471 |
+
Image.fromarray(sub_Y_with > 0).save(os.path.join(mask_dir, tile_id + '.png'))
|
472 |
+
for j, (sub_X_without, sub_Y_without) in enumerate(zip(X_without, Y_without)):
|
473 |
+
tile_id = f"{n:02d}_{i+1+j:05d}"
|
474 |
+
ids.append(tile_id)
|
475 |
+
labels.append(1)
|
476 |
+
Image.fromarray(sub_X_without).save(os.path.join(img_dir, tile_id + '.tif'))
|
477 |
+
Image.fromarray(sub_Y_without > 0).save(os.path.join(mask_dir, tile_id + '.png'))
|
478 |
+
# Image.fromarray(sub_mask).save(os.path.join(mask_dir, tile_id + '.png'))
|
479 |
+
|
480 |
+
df = pd.DataFrame({'file name': ids, 'label': labels})
|
481 |
+
|
482 |
+
df_loc = f'data/drone/classification/dataset_tiles_{tile_size}_{random_state}_{thr}.csv'
|
483 |
+
df.to_csv(df_loc)
|
484 |
+
|
485 |
+
return
|
486 |
+
|
487 |
+
|
488 |
+
def class_detection(X, Y):
|
489 |
+
"""Split dataset in images which has the class in the target
|
490 |
+
|
491 |
+
Args:
|
492 |
+
X (ndarray): input image
|
493 |
+
Y (ndarray): target with segmentation map (images with {0,1} values where it is 1 when there is the class)
|
494 |
+
Returns:
|
495 |
+
X_with_class (ndarray): input regions with the selected class
|
496 |
+
Y_with_class (ndarray): target regions with the selected class
|
497 |
+
X_without_class (ndarray): input regions without the selected class
|
498 |
+
Y_without_class (ndarray): target regions without the selected class
|
499 |
+
"""
|
500 |
+
|
501 |
+
with_class = []
|
502 |
+
without_class = []
|
503 |
+
for i, img in enumerate(Y):
|
504 |
+
if img.mean() == 0:
|
505 |
+
without_class.append(i)
|
506 |
+
else:
|
507 |
+
with_class.append(i)
|
508 |
+
|
509 |
+
X_with_class = np.delete(X, without_class, 0)
|
510 |
+
Y_with_class = np.delete(Y, without_class, 0)
|
511 |
+
|
512 |
+
return X_with_class, Y_with_class
|
513 |
+
|
514 |
+
|
515 |
+
def binary_class_detection(X, Y, random_seed, thr):
|
516 |
+
"""Splits subimages in subimages with the selected class and without the selected class by calculating the mean of the submasks; subimages with 0 < submask.mean()<=thr are disregared
|
517 |
+
|
518 |
+
|
519 |
+
|
520 |
+
Args:
|
521 |
+
X (ndarray): input image
|
522 |
+
Y (ndarray): target with segmentation map (images with {0,1} values where it is 1 when there is the class)
|
523 |
+
thr (flaot): sub images are not considered if 0 < sub_target.mean() <= thr
|
524 |
+
balanced (bool): number of returned sub images is equal for both classes if true
|
525 |
+
random_seed (None or int): selection of sub images in class with more elements according to random_seed if balanced
|
526 |
+
Returns:
|
527 |
+
X_with_class (ndarray): input regions with the selected class
|
528 |
+
Y_with_class (ndarray): target regions with the selected class
|
529 |
+
X_without_class (ndarray): input regions without the selected class
|
530 |
+
Y_without_class (ndarray): target regions without the selected class
|
531 |
+
"""
|
532 |
+
|
533 |
+
with_class = []
|
534 |
+
without_class = []
|
535 |
+
no_class = []
|
536 |
+
|
537 |
+
for i, img in enumerate(Y):
|
538 |
+
m = img.mean()
|
539 |
+
if m == 0:
|
540 |
+
without_class.append(i)
|
541 |
+
else:
|
542 |
+
if m > thr:
|
543 |
+
with_class.append(i)
|
544 |
+
else:
|
545 |
+
no_class.append(i)
|
546 |
+
|
547 |
+
N = len(with_class)
|
548 |
+
M = len(without_class)
|
549 |
+
random.seed(random_seed)
|
550 |
+
if N <= M:
|
551 |
+
random.shuffle(without_class)
|
552 |
+
with_class.extend(without_class[:M - N])
|
553 |
+
else:
|
554 |
+
random.shuffle(with_class)
|
555 |
+
without_class.extend(with_class[:N - M])
|
556 |
+
|
557 |
+
X_with_class = np.delete(X, without_class + no_class, 0)
|
558 |
+
X_without_class = np.delete(X, with_class + no_class, 0)
|
559 |
+
Y_with_class = np.delete(Y, without_class + no_class, 0)
|
560 |
+
Y_without_class = np.delete(Y, with_class + no_class, 0)
|
561 |
+
|
562 |
+
return X_with_class, X_without_class, Y_with_class, Y_without_class
|
563 |
+
|
564 |
+
|
565 |
+
def make_dataloader(dataset, batch_size, shuffle=True):
|
566 |
+
|
567 |
+
X, Y = dataset
|
568 |
+
|
569 |
+
X, Y = np2torch(X), np2torch(Y)
|
570 |
+
|
571 |
+
dataset = TensorDataset(X, Y)
|
572 |
+
dataset = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
573 |
+
|
574 |
+
return dataset
|
575 |
+
|
576 |
+
|
577 |
+
def check_image_folder_consistency(images, masks):
|
578 |
+
file_type_images = images[0].split('.')[-1].lower()
|
579 |
+
file_type_masks = masks[0].split('.')[-1].lower()
|
580 |
+
assert len(images) == len(masks), "images / masks length mismatch"
|
581 |
+
for img_file, mask_file in zip(images, masks):
|
582 |
+
img_name = img_file.split('/')[-1].split('.')[0]
|
583 |
+
assert img_name in mask_file, f"image {img_file} corresponds to {mask_file}?"
|
584 |
+
assert img_file.split('.')[-1].lower() == file_type_images, \
|
585 |
+
f"image file {img_file} file type mismatch. Shoule be: {file_type_images}"
|
586 |
+
assert mask_file.split('.')[-1].lower() == file_type_masks, \
|
587 |
+
f"image file {mask_file} file type mismatch. Should be: {file_type_masks}"
|
588 |
+
|
589 |
+
|
590 |
+
def k_fold(dataset, n_splits: int, seed: int, train_size: float):
|
591 |
+
"""Split dataset in subsets for cross-validation
|
592 |
+
|
593 |
+
Args:
|
594 |
+
dataset (class): dataset to split
|
595 |
+
n_split (int): Number of re-shuffling & splitting iterations.
|
596 |
+
seed (int): seed for k_fold splitting
|
597 |
+
train_size (float): should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the train split.
|
598 |
+
Returns:
|
599 |
+
idxs (list): indeces for splitting the dataset. The list contain n_split pair of train/test indeces.
|
600 |
+
"""
|
601 |
+
if hasattr(dataset, 'labels'):
|
602 |
+
x = dataset.images
|
603 |
+
y = dataset.labels
|
604 |
+
elif hasattr(dataset, 'masks'):
|
605 |
+
x = dataset.images
|
606 |
+
y = dataset.masks
|
607 |
+
|
608 |
+
idxs = []
|
609 |
+
|
610 |
+
if dataset.task == 'classification':
|
611 |
+
sss = StratifiedShuffleSplit(n_splits=n_splits, train_size=train_size, random_state=seed)
|
612 |
+
|
613 |
+
for idxs_train, idxs_test in sss.split(x, y):
|
614 |
+
idxs.append((idxs_train.tolist(), idxs_test.tolist()))
|
615 |
+
|
616 |
+
elif dataset.task == 'segmentation':
|
617 |
+
for n in range(n_splits):
|
618 |
+
split_idx = int(len(dataset) * train_size)
|
619 |
+
indices = np.random.permutation(len(dataset))
|
620 |
+
idxs.append((indices[:split_idx].tolist(), indices[split_idx:].tolist()))
|
621 |
+
|
622 |
+
return idxs
|
utils/debug.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import inspect
|
4 |
+
from functools import reduce, wraps
|
5 |
+
from collections.abc import Iterable
|
6 |
+
from IPython import embed
|
7 |
+
|
8 |
+
try:
|
9 |
+
get_ipython() # pylint: disable=undefined-variable
|
10 |
+
interactive_notebook = True
|
11 |
+
except:
|
12 |
+
interactive_notebook = False
|
13 |
+
|
14 |
+
_NONE = "__UNSET_VARIABLE__"
|
15 |
+
|
16 |
+
|
17 |
+
def debug_init():
|
18 |
+
debug.disable = False
|
19 |
+
debug.silent = False
|
20 |
+
debug.verbose = 2
|
21 |
+
debug.expand_ignore = ["DataLoader", "Dataset", "Subset"]
|
22 |
+
debug.max_expand = 10
|
23 |
+
debug.show_tensor = False
|
24 |
+
debug.raise_exception = True
|
25 |
+
debug.full_stack = True
|
26 |
+
debug.restore_defaults_on_exception = not interactive_notebook
|
27 |
+
debug._indent = 0
|
28 |
+
debug._stack = ""
|
29 |
+
|
30 |
+
debug.embed = embed
|
31 |
+
debug.show = debug_show
|
32 |
+
debug.pause = debug_pause
|
33 |
+
|
34 |
+
|
35 |
+
def debug_pause():
|
36 |
+
input("Press Enter to continue...")
|
37 |
+
|
38 |
+
|
39 |
+
def debug(*args, assert_true=False):
|
40 |
+
"""Decorator for debugging functions and tensors.
|
41 |
+
Will throw an exception as soon as a nan is encountered.
|
42 |
+
If used on iterables, these will be expanded and also searched for nans.
|
43 |
+
Usage:
|
44 |
+
debug(x)
|
45 |
+
Or:
|
46 |
+
@debug
|
47 |
+
def function():
|
48 |
+
...
|
49 |
+
If used as a function wrapper, all arguments will be searched and printed.
|
50 |
+
"""
|
51 |
+
|
52 |
+
single_arg = len(args) == 1
|
53 |
+
|
54 |
+
if debug.disable:
|
55 |
+
return args[0] if single_arg else None
|
56 |
+
|
57 |
+
try:
|
58 |
+
call_line = ''.join(inspect.stack()[1][4]).strip()
|
59 |
+
except:
|
60 |
+
call_line = '...'
|
61 |
+
used_as_wrapper = 'def ' == call_line[:4]
|
62 |
+
expect_return_arg = single_arg and 'debug' in call_line and call_line.split('debug')[0].strip() != ''
|
63 |
+
is_func = single_arg and hasattr(args[0], '__call__')
|
64 |
+
|
65 |
+
if is_func and (used_as_wrapper or expect_return_arg):
|
66 |
+
func = args[0]
|
67 |
+
sig_parameters = inspect.signature(func).parameters
|
68 |
+
sig_argnames = [p.name for p in sig_parameters.values()]
|
69 |
+
sig_defaults = {
|
70 |
+
k: v.default
|
71 |
+
for k, v in sig_parameters.items()
|
72 |
+
if v.default is not inspect.Parameter.empty
|
73 |
+
}
|
74 |
+
|
75 |
+
@wraps(func)
|
76 |
+
def _func(*args, **kwargs):
|
77 |
+
if debug.disable:
|
78 |
+
return func(*args, **kwargs)
|
79 |
+
|
80 |
+
if debug._indent == 0:
|
81 |
+
debug._stack = ""
|
82 |
+
stack_before = debug._stack
|
83 |
+
indent = ' ' * 4 * debug._indent
|
84 |
+
debug._indent += 1
|
85 |
+
|
86 |
+
args_kw = dict(zip(sig_argnames, args))
|
87 |
+
defaults = {k: v for k, v in sig_defaults.items()
|
88 |
+
if k not in kwargs
|
89 |
+
if k not in args_kw}
|
90 |
+
all_args = {**args_kw, **kwargs, **defaults}
|
91 |
+
|
92 |
+
func_name = None
|
93 |
+
if hasattr(func, '__name__'):
|
94 |
+
func_name = func.__name__
|
95 |
+
elif hasattr(func, '__class__'):
|
96 |
+
func_name = func.__class__.__name__
|
97 |
+
|
98 |
+
if func_name is None:
|
99 |
+
func_name = '... ' + call_line + '...'
|
100 |
+
else:
|
101 |
+
func_name = '@' + func_name + '()'
|
102 |
+
|
103 |
+
_debug_log('', indent=indent)
|
104 |
+
_debug_log(func_name, indent=indent)
|
105 |
+
|
106 |
+
debug._last_call = func
|
107 |
+
debug._last_args = all_args
|
108 |
+
debug._last_args_sig = sig_argnames
|
109 |
+
|
110 |
+
for argtype, params in [("args", args_kw.items()),
|
111 |
+
("kwargs", kwargs.items()),
|
112 |
+
("defaults", defaults.items())]:
|
113 |
+
if params:
|
114 |
+
_debug_log(f"{argtype}:", indent=indent + ' ' * 6)
|
115 |
+
for argname, arg in params:
|
116 |
+
if argname == 'self':
|
117 |
+
# _debug_log(f"- self: ...", indent=indent + ' ' * 8)
|
118 |
+
pass
|
119 |
+
else:
|
120 |
+
_debug_log(f"- {argname}: ", arg, indent + ' ' * 8, assert_true)
|
121 |
+
try:
|
122 |
+
out = func(*args, **kwargs)
|
123 |
+
except:
|
124 |
+
_debug_crash_save()
|
125 |
+
debug._stack = ""
|
126 |
+
debug._indent = 0
|
127 |
+
raise
|
128 |
+
debug.out = out
|
129 |
+
_debug_log("returned: ", out, indent, assert_true)
|
130 |
+
_debug_log('', indent=indent)
|
131 |
+
debug._indent -= 1
|
132 |
+
if not debug.full_stack:
|
133 |
+
debug._stack = stack_before
|
134 |
+
return out
|
135 |
+
return _func
|
136 |
+
else:
|
137 |
+
if debug._indent == 0:
|
138 |
+
debug._stack = ""
|
139 |
+
argname = ')'.join('('.join(call_line.split('(')[1:]).split(')')[:-1])
|
140 |
+
if assert_true:
|
141 |
+
argname = ','.join(argname.split(',')[:-1])
|
142 |
+
_debug_log(f"assert{{{argname}}} ", args[0], ' ' * 4 * debug._indent, assert_true)
|
143 |
+
else:
|
144 |
+
for arg in args:
|
145 |
+
_debug_log(f"{{{argname}}} = ", arg, ' ' * 4 * debug._indent, assert_true)
|
146 |
+
if expect_return_arg:
|
147 |
+
return args[0]
|
148 |
+
return
|
149 |
+
|
150 |
+
|
151 |
+
def is_iterable(x):
|
152 |
+
return isinstance(x, Iterable) or hasattr(x, '__getitem__') and not isinstance(x, str)
|
153 |
+
|
154 |
+
|
155 |
+
def ndarray_repr(t, assert_all=False):
|
156 |
+
exception_encountered = False
|
157 |
+
info = []
|
158 |
+
shape = tuple(t.shape)
|
159 |
+
single_entry = shape == () or shape == (1,)
|
160 |
+
if single_entry:
|
161 |
+
info.append(f"[{t.item():.4f}]")
|
162 |
+
else:
|
163 |
+
info.append(f"({', '.join(map(repr, shape))})")
|
164 |
+
invalid_sum = (~np.isfinite(t)).sum().item()
|
165 |
+
if invalid_sum:
|
166 |
+
info.append(
|
167 |
+
f"{invalid_sum} INVALID ENTR{'Y' if invalid_sum == 1 else 'IES'}")
|
168 |
+
exception_encountered = True
|
169 |
+
if debug.verbose > 1:
|
170 |
+
if not invalid_sum and not single_entry:
|
171 |
+
info.append(f"|x|={np.linalg.norm(t):.1f}")
|
172 |
+
if t.size:
|
173 |
+
info.append(f"x in [{t.min():.1f}, {t.max():.1f}]")
|
174 |
+
if debug.verbose and t.dtype != np.float:
|
175 |
+
info.append(f"dtype={str(t.dtype)}".replace("'", ''))
|
176 |
+
if assert_all:
|
177 |
+
assert_val = t.all()
|
178 |
+
if not assert_val:
|
179 |
+
exception_encountered = True
|
180 |
+
if assert_all and not exception_encountered:
|
181 |
+
output = "passed"
|
182 |
+
else:
|
183 |
+
if assert_all and not assert_val:
|
184 |
+
output = f"ndarray({info[0]})"
|
185 |
+
else:
|
186 |
+
output = f"ndarray({', '.join(info)})"
|
187 |
+
if exception_encountered and (not hasattr(debug, 'raise_exception') or debug.raise_exception):
|
188 |
+
if debug.restore_defaults_on_exception:
|
189 |
+
debug.raise_exception = False
|
190 |
+
debug.silent = False
|
191 |
+
debug.x = t
|
192 |
+
msg = output
|
193 |
+
debug._stack += output
|
194 |
+
if debug._stack and '\n' in debug._stack:
|
195 |
+
msg += '\nSTACK: ' + debug._stack
|
196 |
+
if assert_all:
|
197 |
+
assert assert_val, "Assert did not pass on " + msg
|
198 |
+
raise Exception("Invalid entries encountered in " + msg)
|
199 |
+
return output
|
200 |
+
|
201 |
+
|
202 |
+
def tensor_repr(t, assert_all=False):
|
203 |
+
exception_encountered = False
|
204 |
+
info = []
|
205 |
+
shape = tuple(t.shape)
|
206 |
+
single_entry = shape == () or shape == (1,)
|
207 |
+
if single_entry:
|
208 |
+
info.append(f"[{t.item():.3f}]")
|
209 |
+
else:
|
210 |
+
info.append(f"({', '.join(map(repr, shape))})")
|
211 |
+
invalid_sum = (~torch.isfinite(t)).sum().item()
|
212 |
+
if invalid_sum:
|
213 |
+
info.append(
|
214 |
+
f"{invalid_sum} INVALID ENTR{'Y' if invalid_sum == 1 else 'IES'}")
|
215 |
+
exception_encountered = True
|
216 |
+
if debug.verbose and t.requires_grad:
|
217 |
+
info.append('req_grad')
|
218 |
+
if debug.verbose > 2:
|
219 |
+
if t.is_leaf:
|
220 |
+
info.append('leaf')
|
221 |
+
if hasattr(t, 'retains_grad') and t.retains_grad:
|
222 |
+
info.append('retains_grad')
|
223 |
+
has_grad = (t.is_leaf or hasattr(t, 'retains_grad') and t.retains_grad) and t.grad is not None
|
224 |
+
if has_grad:
|
225 |
+
grad_invalid_sum = (~torch.isfinite(t.grad)).sum().item()
|
226 |
+
if grad_invalid_sum:
|
227 |
+
info.append(
|
228 |
+
f"GRAD {grad_invalid_sum} INVALID ENTR{'Y' if grad_invalid_sum == 1 else 'IES'}")
|
229 |
+
exception_encountered = True
|
230 |
+
if debug.verbose > 1:
|
231 |
+
if not invalid_sum and not single_entry:
|
232 |
+
info.append(f"|x|={t.float().norm():.1f}")
|
233 |
+
if t.numel():
|
234 |
+
info.append(f"x in [{t.min():.2f}, {t.max():.2f}]")
|
235 |
+
if has_grad and not grad_invalid_sum:
|
236 |
+
if single_entry:
|
237 |
+
info.append(f"grad={t.grad.float().item():.3f}")
|
238 |
+
else:
|
239 |
+
info.append(f"|grad|={t.grad.float().norm():.1f}")
|
240 |
+
if debug.verbose and t.dtype != torch.float:
|
241 |
+
info.append(f"dtype={str(t.dtype).split('.')[-1]}")
|
242 |
+
if debug.verbose and t.device.type != 'cpu':
|
243 |
+
info.append(f"device={t.device.type}")
|
244 |
+
if assert_all:
|
245 |
+
assert_val = t.all()
|
246 |
+
if not assert_val:
|
247 |
+
exception_encountered = True
|
248 |
+
if assert_all and not exception_encountered:
|
249 |
+
output = "passed"
|
250 |
+
else:
|
251 |
+
if assert_all and not assert_val:
|
252 |
+
output = f"tensor({info[0]})"
|
253 |
+
else:
|
254 |
+
output = f"tensor({', '.join(info)})"
|
255 |
+
if exception_encountered and (not hasattr(debug, 'raise_exception') or debug.raise_exception):
|
256 |
+
if debug.restore_defaults_on_exception:
|
257 |
+
debug.raise_exception = False
|
258 |
+
debug.silent = False
|
259 |
+
debug.x = t
|
260 |
+
msg = output
|
261 |
+
debug._stack += output
|
262 |
+
if debug._stack and '\n' in debug._stack:
|
263 |
+
msg += '\nSTACK: ' + debug._stack
|
264 |
+
if assert_all:
|
265 |
+
assert assert_val, "Assert did not pass on " + msg
|
266 |
+
raise Exception("Invalid entries encountered in " + msg)
|
267 |
+
return output
|
268 |
+
|
269 |
+
|
270 |
+
def _debug_crash_save():
|
271 |
+
if debug._indent:
|
272 |
+
debug.args = debug._last_args
|
273 |
+
debug.func = debug._last_call
|
274 |
+
|
275 |
+
@wraps(debug.func)
|
276 |
+
def _recall(*args, **kwargs):
|
277 |
+
call_args = {**debug.args, **kwargs, **dict(zip(debug._last_args_sig, args))}
|
278 |
+
return debug(debug.func)(**call_args)
|
279 |
+
|
280 |
+
def print_stack(stack=debug._stack):
|
281 |
+
print('\nSTACK: ' + stack)
|
282 |
+
debug.stack = print_stack
|
283 |
+
|
284 |
+
debug.recall = _recall
|
285 |
+
debug._indent = 0
|
286 |
+
|
287 |
+
|
288 |
+
def _debug_log(output, var=_NONE, indent='', assert_true=False, expand=True):
|
289 |
+
debug._stack += indent + output
|
290 |
+
if not debug.silent:
|
291 |
+
print(indent + output, end='')
|
292 |
+
if var is not _NONE:
|
293 |
+
type_str = type(var).__name__.lower()
|
294 |
+
if var is None:
|
295 |
+
_debug_log('None')
|
296 |
+
elif isinstance(var, str):
|
297 |
+
_debug_log(f"'{var}'")
|
298 |
+
elif type_str == 'ndarray':
|
299 |
+
_debug_log(ndarray_repr(var, assert_true))
|
300 |
+
if debug.show_tensor:
|
301 |
+
_debug_show_print(var, indent=indent + 4 * ' ')
|
302 |
+
# elif type_str in ('tensor', 'parameter'):
|
303 |
+
elif type_str == 'tensor':
|
304 |
+
_debug_log(tensor_repr(var, assert_true))
|
305 |
+
if debug.show_tensor:
|
306 |
+
_debug_show_print(var, indent=indent + 4 * ' ')
|
307 |
+
elif hasattr(var, 'named_parameters'):
|
308 |
+
_debug_log(type_str)
|
309 |
+
params = list(var.named_parameters())
|
310 |
+
_debug_log(f"{type_str}[{len(params)}] {{")
|
311 |
+
for k, v in params:
|
312 |
+
_debug_log(f"'{k}': ", v, indent + 6 * ' ')
|
313 |
+
_debug_log(indent + 4 * ' ' + '}')
|
314 |
+
elif is_iterable(var):
|
315 |
+
expand = debug.expand_ignore != '*' and expand
|
316 |
+
if expand:
|
317 |
+
if isinstance(debug.expand_ignore, str):
|
318 |
+
if type_str == str(debug.expand_ignore).lower():
|
319 |
+
expand = False
|
320 |
+
elif is_iterable(debug.expand_ignore):
|
321 |
+
for ignore in debug.expand_ignore:
|
322 |
+
if type_str == ignore.lower():
|
323 |
+
expand = False
|
324 |
+
if hasattr(var, '__len__'):
|
325 |
+
length = len(var)
|
326 |
+
else:
|
327 |
+
var = list(var)
|
328 |
+
length = len(var)
|
329 |
+
if expand and length > 0:
|
330 |
+
_debug_log(f"{type_str}[{length}] {{")
|
331 |
+
if isinstance(var, dict):
|
332 |
+
for k, v in var.items():
|
333 |
+
_debug_log(f"'{k}': ", v, indent + 6 * ' ', assert_true)
|
334 |
+
else:
|
335 |
+
i = 0
|
336 |
+
for k, i in zip(var, range(debug.max_expand)):
|
337 |
+
_debug_log('- ', k, indent + 6 * ' ', assert_true)
|
338 |
+
if i < length - 1:
|
339 |
+
_debug_log('- ' + ' ' * 6 + '...', indent=indent + 6 * ' ')
|
340 |
+
_debug_log(indent + 4 * ' ' + '}')
|
341 |
+
else:
|
342 |
+
_debug_log(f"{type_str}[{length}]")
|
343 |
+
else:
|
344 |
+
_debug_log(str(var))
|
345 |
+
else:
|
346 |
+
debug._stack += '\n'
|
347 |
+
if not debug.silent:
|
348 |
+
print()
|
349 |
+
|
350 |
+
|
351 |
+
def debug_show(x):
|
352 |
+
assert is_iterable(x)
|
353 |
+
debug(x)
|
354 |
+
_debug_show_print(x, indent=' ' * 4 * debug._indent)
|
355 |
+
|
356 |
+
|
357 |
+
def _debug_show_print(x, indent=''):
|
358 |
+
is_tensor = type(x).__name__ in ('Tensor', 'ndarray')
|
359 |
+
if is_tensor:
|
360 |
+
x = x.flatten()
|
361 |
+
if type(x).__name__ == 'Tensor' and x.dim() == 0:
|
362 |
+
return
|
363 |
+
n_samples = min(10, len(x))
|
364 |
+
di = len(x) // n_samples
|
365 |
+
var = list(x[i * di] for i in range(n_samples))
|
366 |
+
if is_tensor or type(var[0]) == float:
|
367 |
+
var = [round(float(v), 4) for v in var]
|
368 |
+
_debug_log('--> ', str(var), indent, expand=False)
|
369 |
+
|
370 |
+
|
371 |
+
debug_init()
|
utils/mutual_entropy.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from scipy.signal import convolve2d
|
5 |
+
|
6 |
+
def mse(x,y):
|
7 |
+
return ((x-y)**2).mean()
|
8 |
+
|
9 |
+
def gaussian_noise_entropies(t1, bins=20):
|
10 |
+
all_MI= []
|
11 |
+
all_mse = []
|
12 |
+
for sigma in np.linspace(0,100,201):
|
13 |
+
t2 = np.random.normal(t1.copy(), scale=sigma, size = t1.shape)
|
14 |
+
hist_2d, x_edges, y_edges = np.histogram2d(
|
15 |
+
t1.ravel(),
|
16 |
+
t2.ravel(),
|
17 |
+
bins=bins)
|
18 |
+
all_mse.append(mse(t1,t2))
|
19 |
+
MI = mutual_information(hist_2d)
|
20 |
+
all_MI.append(MI)
|
21 |
+
|
22 |
+
return np.array((all_MI)), np.array((all_mse))
|
23 |
+
|
24 |
+
def shifts_entropies(t1, bins=20):
|
25 |
+
all_MI=[]
|
26 |
+
all_mse=[]
|
27 |
+
for N in np.linspace(1,50,50):
|
28 |
+
N = int(N)
|
29 |
+
temp_t2 = t1[:-N].copy()
|
30 |
+
temp_t1 = t1[N:].copy()
|
31 |
+
hist_2d, x_edges, y_edges = np.histogram2d(
|
32 |
+
t1.ravel(),
|
33 |
+
t2.ravel(),
|
34 |
+
bins=bins)
|
35 |
+
MI = mutual_information(hist_2d)
|
36 |
+
|
37 |
+
all_mse.append(mse(temp_t1,temp_t2))
|
38 |
+
all_MI.append(MI)
|
39 |
+
|
40 |
+
return np.array((all_MI)), np.array((all_mse))
|
41 |
+
|
42 |
+
def mutual_information(hgram):
|
43 |
+
""" Mutual information for joint histogram
|
44 |
+
"""
|
45 |
+
# Convert bins counts to probability values
|
46 |
+
pxy = hgram / float(np.sum(hgram))
|
47 |
+
px = np.sum(pxy, axis=1) # marginal for x over y
|
48 |
+
py = np.sum(pxy, axis=0) # marginal for y over x
|
49 |
+
px_py = px[:, None] * py[None, :] # Broadcast to multiply marginals
|
50 |
+
# Now we can do the calculation using the pxy, px_py 2D arrays
|
51 |
+
nzs = pxy > 0 # Only non-zero pxy values contribute to the sum
|
52 |
+
return np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs]))
|
53 |
+
|
54 |
+
def entropy(image, bins=20):
|
55 |
+
image = image.ravel()
|
56 |
+
hist, bin_edges = np.histogram(image, bins = bins)
|
57 |
+
hist = hist/hist.sum()
|
58 |
+
entropy_term = np.where(hist != 0, hist*np.log(hist), 0)
|
59 |
+
entropy = - np.sum(entropy_term)
|
60 |
+
|
61 |
+
return entropy
|
62 |
+
|
63 |
+
# Gray Image
|
64 |
+
# t1 = np.array(Image.open("img.png"))[:,:,0].astype(float)
|
65 |
+
|
66 |
+
# Colour Image
|
67 |
+
t1 = np.array(Image.open("img.png").resize((255,255)))
|
68 |
+
|
69 |
+
perturb = "gauss"
|
70 |
+
show_image = True
|
71 |
+
bins=20
|
72 |
+
|
73 |
+
print(perturb)
|
74 |
+
|
75 |
+
# Identity
|
76 |
+
if perturb == "identity":
|
77 |
+
t2 = t1
|
78 |
+
title = "Identity"
|
79 |
+
image1 = "Clean"
|
80 |
+
image2 = "Clean"
|
81 |
+
|
82 |
+
# Poisson Noise on t2
|
83 |
+
if perturb == "poisson":
|
84 |
+
t2 = np.random.poisson(t1)
|
85 |
+
title = "Poisson Noise"
|
86 |
+
image1 = "Clean"
|
87 |
+
image2 = "Noisy"
|
88 |
+
|
89 |
+
# Gaussian Noise on t2
|
90 |
+
if perturb == "gauss":
|
91 |
+
print(np.shape(t1))
|
92 |
+
sigma = 50.0
|
93 |
+
t2 = np.random.normal(t1.copy(), scale=sigma, size = t1.shape)
|
94 |
+
if "grad" in locals():
|
95 |
+
title = f"Gaussian Noise, grad= True, sigma = {sigma:.2f}"
|
96 |
+
else:
|
97 |
+
title = f"Gaussian Noise, sigma = {sigma:.2f}"
|
98 |
+
image1 = "Clean"
|
99 |
+
image2 = "Noisy"
|
100 |
+
|
101 |
+
if perturb == "box":
|
102 |
+
sigma = 50.0
|
103 |
+
mean = np.mean(t1)
|
104 |
+
print(np.shape(t1))
|
105 |
+
t2 = t1.copy()
|
106 |
+
t2[30:220,50:120,:] = mean
|
107 |
+
title = "Box with mean pixels"
|
108 |
+
image1 = "Clean"
|
109 |
+
image2 = "Noisy"
|
110 |
+
|
111 |
+
|
112 |
+
# Shift t2 on y axis
|
113 |
+
if perturb == "shift":
|
114 |
+
N=30
|
115 |
+
t2 = t1[:-N]
|
116 |
+
t1 = t1[N:]
|
117 |
+
title = "y shift"
|
118 |
+
image1 = "Clean"
|
119 |
+
image2 = "Shifted"
|
120 |
+
|
121 |
+
t2 = np.clip(t2,0,255).astype(int)
|
122 |
+
|
123 |
+
print("Correlation Coefficient: ",np.corrcoef(t1.ravel(), t2.ravel())[0, 1])
|
124 |
+
|
125 |
+
# 2D Histogram
|
126 |
+
hist_2d, x_edges, y_edges = np.histogram2d(
|
127 |
+
t1.ravel(),
|
128 |
+
t2.ravel(),
|
129 |
+
bins=bins)
|
130 |
+
|
131 |
+
MI = mutual_information(hist_2d)
|
132 |
+
|
133 |
+
print("Mutual Information", MI)
|
134 |
+
print("Mean squared error:", mse(t1,t2))
|
135 |
+
|
136 |
+
if show_image == True:
|
137 |
+
plt.figure()
|
138 |
+
plt.imshow(np.hstack((t2, t1)))
|
139 |
+
plt.title(title)
|
140 |
+
|
141 |
+
plt.figure()
|
142 |
+
|
143 |
+
plt.plot(t1.ravel(), t2.ravel(), '.')
|
144 |
+
plt.xlabel(image1)
|
145 |
+
plt.ylabel(image2)
|
146 |
+
plt.title('I1 vs I2')
|
147 |
+
|
148 |
+
plt.figure()
|
149 |
+
plt.imshow((hist_2d.T)/hist_2d.max(), origin='lower')
|
150 |
+
plt.xlabel(image1)
|
151 |
+
plt.ylabel(image2)
|
152 |
+
plt.xticks(ticks=np.linspace(0,bins-1,10), labels=np.linspace(x_edges.min(),x_edges.max(),10).astype(int))
|
153 |
+
plt.yticks(ticks=np.linspace(0,bins-1,10), labels=np.linspace(y_edges.min(),y_edges.max(),10).astype(int))
|
154 |
+
plt.title('p(x,y)')
|
155 |
+
plt.colorbar()
|
156 |
+
|
157 |
+
# Show log histogram, avoiding divide by 0
|
158 |
+
plt.figure(figsize=(4,4))
|
159 |
+
hist_2d_log = np.zeros(hist_2d.shape)
|
160 |
+
non_zeros = hist_2d != 0
|
161 |
+
hist_2d_log[non_zeros] = np.log(hist_2d[non_zeros])
|
162 |
+
plt.imshow((hist_2d_log.T)/hist_2d_log.max(), origin='lower')
|
163 |
+
plt.xlabel(image1)
|
164 |
+
plt.ylabel(image2)
|
165 |
+
plt.xticks(ticks=np.linspace(0,bins-1,10), labels=np.linspace(x_edges.min(),x_edges.max(),10).astype(int))
|
166 |
+
plt.yticks(ticks=np.linspace(0,bins-1,10), labels=np.linspace(y_edges.min(),y_edges.max(),10).astype(int))
|
167 |
+
plt.title('log(p(x,y))')
|
168 |
+
plt.colorbar()
|
169 |
+
plt.show()
|
170 |
+
|
171 |
+
if perturb == "shift":
|
172 |
+
mi_array, mse_array = shifts_entropies(t1, bins=bins)
|
173 |
+
plt.figure()
|
174 |
+
plt.plot(np.linspace(0,50,50), mi_array)
|
175 |
+
plt.xlabel("y shift")
|
176 |
+
plt.ylabel("Mutual Information")
|
177 |
+
plt.figure()
|
178 |
+
plt.plot(np.linspace(0,50,50), mse_array)
|
179 |
+
plt.xlabel("y shift")
|
180 |
+
plt.ylabel("Mean Squared Error")
|
181 |
+
plt.show()
|
182 |
+
|
183 |
+
if perturb == "gauss":
|
184 |
+
mi_array, mse_array = gaussian_noise_entropies(t1, bins= bins)
|
185 |
+
plt.figure()
|
186 |
+
plt.plot(np.linspace(0,100,201), mi_array)
|
187 |
+
plt.xlabel("sigma")
|
188 |
+
plt.ylabel("Mutual Information")
|
189 |
+
plt.figure()
|
190 |
+
plt.plot(np.linspace(0,100,201), mse_array)
|
191 |
+
plt.xlabel("sigma")
|
192 |
+
plt.ylabel("Mean Squared Error")
|
193 |
+
plt.show()
|
utils/pytorch_ssim.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""https://github.com/Po-Hsun-Su/pytorch-ssim"""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.autograd import Variable
|
6 |
+
import numpy as np
|
7 |
+
from math import exp
|
8 |
+
|
9 |
+
def gaussian(window_size, sigma):
|
10 |
+
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
11 |
+
return gauss/gauss.sum()
|
12 |
+
|
13 |
+
def create_window(window_size, channel):
|
14 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
15 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
16 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
17 |
+
return window
|
18 |
+
|
19 |
+
def _ssim(img1, img2, window, window_size, channel, size_average = True):
|
20 |
+
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
|
21 |
+
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
|
22 |
+
|
23 |
+
mu1_sq = mu1.pow(2)
|
24 |
+
mu2_sq = mu2.pow(2)
|
25 |
+
mu1_mu2 = mu1*mu2
|
26 |
+
|
27 |
+
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
|
28 |
+
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
|
29 |
+
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
|
30 |
+
|
31 |
+
C1 = 0.01**2
|
32 |
+
C2 = 0.03**2
|
33 |
+
|
34 |
+
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
|
35 |
+
|
36 |
+
if size_average:
|
37 |
+
return ssim_map.mean()
|
38 |
+
else:
|
39 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
40 |
+
|
41 |
+
class SSIM(torch.nn.Module):
|
42 |
+
def __init__(self, window_size = 11, size_average = True):
|
43 |
+
super(SSIM, self).__init__()
|
44 |
+
self.window_size = window_size
|
45 |
+
self.size_average = size_average
|
46 |
+
self.channel = 1
|
47 |
+
self.window = create_window(window_size, self.channel)
|
48 |
+
|
49 |
+
def forward(self, img1, img2):
|
50 |
+
(_, channel, _, _) = img1.size()
|
51 |
+
|
52 |
+
if channel == self.channel and self.window.data.type() == img1.data.type():
|
53 |
+
window = self.window
|
54 |
+
else:
|
55 |
+
window = create_window(self.window_size, channel)
|
56 |
+
|
57 |
+
if img1.is_cuda:
|
58 |
+
window = window.cuda(img1.get_device())
|
59 |
+
window = window.type_as(img1)
|
60 |
+
|
61 |
+
self.window = window
|
62 |
+
self.channel = channel
|
63 |
+
|
64 |
+
|
65 |
+
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
|
66 |
+
|
67 |
+
def ssim(img1, img2, window_size = 11, size_average = True):
|
68 |
+
(_, channel, _, _) = img1.size()
|
69 |
+
window = create_window(window_size, channel)
|
70 |
+
|
71 |
+
if img1.is_cuda:
|
72 |
+
window = window.cuda(img1.get_device())
|
73 |
+
window = window.type_as(img1)
|
74 |
+
|
75 |
+
return _ssim(img1, img2, window, window_size, channel, size_average)
|
utils/show_dataset.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cd09ffb969b9a0a5b414b892614b5b9e48fa32721ea9a14e9e0951160e8f92e4
|
3 |
+
size 2115545
|
utils/splitting.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Split images in blocks and vice versa
|
3 |
+
"""
|
4 |
+
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from skimage.util.shape import view_as_windows
|
11 |
+
|
12 |
+
|
13 |
+
def split_img(imgs, ROIs = (3,3) , step= (1,1)):
|
14 |
+
"""Split the imgs in regions of size ROIs.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
imgs (ndarray): images which you want to split
|
18 |
+
ROIs (tuple): size of sub-regions splitted (ROIs=region of interests)
|
19 |
+
step (tuple): step path from one sub-region to the next one (in the x,y axis)
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
ndarray: splitted subimages.
|
23 |
+
The size is (x_num_subROIs*y_num_subROIs, **) where:
|
24 |
+
x_num_subROIs = ( imgs.shape[1]-int(ROIs[1]/2)*2 )/step[1]
|
25 |
+
y_num_subROIs = ( imgs.shape[0]-int(ROIs[0]/2)*2 )/step[0]
|
26 |
+
|
27 |
+
Example:
|
28 |
+
>>> from dataset_generator import split
|
29 |
+
>>> imgs_splitted = split(imgs, ROI_size = (5,5), step=(2,3))
|
30 |
+
"""
|
31 |
+
|
32 |
+
if len(ROIs) > 2:
|
33 |
+
return print("ROIs is a 2 element list")
|
34 |
+
|
35 |
+
if len(step) > 2:
|
36 |
+
return print("step is a 2 element list")
|
37 |
+
|
38 |
+
if type(imgs) != type(np.array(1)):
|
39 |
+
return print("imgs should be a ndarray")
|
40 |
+
|
41 |
+
if len(imgs.shape) == 2: # Single image with one channel (HxW)
|
42 |
+
splitted = view_as_windows(imgs, (ROIs[0],ROIs[1]), (step[0], step[1]))
|
43 |
+
return splitted.reshape((-1, ROIs[0], ROIs[1]))
|
44 |
+
|
45 |
+
if len(imgs.shape) == 3:
|
46 |
+
_, _, channels = imgs.shape
|
47 |
+
if channels <= 3: # Single image more channels (HxWxC)
|
48 |
+
splitted = view_as_windows(imgs, (ROIs[0],ROIs[1], channels), (step[0], step[1], channels))
|
49 |
+
return splitted.reshape((-1, ROIs[0], ROIs[1], channels))
|
50 |
+
else: # More images with 1 channel
|
51 |
+
splitted = view_as_windows(imgs, (1, ROIs[0],ROIs[1]), (1, step[0], step[1]))
|
52 |
+
return splitted.reshape((-1, ROIs[0], ROIs[1]))
|
53 |
+
|
54 |
+
if len(imgs.shape) == 4: # More images with more channels(BxHxWxC)
|
55 |
+
_, _, _, channels = imgs.shape
|
56 |
+
splitted = view_as_windows(imgs, (1, ROIs[0],ROIs[1], channels), (1, step[0], step[1], channels))
|
57 |
+
return splitted.reshape((-1, ROIs[0], ROIs[1], channels))
|
58 |
+
|
59 |
+
def join_blocks(splitted, final_shape):
|
60 |
+
"""Join blocks to reobtain a splitted image
|
61 |
+
|
62 |
+
Attribute:
|
63 |
+
splitted (tensor) = image splitted in blocks, size = (N_blocks, Channels, Height, Width)
|
64 |
+
final_shape (tuple) = size of the final image reconstructed (Height, Width)
|
65 |
+
Return:
|
66 |
+
tensor: image restored from blocks. size = (Channels, Height, Width)
|
67 |
+
|
68 |
+
"""
|
69 |
+
n_blocks, channels, ROI_height, ROI_width = splitted.shape
|
70 |
+
|
71 |
+
rows = final_shape[0] // ROI_height
|
72 |
+
columns = final_shape[1] // ROI_width
|
73 |
+
|
74 |
+
final_img = torch.empty(rows, channels, ROI_height, ROI_width*columns)
|
75 |
+
for r in range(rows):
|
76 |
+
stackblocks = splitted[r*columns]
|
77 |
+
for c in range(1, columns):
|
78 |
+
stackblocks = torch.cat((stackblocks, splitted[r*columns+c]), axis=2)
|
79 |
+
final_img[r] = stackblocks
|
80 |
+
|
81 |
+
joined_img = final_img[0]
|
82 |
+
|
83 |
+
for i in np.arange(1, len(final_img)):
|
84 |
+
joined_img = torch.cat((joined_img,final_img[i]), axis = 1)
|
85 |
+
|
86 |
+
return joined_img
|
87 |
+
|
88 |
+
def random_ROI(X, Y, ROIs = (512,512)):
|
89 |
+
""" Return a random region for each input/target pair images of the dataset
|
90 |
+
Args:
|
91 |
+
Y (ndarray): target of your dataset --> size: (BxHxWxC)
|
92 |
+
X (ndarray): input of your dataset --> size: (BxHxWxC)
|
93 |
+
ROIs (tuple): size of random region (ROIs=region of interests)
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
For each pair images (input/target) of the dataset, return respectively random ROIs
|
97 |
+
Y_cut (ndarray): target of your dataset --> size: (Batch,Channels,ROIs[0],ROIs[1])
|
98 |
+
X_cut (ndarray): input of your dataset --> size: (Batch,Channels,ROIs[0],ROIs[1])
|
99 |
+
|
100 |
+
Example:
|
101 |
+
>>> from dataset_generator import random_ROI
|
102 |
+
>>> X,Y = random_ROI(X,Y, ROIs = (10,10))
|
103 |
+
"""
|
104 |
+
|
105 |
+
batch, channels, height, width = X.shape
|
106 |
+
|
107 |
+
X_cut=np.empty((batch, ROIs[0], ROIs[1], channels))
|
108 |
+
Y_cut=np.empty((batch, ROIs[0], ROIs[1], channels))
|
109 |
+
|
110 |
+
for i in np.arange(len(X)):
|
111 |
+
x_size=int(random.random()*(height-(ROIs[0]+1)))
|
112 |
+
y_size=int(random.random()*(width-(ROIs[1]+1)))
|
113 |
+
X_cut[i]=X[i, x_size:x_size+ROIs[0],y_size:y_size+ROIs[1], :]
|
114 |
+
Y_cut[i]=Y[i, x_size:x_size+ROIs[0],y_size:y_size+ROIs[1], :]
|
115 |
+
return X_cut, Y_cut
|
116 |
+
|
117 |
+
def one2many_random_ROI(X, Y, datasize=1000, ROIs = (512,512)):
|
118 |
+
""" Return a dataset of N subimages obtained from random regions of the same image
|
119 |
+
Args:
|
120 |
+
Y (ndarray): target of your dataset --> size: (1,H,W,C)
|
121 |
+
X (ndarray): input of your dataset --> size: (1,H,W,C)
|
122 |
+
datasize = number of random ROIs to generate
|
123 |
+
ROIs (tuple): size of random region (ROIs=region of interests)
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
Y_cut (ndarray): target of your dataset --> size: (Datasize,ROIs[0],ROIs[1],Channels)
|
127 |
+
X_cut (ndarray): input of your dataset --> size: (Datasize,ROIs[0],ROIs[1],Channels)
|
128 |
+
"""
|
129 |
+
|
130 |
+
batch, channels, height, width = X.shape
|
131 |
+
|
132 |
+
X_cut=np.empty((datasize, ROIs[0], ROIs[1], channels))
|
133 |
+
Y_cut=np.empty((datasize, ROIs[0], ROIs[1], channels))
|
134 |
+
|
135 |
+
for i in np.arange(datasize):
|
136 |
+
X_cut[i], Y_cut[i] = random_ROI(X, Y, ROIs)
|
137 |
+
return X_cut, Y_cut
|