Rodrigo_Cobo commited on
Commit
cc6c676
1 Parent(s): 9e2cd5a

added thesis

Browse files
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ __pycache__/*
2
+ Scripts/*
3
+ Include/*
4
+ Lib/*
5
+ logs/*
6
+ WiggleGAN_mod.py
7
+ WiggleGAN_noCycle.py
Images/Input-Test/1.png ADDED
Images/Input-Test/10.png ADDED
Images/Input-Test/11.png ADDED
Images/Input-Test/12.png ADDED
Images/Input-Test/2.png ADDED
Images/Input-Test/3.png ADDED
Images/Input-Test/4.png ADDED
Images/Input-Test/6.png ADDED
Images/Input-Test/7.png ADDED
Images/Input-Test/8.png ADDED
Images/Input-Test/9.png ADDED
WiggleGAN.py ADDED
@@ -0,0 +1,833 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import utils, torch, time, os, pickle
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import torch.cuda as cu
5
+ import torch.optim as optim
6
+ import pickle
7
+ from torchvision import transforms
8
+ from torchvision.utils import save_image
9
+ from utils import augmentData, RGBtoL, LtoRGB
10
+ from PIL import Image
11
+ from dataloader import dataloader
12
+ from torch.autograd import Variable
13
+ import matplotlib.pyplot as plt
14
+ import random
15
+ from datetime import date
16
+ from statistics import mean
17
+ from architectures import depth_generator_UNet, \
18
+ depth_discriminator_noclass_UNet
19
+
20
+
21
+ class WiggleGAN(object):
22
+ def __init__(self, args):
23
+ # parameters
24
+ self.epoch = args.epoch
25
+ self.sample_num = 100
26
+ self.nCameras = args.cameras
27
+ self.batch_size = args.batch_size
28
+ self.save_dir = args.save_dir
29
+ self.result_dir = args.result_dir
30
+ self.dataset = args.dataset
31
+ self.log_dir = args.log_dir
32
+ self.gpu_mode = args.gpu_mode
33
+ self.model_name = args.gan_type
34
+ self.input_size = args.input_size
35
+ self.class_num = (args.cameras - 1) * 2 # un calculo que hice en paint
36
+ self.sample_num = self.class_num ** 2
37
+ self.imageDim = args.imageDim
38
+ self.epochVentaja = args.epochV
39
+ self.cantImages = args.cIm
40
+ self.visdom = args.visdom
41
+ self.lambdaL1 = args.lambdaL1
42
+ self.depth = args.depth
43
+ self.name_wiggle = args.name_wiggle
44
+
45
+ self.clipping = args.clipping
46
+ self.WGAN = False
47
+ if (self.clipping > 0):
48
+ self.WGAN = True
49
+
50
+ self.seed = str(random.randint(0, 99999))
51
+ self.seed_load = args.seedLoad
52
+ self.toLoad = False
53
+ if (self.seed_load != "-0000"):
54
+ self.toLoad = True
55
+
56
+ self.zGenFactor = args.zGF
57
+ self.zDisFactor = args.zDF
58
+ self.bFactor = args.bF
59
+ self.CR = False
60
+ if (self.zGenFactor > 0 or self.zDisFactor > 0 or self.bFactor > 0):
61
+ self.CR = True
62
+
63
+ self.expandGen = args.expandGen
64
+ self.expandDis = args.expandDis
65
+
66
+ self.wiggleDepth = args.wiggleDepth
67
+ self.wiggle = False
68
+ if (self.wiggleDepth > 0):
69
+ self.wiggle = True
70
+
71
+
72
+
73
+ # load dataset
74
+
75
+ self.onlyGen = args.lrD <= 0
76
+
77
+ if not self.wiggle:
78
+ self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, split='train',
79
+ trans=not self.CR)
80
+
81
+ self.data_Validation = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim,
82
+ split='validation')
83
+
84
+ self.dataprint = self.data_Validation.__iter__().__next__()
85
+
86
+ data = self.data_loader.__iter__().__next__().get('x_im')
87
+
88
+
89
+ if not self.onlyGen:
90
+ self.D = depth_discriminator_noclass_UNet(input_dim=3, output_dim=1, input_shape=data.shape,
91
+ class_num=self.class_num,
92
+ expand_net=self.expandDis, depth = self.depth, wgan = self.WGAN)
93
+ self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))
94
+
95
+ self.data_Test = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, split='test')
96
+ self.dataprint_test = self.data_Test.__iter__().__next__()
97
+
98
+ # networks init
99
+
100
+ self.G = depth_generator_UNet(input_dim=4, output_dim=3, class_num=self.class_num, expand_net=self.expandGen, depth = self.depth)
101
+ self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
102
+
103
+
104
+ if self.gpu_mode:
105
+ self.G.cuda()
106
+ if not self.wiggle and not self.onlyGen:
107
+ self.D.cuda()
108
+ self.BCE_loss = nn.BCELoss().cuda()
109
+ self.CE_loss = nn.CrossEntropyLoss().cuda()
110
+ self.L1 = nn.L1Loss().cuda()
111
+ self.MSE = nn.MSELoss().cuda()
112
+ self.BCEWithLogitsLoss = nn.BCEWithLogitsLoss().cuda()
113
+ else:
114
+ self.BCE_loss = nn.BCELoss()
115
+ self.CE_loss = nn.CrossEntropyLoss()
116
+ self.MSE = nn.MSELoss()
117
+ self.L1 = nn.L1Loss()
118
+ self.BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
119
+
120
+ print('---------- Networks architecture -------------')
121
+ utils.print_network(self.G)
122
+ if not self.wiggle and not self.onlyGen:
123
+ utils.print_network(self.D)
124
+ print('-----------------------------------------------')
125
+
126
+ temp = torch.zeros((self.class_num, 1))
127
+ for i in range(self.class_num):
128
+ temp[i, 0] = i
129
+
130
+ temp_y = torch.zeros((self.sample_num, 1))
131
+ for i in range(self.class_num):
132
+ temp_y[i * self.class_num: (i + 1) * self.class_num] = temp
133
+
134
+ self.sample_y_ = torch.zeros((self.sample_num, self.class_num)).scatter_(1, temp_y.type(torch.LongTensor), 1)
135
+ if self.gpu_mode:
136
+ self.sample_y_ = self.sample_y_.cuda()
137
+
138
+ if (self.toLoad):
139
+ self.load()
140
+
141
+ def train(self):
142
+
143
+ if self.visdom:
144
+ random.seed(time.time())
145
+ today = date.today()
146
+
147
+ vis = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
148
+ visValidation = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
149
+ visEpoch = utils.VisdomLineTwoPlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
150
+ visImages = utils.VisdomImagePlotter(env_name='Cobo_depth_Images_' + str(today) + '_' + self.seed)
151
+ visImagesTest = utils.VisdomImagePlotter(env_name='Cobo_depth_ImagesTest_' + str(today) + '_' + self.seed)
152
+
153
+ visLossGTest = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
154
+ visLossGValidation = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
155
+
156
+ visLossDTest = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
157
+ visLossDValidation = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
158
+
159
+ self.train_hist = {}
160
+ self.epoch_hist = {}
161
+ self.details_hist = {}
162
+ self.train_hist['D_loss_train'] = []
163
+ self.train_hist['G_loss_train'] = []
164
+ self.train_hist['D_loss_Validation'] = []
165
+ self.train_hist['G_loss_Validation'] = []
166
+ self.train_hist['per_epoch_time'] = []
167
+ self.train_hist['total_time'] = []
168
+
169
+ self.details_hist['G_T_Comp_im'] = []
170
+ self.details_hist['G_T_BCE_fake_real'] = []
171
+ self.details_hist['G_T_Cycle'] = []
172
+ self.details_hist['G_zCR'] = []
173
+
174
+ self.details_hist['G_V_Comp_im'] = []
175
+ self.details_hist['G_V_BCE_fake_real'] = []
176
+ self.details_hist['G_V_Cycle'] = []
177
+
178
+ self.details_hist['D_T_BCE_fake_real_R'] = []
179
+ self.details_hist['D_T_BCE_fake_real_F'] = []
180
+ self.details_hist['D_zCR'] = []
181
+ self.details_hist['D_bCR'] = []
182
+
183
+ self.details_hist['D_V_BCE_fake_real_R'] = []
184
+ self.details_hist['D_V_BCE_fake_real_F'] = []
185
+
186
+ self.epoch_hist['D_loss_train'] = []
187
+ self.epoch_hist['G_loss_train'] = []
188
+ self.epoch_hist['D_loss_Validation'] = []
189
+ self.epoch_hist['G_loss_Validation'] = []
190
+
191
+ ##Para poder tomar el promedio por epoch
192
+ iterIniTrain = 0
193
+ iterFinTrain = 0
194
+
195
+ iterIniValidation = 0
196
+ iterFinValidation = 0
197
+
198
+ maxIter = self.data_loader.dataset.__len__() // self.batch_size
199
+ maxIterVal = self.data_Validation.dataset.__len__() // self.batch_size
200
+
201
+ if (self.WGAN):
202
+ one = torch.tensor(1, dtype=torch.float).cuda()
203
+ mone = one * -1
204
+ else:
205
+ self.y_real_ = torch.ones(self.batch_size, 1)
206
+ self.y_fake_ = torch.zeros(self.batch_size, 1)
207
+ if self.gpu_mode:
208
+ self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()
209
+
210
+ print('training start!!')
211
+ start_time = time.time()
212
+
213
+ for epoch in range(self.epoch):
214
+
215
+ if (epoch < self.epochVentaja):
216
+ ventaja = True
217
+ else:
218
+ ventaja = False
219
+
220
+ self.G.train()
221
+
222
+ if not self.onlyGen:
223
+ self.D.train()
224
+ epoch_start_time = time.time()
225
+
226
+
227
+ # TRAIN!!!
228
+ for iter, data in enumerate(self.data_loader):
229
+
230
+ x_im = data.get('x_im')
231
+ x_dep = data.get('x_dep')
232
+ y_im = data.get('y_im')
233
+ y_dep = data.get('y_dep')
234
+ y_ = data.get('y_')
235
+
236
+ # x_im = imagenes normales
237
+ # x_dep = profundidad de images
238
+ # y_im = imagen con el angulo cambiado
239
+ # y_ = angulo de la imagen = tengo que tratar negativos
240
+
241
+ # Aumento mi data
242
+ if (self.CR):
243
+ x_im_aug, y_im_aug = augmentData(x_im, y_im)
244
+ x_im_vanilla = x_im
245
+
246
+ if self.gpu_mode:
247
+ x_im_aug, y_im_aug = x_im_aug.cuda(), y_im_aug.cuda()
248
+
249
+ if iter >= maxIter:
250
+ break
251
+
252
+ if self.gpu_mode:
253
+ x_im, y_, y_im, x_dep, y_dep = x_im.cuda(), y_.cuda(), y_im.cuda(), x_dep.cuda(), y_dep.cuda()
254
+
255
+ # update D network
256
+ if not ventaja and not self.onlyGen:
257
+
258
+ for p in self.D.parameters(): # reset requires_grad
259
+ p.requires_grad = True # they are set to False below in netG update
260
+
261
+ self.D_optimizer.zero_grad()
262
+
263
+ # Real Images
264
+ D_real, D_features_real = self.D(y_im, x_im, y_dep, y_) ## Es la funcion forward `` g(z) x
265
+
266
+ # Fake Images
267
+ G_, G_dep = self.G( y_, x_im, x_dep)
268
+ D_fake, D_features_fake = self.D(G_, x_im, G_dep, y_)
269
+
270
+ # Losses
271
+ # GAN Loss
272
+ if (self.WGAN): # de WGAN
273
+ D_loss_real_fake_R = - torch.mean(D_real)
274
+ D_loss_real_fake_F = torch.mean(D_fake)
275
+ #D_loss_real_fake_R = - D_loss_real_fake_R_positive
276
+
277
+ else: # de Gan normal
278
+ D_loss_real_fake_R = self.BCEWithLogitsLoss(D_real, self.y_real_)
279
+ D_loss_real_fake_F = self.BCEWithLogitsLoss(D_fake, self.y_fake_)
280
+
281
+ D_loss = D_loss_real_fake_F + D_loss_real_fake_R
282
+
283
+ if self.CR:
284
+
285
+ # Fake Augmented Images bCR
286
+ x_im_aug_bCR, G_aug_bCR = augmentData(x_im_vanilla, G_.data.cpu())
287
+
288
+ if self.gpu_mode:
289
+ G_aug_bCR, x_im_aug_bCR = G_aug_bCR.cuda(), x_im_aug_bCR.cuda()
290
+
291
+ D_fake_bCR, D_features_fake_bCR = self.D(G_aug_bCR, x_im_aug_bCR, G_dep, y_)
292
+ D_real_bCR, D_features_real_bCR = self.D(y_im_aug, x_im_aug, y_dep, y_)
293
+
294
+ # Fake Augmented Images zCR
295
+ G_aug_zCR, G_dep_aug_zCR = self.G(y_, x_im_aug, x_dep)
296
+ D_fake_aug_zCR, D_features_fake_aug_zCR = self.D(G_aug_zCR, x_im_aug, G_dep_aug_zCR, y_)
297
+
298
+ # bCR Loss (*)
299
+ D_loss_real = self.MSE(D_features_real, D_features_real_bCR)
300
+ D_loss_fake = self.MSE(D_features_fake, D_features_fake_bCR)
301
+ D_bCR = (D_loss_real + D_loss_fake) * self.bFactor
302
+
303
+ # zCR Loss
304
+ D_zCR = self.MSE(D_features_fake, D_features_fake_aug_zCR) * self.zDisFactor
305
+
306
+ D_CR_losses = D_bCR + D_zCR
307
+ #D_CR_losses.backward(retain_graph=True)
308
+
309
+ D_loss += D_CR_losses
310
+
311
+ self.details_hist['D_bCR'].append(D_bCR.detach().item())
312
+ self.details_hist['D_zCR'].append(D_zCR.detach().item())
313
+ else:
314
+ self.details_hist['D_bCR'].append(0)
315
+ self.details_hist['D_zCR'].append(0)
316
+
317
+ self.train_hist['D_loss_train'].append(D_loss.detach().item())
318
+ self.details_hist['D_T_BCE_fake_real_R'].append(D_loss_real_fake_R.detach().item())
319
+ self.details_hist['D_T_BCE_fake_real_F'].append(D_loss_real_fake_F.detach().item())
320
+ if self.visdom:
321
+ visLossDTest.plot('Discriminator_losses',
322
+ ['D_T_BCE_fake_real_R','D_T_BCE_fake_real_F', 'D_bCR', 'D_zCR'], 'train',
323
+ self.details_hist)
324
+ #if self.WGAN:
325
+ # D_loss_real_fake_F.backward(retain_graph=True)
326
+ # D_loss_real_fake_R_positive.backward(mone)
327
+ #else:
328
+ # D_loss_real_fake.backward()
329
+ D_loss.backward()
330
+
331
+ self.D_optimizer.step()
332
+
333
+ #WGAN
334
+ if (self.WGAN):
335
+ for p in self.D.parameters():
336
+ p.data.clamp_(-self.clipping, self.clipping) #Segun paper si el valor es muy chico lleva al banishing gradient
337
+ # Si se aplicaria la mejora en las WGANs tendiramos que sacar los batch normalizations de la red
338
+
339
+
340
+ # update G network
341
+ self.G_optimizer.zero_grad()
342
+
343
+ G_, G_dep = self.G(y_, x_im, x_dep)
344
+
345
+ if not ventaja and not self.onlyGen:
346
+ for p in self.D.parameters():
347
+ p.requires_grad = False # to avoid computation
348
+
349
+ # Fake images
350
+ D_fake, _ = self.D(G_, x_im, G_dep, y_)
351
+
352
+ if (self.WGAN):
353
+ G_loss_fake = -torch.mean(D_fake) #de WGAN
354
+ else:
355
+ G_loss_fake = self.BCEWithLogitsLoss(D_fake, self.y_real_)
356
+
357
+ # loss between images (*)
358
+ #G_join = torch.cat((G_, G_dep), 1)
359
+ #y_join = torch.cat((y_im, y_dep), 1)
360
+
361
+ G_loss_Comp = self.L1(G_, y_im)
362
+ if self.depth:
363
+ G_loss_Comp += self.L1(G_dep, y_dep)
364
+
365
+ G_loss_Dif_Comp = G_loss_Comp * self.lambdaL1
366
+
367
+ reverse_y = - y_ + 1
368
+ reverse_G, reverse_G_dep = self.G(reverse_y, G_, G_dep)
369
+ G_loss_Cycle = self.L1(reverse_G, x_im)
370
+ if self.depth:
371
+ G_loss_Cycle += self.L1(reverse_G_dep, x_dep)
372
+ G_loss_Cycle = G_loss_Cycle * self.lambdaL1/2
373
+
374
+
375
+ if (self.CR):
376
+ # Fake images augmented
377
+
378
+ G_aug, G_dep_aug = self.G(y_, x_im_aug, x_dep)
379
+ D_fake_aug, _ = self.D(G_aug, x_im, G_dep_aug, y_)
380
+
381
+ if (self.WGAN):
382
+ G_loss_fake = - (torch.mean(D_fake)+torch.mean(D_fake_aug))/2
383
+ else:
384
+ G_loss_fake = ( self.BCEWithLogitsLoss(D_fake, self.y_real_) +
385
+ self.BCEWithLogitsLoss(D_fake_aug,self.y_real_)) / 2
386
+
387
+ # loss between images (*)
388
+ #y_aug_join = torch.cat((y_im_aug, y_dep), 1)
389
+ #G_aug_join = torch.cat((G_aug, G_dep_aug), 1)
390
+
391
+ G_loss_Comp_Aug = self.L1(G_aug, y_im_aug)
392
+ if self.depth:
393
+ G_loss_Comp_Aug += self.L1(G_dep_aug, y_dep)
394
+ G_loss_Dif_Comp = (G_loss_Comp + G_loss_Comp_Aug)/2 * self.lambdaL1
395
+
396
+
397
+ G_loss = G_loss_fake + G_loss_Dif_Comp + G_loss_Cycle
398
+
399
+ self.details_hist['G_T_BCE_fake_real'].append(G_loss_fake.detach().item())
400
+ self.details_hist['G_T_Comp_im'].append(G_loss_Dif_Comp.detach().item())
401
+ self.details_hist['G_T_Cycle'].append(G_loss_Cycle.detach().item())
402
+ self.details_hist['G_zCR'].append(0)
403
+
404
+
405
+ else:
406
+
407
+ G_loss = self.L1(G_, y_im)
408
+ if self.depth:
409
+ G_loss += self.L1(G_dep, y_dep)
410
+ G_loss = G_loss * self.lambdaL1
411
+ self.details_hist['G_T_Comp_im'].append(G_loss.detach().item())
412
+ self.details_hist['G_T_BCE_fake_real'].append(0)
413
+ self.details_hist['G_T_Cycle'].append(0)
414
+ self.details_hist['G_zCR'].append(0)
415
+
416
+ G_loss.backward()
417
+ self.G_optimizer.step()
418
+ self.train_hist['G_loss_train'].append(G_loss.detach().item())
419
+ if self.onlyGen:
420
+ self.train_hist['D_loss_train'].append(0)
421
+
422
+ iterFinTrain += 1
423
+
424
+ if self.visdom:
425
+ visLossGTest.plot('Generator_losses',
426
+ ['G_T_Comp_im', 'G_T_BCE_fake_real', 'G_zCR','G_T_Cycle'],
427
+ 'train', self.details_hist)
428
+
429
+ vis.plot('loss', ['D_loss_train', 'G_loss_train'], 'train', self.train_hist)
430
+
431
+ ##################Validation####################################
432
+ with torch.no_grad():
433
+
434
+ self.G.eval()
435
+ if not self.onlyGen:
436
+ self.D.eval()
437
+
438
+ for iter, data in enumerate(self.data_Validation):
439
+
440
+ # Aumento mi data
441
+ x_im = data.get('x_im')
442
+ x_dep = data.get('x_dep')
443
+ y_im = data.get('y_im')
444
+ y_dep = data.get('y_dep')
445
+ y_ = data.get('y_')
446
+ # x_im = imagenes normales
447
+ # x_dep = profundidad de images
448
+ # y_im = imagen con el angulo cambiado
449
+ # y_ = angulo de la imagen = tengo que tratar negativos
450
+
451
+ # x_im = torch.Tensor(list(x_im))
452
+ # x_dep = torch.Tensor(x_dep)
453
+ # y_im = torch.Tensor(y_im)
454
+ # print(y_.shape[0])
455
+ if iter == maxIterVal:
456
+ # print ("Break")
457
+ break
458
+ # print (y_.type(torch.LongTensor).unsqueeze(1))
459
+
460
+
461
+ # print("y_vec_", y_vec_)
462
+ # print ("z_", z_)
463
+
464
+ if self.gpu_mode:
465
+ x_im, y_, y_im, x_dep, y_dep = x_im.cuda(), y_.cuda(), y_im.cuda(), x_dep.cuda(), y_dep.cuda()
466
+ # D network
467
+
468
+ if not ventaja and not self.onlyGen:
469
+ # Real Images
470
+ D_real, _ = self.D(y_im, x_im, y_dep,y_) ## Es la funcion forward `` g(z) x
471
+
472
+ # Fake Images
473
+ G_, G_dep = self.G(y_, x_im, x_dep)
474
+ D_fake, _ = self.D(G_, x_im, G_dep, y_)
475
+ # Losses
476
+ # GAN Loss
477
+ if (self.WGAN): # de WGAN
478
+ D_loss_real_fake_R = - torch.mean(D_real)
479
+ D_loss_real_fake_F = torch.mean(D_fake)
480
+
481
+ else: # de Gan normal
482
+ D_loss_real_fake_R = self.BCEWithLogitsLoss(D_real, self.y_real_)
483
+ D_loss_real_fake_F = self.BCEWithLogitsLoss(D_fake, self.y_fake_)
484
+
485
+ D_loss_real_fake = D_loss_real_fake_F + D_loss_real_fake_R
486
+
487
+ D_loss = D_loss_real_fake
488
+
489
+ self.train_hist['D_loss_Validation'].append(D_loss.item())
490
+ self.details_hist['D_V_BCE_fake_real_R'].append(D_loss_real_fake_R.item())
491
+ self.details_hist['D_V_BCE_fake_real_F'].append(D_loss_real_fake_F.item())
492
+ if self.visdom:
493
+ visLossDValidation.plot('Discriminator_losses',
494
+ ['D_V_BCE_fake_real_R','D_V_BCE_fake_real_F'], 'Validation',
495
+ self.details_hist)
496
+
497
+ # G network
498
+
499
+ G_, G_dep = self.G(y_, x_im, x_dep)
500
+
501
+ if not ventaja and not self.onlyGen:
502
+ # Fake images
503
+ D_fake,_ = self.D(G_, x_im, G_dep, y_)
504
+
505
+ #Loss GAN
506
+ if (self.WGAN):
507
+ G_loss = -torch.mean(D_fake) # porWGAN
508
+ else:
509
+ G_loss = self.BCEWithLogitsLoss(D_fake, self.y_real_) #de GAN NORMAL
510
+
511
+ self.details_hist['G_V_BCE_fake_real'].append(G_loss.item())
512
+
513
+ #Loss comparation
514
+ #G_join = torch.cat((G_, G_dep), 1)
515
+ #y_join = torch.cat((y_im, y_dep), 1)
516
+
517
+ G_loss_Comp = self.L1(G_, y_im)
518
+ if self.depth:
519
+ G_loss_Comp += self.L1(G_dep, y_dep)
520
+ G_loss_Comp = G_loss_Comp * self.lambdaL1
521
+
522
+ reverse_y = - y_ + 1
523
+ reverse_G, reverse_G_dep = self.G(reverse_y, G_, G_dep)
524
+ G_loss_Cycle = self.L1(reverse_G, x_im)
525
+ if self.depth:
526
+ G_loss_Cycle += self.L1(reverse_G_dep, x_dep)
527
+ G_loss_Cycle = G_loss_Cycle * self.lambdaL1/2
528
+
529
+ G_loss += G_loss_Comp + G_loss_Cycle
530
+
531
+
532
+ self.details_hist['G_V_Comp_im'].append(G_loss_Comp.item())
533
+ self.details_hist['G_V_Cycle'].append(G_loss_Cycle.detach().item())
534
+
535
+ else:
536
+ G_loss = self.L1(G_, y_im)
537
+ if self.depth:
538
+ G_loss += self.L1(G_dep, y_dep)
539
+ G_loss = G_loss * self.lambdaL1
540
+ self.details_hist['G_V_Comp_im'].append(G_loss.item())
541
+ self.details_hist['G_V_BCE_fake_real'].append(0)
542
+ self.details_hist['G_V_Cycle'].append(0)
543
+
544
+ self.train_hist['G_loss_Validation'].append(G_loss.item())
545
+ if self.onlyGen:
546
+ self.train_hist['D_loss_Validation'].append(0)
547
+
548
+
549
+ iterFinValidation += 1
550
+ if self.visdom:
551
+ visLossGValidation.plot('Generator_losses', ['G_V_Comp_im', 'G_V_BCE_fake_real','G_V_Cycle'],
552
+ 'Validation', self.details_hist)
553
+ visValidation.plot('loss', ['D_loss_Validation', 'G_loss_Validation'], 'Validation',
554
+ self.train_hist)
555
+
556
+ ##Vis por epoch
557
+
558
+ if ventaja or self.onlyGen:
559
+ self.epoch_hist['D_loss_train'].append(0)
560
+ self.epoch_hist['D_loss_Validation'].append(0)
561
+ else:
562
+ #inicioTr = (epoch - self.epochVentaja) * (iterFinTrain - iterIniTrain)
563
+ #inicioTe = (epoch - self.epochVentaja) * (iterFinValidation - iterIniValidation)
564
+ self.epoch_hist['D_loss_train'].append(mean(self.train_hist['D_loss_train'][iterIniTrain: -1]))
565
+ self.epoch_hist['D_loss_Validation'].append(mean(self.train_hist['D_loss_Validation'][iterIniValidation: -1]))
566
+
567
+ self.epoch_hist['G_loss_train'].append(mean(self.train_hist['G_loss_train'][iterIniTrain:iterFinTrain]))
568
+ self.epoch_hist['G_loss_Validation'].append(
569
+ mean(self.train_hist['G_loss_Validation'][iterIniValidation:iterFinValidation]))
570
+ if self.visdom:
571
+ visEpoch.plot('epoch', epoch,
572
+ ['D_loss_train', 'G_loss_train', 'D_loss_Validation', 'G_loss_Validation'],
573
+ self.epoch_hist)
574
+
575
+ self.train_hist['D_loss_train'] = self.train_hist['D_loss_train'][-1:]
576
+ self.train_hist['G_loss_train'] = self.train_hist['G_loss_train'][-1:]
577
+ self.train_hist['D_loss_Validation'] = self.train_hist['D_loss_Validation'][-1:]
578
+ self.train_hist['G_loss_Validation'] = self.train_hist['G_loss_Validation'][-1:]
579
+ self.train_hist['per_epoch_time'] = self.train_hist['per_epoch_time'][-1:]
580
+ self.train_hist['total_time'] = self.train_hist['total_time'][-1:]
581
+
582
+ self.details_hist['G_T_Comp_im'] = self.details_hist['G_T_Comp_im'][-1:]
583
+ self.details_hist['G_T_BCE_fake_real'] = self.details_hist['G_T_BCE_fake_real'][-1:]
584
+ self.details_hist['G_T_Cycle'] = self.details_hist['G_T_Cycle'][-1:]
585
+ self.details_hist['G_zCR'] = self.details_hist['G_zCR'][-1:]
586
+
587
+ self.details_hist['G_V_Comp_im'] = self.details_hist['G_V_Comp_im'][-1:]
588
+ self.details_hist['G_V_BCE_fake_real'] = self.details_hist['G_V_BCE_fake_real'][-1:]
589
+ self.details_hist['G_V_Cycle'] = self.details_hist['G_V_Cycle'][-1:]
590
+
591
+ self.details_hist['D_T_BCE_fake_real_R'] = self.details_hist['D_T_BCE_fake_real_R'][-1:]
592
+ self.details_hist['D_T_BCE_fake_real_F'] = self.details_hist['D_T_BCE_fake_real_F'][-1:]
593
+ self.details_hist['D_zCR'] = self.details_hist['D_zCR'][-1:]
594
+ self.details_hist['D_bCR'] = self.details_hist['D_bCR'][-1:]
595
+
596
+ self.details_hist['D_V_BCE_fake_real_R'] = self.details_hist['D_V_BCE_fake_real_R'][-1:]
597
+ self.details_hist['D_V_BCE_fake_real_F'] = self.details_hist['D_V_BCE_fake_real_F'][-1:]
598
+ ##Para poder tomar el promedio por epoch
599
+ iterIniTrain = 1
600
+ iterFinTrain = 1
601
+
602
+ iterIniValidation = 1
603
+ iterFinValidation = 1
604
+
605
+ self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
606
+
607
+ if epoch % 10 == 0:
608
+ self.save(str(epoch))
609
+ with torch.no_grad():
610
+ if self.visdom:
611
+ self.visualize_results(epoch, dataprint=self.dataprint, visual=visImages)
612
+ self.visualize_results(epoch, dataprint=self.dataprint_test, visual=visImagesTest)
613
+ else:
614
+ imageName = self.model_name + '_' + 'Train' + '_' + str(self.seed) + '_' + str(epoch)
615
+ self.visualize_results(epoch, dataprint=self.dataprint, name= imageName)
616
+ self.visualize_results(epoch, dataprint=self.dataprint_test, name= imageName)
617
+
618
+
619
+ self.train_hist['total_time'].append(time.time() - start_time)
620
+ print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
621
+ self.epoch, self.train_hist['total_time'][0]))
622
+ print("Training finish!... save training results")
623
+
624
+ self.save()
625
+ #utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
626
+ # self.epoch)
627
+ #utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
628
+
629
+ def visualize_results(self, epoch, dataprint, visual="", name= "test"):
630
+ with torch.no_grad():
631
+ self.G.eval()
632
+
633
+ #if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name):
634
+ # os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name)
635
+
636
+ # print("sample z: ",self.sample_z_,"sample y:", self.sample_y_)
637
+
638
+ ##Podria hacer un loop
639
+ # .zfill(4)
640
+ #newSample = None
641
+ #print(dataprint.shape)
642
+
643
+ #newSample = torch.tensor([])
644
+
645
+ #se que es ineficiente pero lo hago cada 10 epoch nomas
646
+ newSample = []
647
+ iter = 1
648
+ for x_im,x_dep in zip(dataprint.get('x_im'), dataprint.get('x_dep')):
649
+ if (iter > self.cantImages):
650
+ break
651
+
652
+ #x_im = (x_im + 1) / 2
653
+ #imgX = transforms.ToPILImage()(x_im)
654
+ #imgX.show()
655
+
656
+ x_im_input = x_im.repeat(2, 1, 1, 1)
657
+ x_dep_input = x_dep.repeat(2, 1, 1, 1)
658
+
659
+ sizeImage = x_im.shape[2]
660
+
661
+ sample_y_ = torch.zeros((self.class_num, 1, sizeImage, sizeImage))
662
+ for i in range(self.class_num):
663
+ if(int(i % self.class_num) == 1):
664
+ sample_y_[i] = torch.ones(( 1, sizeImage, sizeImage))
665
+
666
+ if self.gpu_mode:
667
+ sample_y_, x_im_input, x_dep_input = sample_y_.cuda(), x_im_input.cuda(), x_dep_input.cuda()
668
+
669
+ G_im, G_dep = self.G(sample_y_, x_im_input, x_dep_input)
670
+
671
+ newSample.append(x_im.squeeze(0))
672
+ newSample.append(x_dep.squeeze(0).expand(3, -1, -1))
673
+
674
+
675
+
676
+ if self.wiggle:
677
+ im_aux, im_dep_aux = G_im, G_dep
678
+ for i in range(0, 2):
679
+ index = i
680
+ for j in range(0, self.wiggleDepth):
681
+
682
+ # print(i,j)
683
+
684
+ if (j == 0 and i == 1):
685
+ # para tomar el original
686
+ im_aux, im_dep_aux = G_im, G_dep
687
+ newSample.append(G_im.cpu()[0].squeeze(0))
688
+ newSample.append(G_im.cpu()[1].squeeze(0))
689
+ elif (i == 1):
690
+ # por el problema de las iteraciones proximas
691
+ index = 0
692
+
693
+ # imagen generada
694
+
695
+
696
+ x = im_aux[index].unsqueeze(0)
697
+ x_dep = im_dep_aux[index].unsqueeze(0)
698
+
699
+ y = sample_y_[i].unsqueeze(0)
700
+
701
+ if self.gpu_mode:
702
+ y, x, x_dep = y.cuda(), x.cuda(), x_dep.cuda()
703
+
704
+ im_aux, im_dep_aux = self.G(y, x, x_dep)
705
+
706
+ newSample.append(im_aux.cpu()[0])
707
+ else:
708
+
709
+ newSample.append(G_im.cpu()[0])
710
+ newSample.append(G_im.cpu()[1])
711
+ newSample.append(G_dep.cpu()[0].expand(3, -1, -1))
712
+ newSample.append(G_dep.cpu()[1].expand(3, -1, -1))
713
+ # sadadas
714
+
715
+ iter+=1
716
+
717
+ if self.visdom:
718
+ visual.plot(epoch, newSample, int(len(newSample) /self.cantImages))
719
+ else:
720
+ utils.save_wiggle(newSample, self.cantImages, name)
721
+ ##TENGO QUE HACER QUE SAMPLES TENGAN COMO MAXIMO self.class_num * self.class_num
722
+
723
+ # utils.save_images(newSample[:, :, :, :], [image_frame_dim * cantidadIm , image_frame_dim * (self.class_num+2)],
724
+ # self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%04d' % epoch + '.png')
725
+
726
+ def show_plot_images(self, images, cols=1, titles=None):
727
+ """Display a list of images in a single figure with matplotlib.
728
+
729
+ Parameters
730
+ ---------
731
+ images: List of np.arrays compatible with plt.imshow.
732
+
733
+ cols (Default = 1): Number of columns in figure (number of rows is
734
+ set to np.ceil(n_images/float(cols))).
735
+
736
+ titles: List of titles corresponding to each image. Must have
737
+ the same length as titles.
738
+ """
739
+ # assert ((titles is None) or (len(images) == len(titles)))
740
+ n_images = len(images)
741
+ if titles is None: titles = ['Image (%d)' % i for i in range(1, n_images + 1)]
742
+ fig = plt.figure()
743
+ for n, (image, title) in enumerate(zip(images, titles)):
744
+ a = fig.add_subplot(np.ceil(n_images / float(cols)), cols, n + 1)
745
+ # print(image)
746
+ image = (image + 1) * 255.0
747
+ # print(image)
748
+ # new_im = Image.fromarray(image)
749
+ # print(new_im)
750
+ if image.ndim == 2:
751
+ plt.gray()
752
+ # print("spi imshape ", image.shape)
753
+ plt.imshow(image)
754
+ a.set_title(title)
755
+ fig.set_size_inches(np.array(fig.get_size_inches()) * n_images)
756
+ plt.show()
757
+
758
+ def joinImages(self, data):
759
+ nData = []
760
+ for i in range(self.class_num):
761
+ nData.append(data)
762
+ nData = np.array(nData)
763
+ nData = torch.tensor(nData.tolist())
764
+ nData = nData.type(torch.FloatTensor)
765
+
766
+ return nData
767
+
768
+ def save(self, epoch=''):
769
+ save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
770
+
771
+ if not os.path.exists(save_dir):
772
+ os.makedirs(save_dir)
773
+
774
+ torch.save(self.G.state_dict(),
775
+ os.path.join(save_dir, self.model_name + '_' + self.seed + '_' + epoch + '_G.pkl'))
776
+ if not self.onlyGen:
777
+ torch.save(self.D.state_dict(),
778
+ os.path.join(save_dir, self.model_name + '_' + self.seed + '_' + epoch + '_D.pkl'))
779
+
780
+ with open(os.path.join(save_dir, self.model_name + '_history_ '+self.seed+'.pkl'), 'wb') as f:
781
+ pickle.dump(self.train_hist, f)
782
+
783
+ def load(self):
784
+ save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
785
+
786
+ self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_G.pkl')))
787
+ if not self.wiggle:
788
+ self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_D.pkl')))
789
+
790
+ def wiggleEf(self):
791
+ seed, epoch = self.seed_load.split('_')
792
+ if self.visdom:
793
+ visWiggle = utils.VisdomImagePlotter(env_name='Cobo_depth_wiggle_' + seed)
794
+ self.visualize_results(epoch=epoch, dataprint=self.dataprint_test, visual=visWiggle)
795
+ else:
796
+ self.visualize_results(epoch=epoch, dataprint=self.dataprint_test, visual=None, name = self.name_wiggle)
797
+
798
+ def recreate(self):
799
+
800
+ dataloader_recreate = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, split='score')
801
+ with torch.no_grad():
802
+ self.G.eval()
803
+ accum = 0
804
+ for data_batch in dataloader_recreate.__iter__():
805
+
806
+ #{'x_im': x1, 'x_dep': x1_dep, 'y_im': x2, 'y_dep': x2_dep, 'y_': torch.ones(1, self.imageDim, self.imageDim)}
807
+ left,left_depth,right,right_depth,direction = data_batch.values()
808
+
809
+ if self.gpu_mode:
810
+ left,left_depth,right,right_depth,direction = left.cuda(),left_depth.cuda(),right.cuda(),right_depth.cuda(),direction.cuda()
811
+
812
+ G_right, G_right_dep = self.G( direction, left, left_depth)
813
+
814
+ reverse_direction = direction * 0
815
+ G_left, G_left_dep = self.G( reverse_direction, right, right_depth)
816
+
817
+ for index in range(0,self.batch_size):
818
+ image_right = (G_right[index] + 1.0)/2.0
819
+ image_right_dep = (G_right_dep[index] + 1.0)/2.0
820
+
821
+ image_left = (G_left[index] + 1.0)/2.0
822
+ image_left_dep = (G_left_dep[index] + 1.0)/2.0
823
+
824
+
825
+
826
+ save_image(image_right, os.path.join("results","recreate_dataset","CAM1","n_{num:0{width}}.png".format(num = index+accum, width = 4)))
827
+ save_image(image_right_dep, os.path.join("results","recreate_dataset","CAM1","d_{num:0{width}}.png".format(num = index+accum, width = 4)))
828
+
829
+ save_image(image_left, os.path.join("results","recreate_dataset","CAM0","n_{num:0{width}}.png".format(num = index+accum, width = 4)))
830
+ save_image(image_left_dep, os.path.join("results","recreate_dataset","CAM0","d_{num:0{width}}.png".format(num = index+accum, width = 4)))
831
+ accum+= self.batch_size
832
+
833
+
WiggleResults/split.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import argparse
4
+
5
+ parser = argparse.ArgumentParser(description='change to useful name')
6
+ parser.add_argument('--dim', default=128, type=int, help='dimention image')
7
+ args = parser.parse_args()
8
+
9
+ path = "."
10
+ dirs = os.listdir(path)
11
+
12
+ dim = args.dim
13
+
14
+ def gif_order (data, center=True):
15
+ gif = []
16
+ base = 1
17
+
18
+ #primera mitad
19
+ i = int((len(data)-2)/2)
20
+ while(i > base ):
21
+ gif.append(data[i])
22
+ #print(i)
23
+ i -= 1
24
+
25
+
26
+ #el del medio izq
27
+ gif.append(data[int((len(data)-2)/2) + 1])
28
+ #print(int((len(data)-2)/2) + 1)
29
+
30
+ #el inicial
31
+ if center:
32
+ gif.append(data[0])
33
+ #print(0)
34
+
35
+ # el del medio der
36
+ gif.append(data[int((len(data) - 2) / 2) + 2])
37
+ #print(int((len(data) - 2) / 2) +2)
38
+ #segunda mitad
39
+ i = int((len(data)-2)/2) + 3
40
+ while (i < len(data)):
41
+ gif.append(data[i])
42
+ #print(i)
43
+ i += 1
44
+ #print("---------")
45
+
46
+ invertedgif = gif[::-1]
47
+ invertedgif = invertedgif[1:]
48
+
49
+ gif = gif[1:] + invertedgif
50
+ #print(gif)
51
+ #for image in gif:
52
+ # image.show()
53
+ #gsdfgsfgf
54
+ return gif
55
+
56
+
57
+ # This would print all the files and directories
58
+ for file in dirs:
59
+ if ".jpg" in file or ".png" in file:
60
+ rowImages = []
61
+ im = Image.open("./" + file)
62
+ width, height = im.size
63
+ im = im.convert('RGB')
64
+
65
+ #CROP (left, top, right, bottom)
66
+
67
+ pointleft = 3
68
+ pointtop = 3
69
+ i = 0
70
+ while (pointtop < height):
71
+ while (pointleft < width):
72
+ im1 = im.crop((pointleft, pointtop, dim+pointleft, dim+pointtop))
73
+ rowImages.append(im1.quantize())
74
+ #im1.show()
75
+ pointleft+= dim+4
76
+ # Ya tengo todas las imagenes podria hacer el gif aca
77
+ rowImages = gif_order(rowImages,center=False)
78
+ name = file[:-4] + "_" + str(i) + '.gif'
79
+ rowImages[0].save(name, save_all=True,format='GIF', append_images=rowImages[1:], optimize=True, duration=100, loop=0)
80
+ pointtop += dim + 4
81
+ pointleft = 3
82
+ rowImages = []
83
+ i+=1
84
+ #im2 = im.crop((width / 2, 0, width, height))
85
+ # im2.show()
86
+
87
+ #im1.save("./2" + file[:-4] + ".png")
88
+ #im2.save("./" + file[:-4] + ".png")
89
+
90
+ # Deleted
91
+ #os.remove("data/" + file)
app.py CHANGED
@@ -16,7 +16,6 @@ def calculate_depth(model_type, img):
16
 
17
  img.save(filename, "JPEG")
18
 
19
- #model_type = "DPT_Hybrid"
20
  midas = torch.hub.load("intel-isl/MiDaS", model_type)
21
 
22
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
@@ -61,18 +60,19 @@ def wiggle_effect(slider):
61
 
62
  with gr.Blocks() as demo:
63
  gr.Markdown("Start typing below and then click **Run** to see the output.")
64
- inp = []
65
-
 
66
  midas_models = ["DPT_Large","DPT_Hybrid","MiDaS_small"]
67
-
68
- inp.append(gr.inputs.Dropdown(midas_models, default="MiDaS_small", label="Depth estimation model type"))
69
-
70
  with gr.Row():
71
  inp.append(gr.Image(type="pil", label="Input"))
72
  out = gr.Image(type="file", label="depth_estimation")
73
  btn = gr.Button("Calculate depth")
74
  btn.click(fn=calculate_depth, inputs=inp, outputs=out)
75
 
 
 
76
  inp = [gr.Slider(1,15, default = 2, label='StepCycles',step= 1)]
77
  with gr.Row():
78
  out = [ gr.Image(type="file", label="Output_images"), #TODO change to gallery
16
 
17
  img.save(filename, "JPEG")
18
 
 
19
  midas = torch.hub.load("intel-isl/MiDaS", model_type)
20
 
21
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
60
 
61
  with gr.Blocks() as demo:
62
  gr.Markdown("Start typing below and then click **Run** to see the output.")
63
+
64
+
65
+ ## Depth Estimation
66
  midas_models = ["DPT_Large","DPT_Hybrid","MiDaS_small"]
67
+ inp = [gr.inputs.Dropdown(midas_models, default="MiDaS_small", label="Depth estimation model type")]
 
 
68
  with gr.Row():
69
  inp.append(gr.Image(type="pil", label="Input"))
70
  out = gr.Image(type="file", label="depth_estimation")
71
  btn = gr.Button("Calculate depth")
72
  btn.click(fn=calculate_depth, inputs=inp, outputs=out)
73
 
74
+
75
+ ## Wigglegram
76
  inp = [gr.Slider(1,15, default = 2, label='StepCycles',step= 1)]
77
  with gr.Row():
78
  out = [ gr.Image(type="file", label="Output_images"), #TODO change to gallery
architectures.py ADDED
@@ -0,0 +1,1094 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import utils, torch
3
+ from torch.autograd import Variable
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class generator(nn.Module):
8
+ # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
9
+ # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
10
+ def __init__(self, input_dim=4, output_dim=1, input_shape=3, class_num=10, height=10, width=10):
11
+ super(generator, self).__init__()
12
+ self.input_dim = input_dim
13
+ self.output_dim = output_dim
14
+ # print ("self.output_dim", self.output_dim)
15
+ self.class_num = class_num
16
+ self.input_shape = list(input_shape)
17
+ self.toPreDecov = 1024
18
+ self.toDecov = 1
19
+ self.height = height
20
+ self.width = width
21
+
22
+ self.input_shape[1] = self.input_dim # esto cambio despues por colores
23
+
24
+ # print("input shpe gen",self.input_shape)
25
+
26
+ self.conv1 = nn.Sequential(
27
+ nn.Conv2d(self.input_dim, 10, 4, 2, 1), # para mi el 2 tendria que ser 1
28
+ nn.Conv2d(10, 4, 4, 2, 1),
29
+ nn.BatchNorm2d(4),
30
+ nn.LeakyReLU(0.2),
31
+ nn.Conv2d(4, 3, 4, 2, 1),
32
+ nn.BatchNorm2d(3),
33
+ nn.LeakyReLU(0.2),
34
+ )
35
+
36
+ self.n_size = self._get_conv_output(self.input_shape)
37
+ # print ("self.n_size",self.n_size)
38
+ self.cubic = (self.n_size // 8192)
39
+ # print("self.cubic: ",self.cubic)
40
+
41
+ self.fc1 = nn.Sequential(
42
+ nn.Linear(self.n_size, self.n_size),
43
+ nn.BatchNorm1d(self.n_size),
44
+ nn.LeakyReLU(0.2),
45
+ )
46
+
47
+ self.preDeconv = nn.Sequential(
48
+ ##############RED SUPER CHICA PARA QUE ANDE TO DO PORQUE RAM Y MEMORY
49
+
50
+ # nn.Linear(self.toPreDecov + self.zdim + self.class_num, 1024),
51
+ # nn.BatchNorm1d(1024),
52
+ # nn.LeakyReLU(0.2),
53
+ # nn.Linear(1024, self.toDecov * self.height // 64 * self.width// 64),
54
+ # nn.BatchNorm1d(self.toDecov * self.height // 64 * self.width// 64),
55
+ # nn.LeakyReLU(0.2),
56
+ # nn.Linear(self.toDecov * self.height // 64 * self.width // 64 , self.toDecov * self.height // 32 * self.width // 32),
57
+ # nn.BatchNorm1d(self.toDecov * self.height // 32 * self.width // 32),
58
+ # nn.LeakyReLU(0.2),
59
+ # nn.Linear(self.toDecov * self.height // 32 * self.width // 32,
60
+ # 1 * self.height * self.width),
61
+ # nn.BatchNorm1d(1 * self.height * self.width),
62
+ # nn.LeakyReLU(0.2),
63
+
64
+ nn.Linear(self.n_size + self.class_num, 400),
65
+ nn.BatchNorm1d(400),
66
+ nn.LeakyReLU(0.2),
67
+ nn.Linear(400, 800),
68
+ nn.BatchNorm1d(800),
69
+ nn.LeakyReLU(0.2),
70
+ nn.Linear(800, self.output_dim * self.height * self.width),
71
+ nn.BatchNorm1d(self.output_dim * self.height * self.width),
72
+ nn.Tanh(), # Cambio porque hago como que termino ahi
73
+
74
+ )
75
+
76
+ """
77
+ self.deconv = nn.Sequential(
78
+ nn.ConvTranspose2d(self.toDecov, 2, 4, 2, 0),
79
+ nn.BatchNorm2d(2),
80
+ nn.ReLU(),
81
+ nn.ConvTranspose2d(2, self.output_dim, 4, 2, 1),
82
+ nn.Tanh(), #esta recomendado que la ultima sea TanH de la Generadora da valores entre -1 y 1
83
+ )
84
+ """
85
+ utils.initialize_weights(self)
86
+
87
+ def _get_conv_output(self, shape):
88
+ bs = 1
89
+ input = Variable(torch.rand(bs, *shape))
90
+ # print("inShape:",input.shape)
91
+ output_feat = self.conv1(input.squeeze())
92
+ # print ("output_feat",output_feat.shape)
93
+ n_size = output_feat.data.view(bs, -1).size(1)
94
+ # print ("n",n_size // 4)
95
+ return n_size // 4
96
+
97
+ def forward(self, clase, im):
98
+ ##Esto es lo que voy a hacer
99
+ # Cat entre la imagen y la profundidad
100
+ # print ("H",self.height,"W",self.width)
101
+ # imDep = imDep[:, None, :, :]
102
+ # im = im[:, None, :, :]
103
+ x = im
104
+
105
+ # Ref Conv de ese cat
106
+ x = self.conv1(x)
107
+ x = x.view(x.size(0), -1)
108
+ # print ("x:", x.shape)
109
+ x = self.fc1(x)
110
+ # print ("x:",x.shape)
111
+
112
+ # cat entre el ruido y la clase
113
+ y = clase
114
+ # print("Cat entre input y clase", y.shape) #podria separarlo, unir primero con clase y despues con ruido
115
+
116
+ # Red Lineal que une la Conv con el cat anterior
117
+ x = torch.cat([x, y], 1)
118
+ x = self.preDeconv(x)
119
+ # print ("antes de deconv", x.shape)
120
+ x = x.view(-1, self.output_dim, self.height, self.width)
121
+ # print("Despues View: ", x.shape)
122
+ # Red que saca produce la imagen final
123
+ # x = self.deconv(x)
124
+ # print("La salida de la generadora es: ",x.shape)
125
+
126
+ return x
127
+
128
+
129
+ class discriminator(nn.Module):
130
+ # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
131
+ # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
132
+ def __init__(self, input_dim=1, output_dim=1, input_shape=2, class_num=10):
133
+ super(discriminator, self).__init__()
134
+ self.input_dim = input_dim * 2 # ya que le doy el origen
135
+ self.output_dim = output_dim
136
+ self.input_shape = list(input_shape)
137
+ self.class_num = class_num
138
+
139
+ self.input_shape[1] = self.input_dim # esto cambio despues por colores
140
+ # print(self.input_shape)
141
+
142
+ """""
143
+ in_channels (int): Number of channels in the input image
144
+ out_channels (int): Number of channels produced by the convolution
145
+ kernel_size (int or tuple): Size of the convolving kernel - lo que se agarra para la conv
146
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
147
+ padding (int or tuple, optional): Zero-padding added to both sides of the input.
148
+ """""
149
+
150
+ """
151
+ nn.Conv2d(self.input_dim, 64, 4, 2, 1), #para mi el 2 tendria que ser 1
152
+ nn.LeakyReLU(0.2),
153
+ nn.Conv2d(64, 32, 4, 2, 1),
154
+ nn.LeakyReLU(0.2),
155
+ nn.MaxPool2d(4, stride=2),
156
+ nn.Conv2d(32, 32, 4, 2, 1),
157
+ nn.LeakyReLU(0.2),
158
+ nn.MaxPool2d(4, stride=2),
159
+ nn.Conv2d(32, 20, 4, 2, 1),
160
+ nn.BatchNorm2d(20),
161
+ nn.LeakyReLU(0.2),
162
+ """
163
+
164
+ self.conv = nn.Sequential(
165
+
166
+ nn.Conv2d(self.input_dim, 4, 4, 2, 1), # para mi el 2 tendria que ser 1
167
+ nn.LeakyReLU(0.2),
168
+ nn.Conv2d(4, 8, 4, 2, 1),
169
+ nn.BatchNorm2d(8),
170
+ nn.LeakyReLU(0.2),
171
+ nn.Conv2d(8, 16, 4, 2, 1),
172
+ nn.BatchNorm2d(16),
173
+
174
+ )
175
+
176
+ self.n_size = self._get_conv_output(self.input_shape)
177
+
178
+ self.fc1 = nn.Sequential(
179
+ nn.Linear(self.n_size // 4, 1024),
180
+ nn.BatchNorm1d(1024),
181
+ nn.LeakyReLU(0.2),
182
+ nn.Linear(1024, 512),
183
+ nn.BatchNorm1d(512),
184
+ nn.LeakyReLU(0.2),
185
+ nn.Linear(512, 256),
186
+ nn.BatchNorm1d(256),
187
+ nn.LeakyReLU(0.2),
188
+ nn.Linear(256, 128),
189
+ nn.BatchNorm1d(128),
190
+ nn.LeakyReLU(0.2),
191
+ nn.Linear(128, 64),
192
+ nn.BatchNorm1d(64),
193
+ nn.LeakyReLU(0.2),
194
+ )
195
+ self.dc = nn.Sequential(
196
+ nn.Linear(64, self.output_dim),
197
+ nn.Sigmoid(),
198
+ )
199
+ self.cl = nn.Sequential(
200
+ nn.Linear(64, self.class_num),
201
+ nn.Sigmoid(),
202
+ )
203
+ utils.initialize_weights(self)
204
+
205
+ # generate input sample and forward to get shape
206
+
207
+ def _get_conv_output(self, shape):
208
+ bs = 1
209
+ input = Variable(torch.rand(bs, *shape))
210
+ output_feat = self.conv(input.squeeze())
211
+ n_size = output_feat.data.view(bs, -1).size(1)
212
+ return n_size
213
+
214
+ def forward(self, input, origen):
215
+ # esto va a cambiar cuando tenga color
216
+ # if (len(input.shape) <= 3):
217
+ # input = input[:, None, :, :]
218
+ # im = im[:, None, :, :]
219
+ # print("D in shape",input.shape)
220
+
221
+ # print(input.shape)
222
+ # print("this si X:", x)
223
+ # print("now shape", x.shape)
224
+ x = input
225
+ x = x.type(torch.FloatTensor)
226
+ x = x.to(device='cuda:0')
227
+
228
+ x = torch.cat((x, origen), 1)
229
+ x = self.conv(x)
230
+ x = x.view(x.size(0), -1)
231
+ x = self.fc1(x)
232
+ d = self.dc(x)
233
+ c = self.cl(x)
234
+
235
+ return d, c
236
+
237
+
238
+ #######################################################################################################################
239
+ class UnetConvBlock(nn.Module):
240
+ '''
241
+ Convolutional block of a U-Net:
242
+ Conv2d - Batch normalization - LeakyReLU
243
+ Conv2D - Batch normalization - LeakyReLU
244
+ Basic Dropout (optional)
245
+ '''
246
+
247
+ def __init__(self, in_size, out_size, dropout=0.0, stride=1, batch_norm = True):
248
+ '''
249
+ Constructor of the convolutional block
250
+ '''
251
+ super(UnetConvBlock, self).__init__()
252
+
253
+ # Convolutional layer with IN_SIZE --> OUT_SIZE
254
+ conv1 = nn.Conv2d(in_channels=in_size, out_channels=out_size, kernel_size=3, stride=1,
255
+ padding=1) # podria aplicar stride 2
256
+ # Activation unit
257
+ activ_unit1 = nn.LeakyReLU(0.2)
258
+ # Add batch normalization if necessary
259
+ if batch_norm:
260
+ self.conv1 = nn.Sequential(conv1, nn.BatchNorm2d(out_size), activ_unit1)
261
+ else:
262
+ self.conv1 = nn.Sequential(conv1, activ_unit1)
263
+
264
+ # Convolutional layer with OUT_SIZE --> OUT_SIZE
265
+ conv2 = nn.Conv2d(in_channels=out_size, out_channels=out_size, kernel_size=3, stride=stride,
266
+ padding=1) # podria aplicar stride 2
267
+ # Activation unit
268
+ activ_unit2 = nn.LeakyReLU(0.2)
269
+
270
+ # Add batch normalization
271
+ if batch_norm:
272
+ self.conv2 = nn.Sequential(conv2, nn.BatchNorm2d(out_size), activ_unit2)
273
+ else:
274
+ self.conv2 = nn.Sequential(conv2, activ_unit2)
275
+ # Dropout
276
+ if dropout > 0.0:
277
+ self.drop = nn.Dropout(dropout)
278
+ else:
279
+ self.drop = None
280
+
281
+ def forward(self, inputs):
282
+ '''
283
+ Do a forward pass
284
+ '''
285
+ outputs = self.conv1(inputs)
286
+ outputs = self.conv2(outputs)
287
+ if not (self.drop is None):
288
+ outputs = self.drop(outputs)
289
+ return outputs
290
+
291
+
292
+ class UnetDeSingleConvBlock(nn.Module):
293
+ '''
294
+ DeConvolutional block of a U-Net:
295
+ Conv2d - Batch normalization - LeakyReLU
296
+ Basic Dropout (optional)
297
+ '''
298
+
299
+ def __init__(self, in_size, out_size, dropout=0.0, stride=1, padding=1, batch_norm = True ):
300
+ '''
301
+ Constructor of the convolutional block
302
+ '''
303
+ super(UnetDeSingleConvBlock, self).__init__()
304
+
305
+ # Convolutional layer with IN_SIZE --> OUT_SIZE
306
+ conv1 = nn.Conv2d(in_channels=in_size, out_channels=out_size, kernel_size=3, stride=stride, padding=1)
307
+ # Activation unit
308
+ activ_unit1 = nn.LeakyReLU(0.2)
309
+ # Add batch normalization if necessary
310
+ if batch_norm:
311
+ self.conv1 = nn.Sequential(conv1, nn.BatchNorm2d(out_size), activ_unit1)
312
+ else:
313
+ self.conv1 = nn.Sequential(conv1, activ_unit1)
314
+
315
+ # Dropout
316
+ if dropout > 0.0:
317
+ self.drop = nn.Dropout(dropout)
318
+ else:
319
+ self.drop = None
320
+
321
+ def forward(self, inputs):
322
+ '''
323
+ Do a forward pass
324
+ '''
325
+ outputs = self.conv1(inputs)
326
+ if not (self.drop is None):
327
+ outputs = self.drop(outputs)
328
+ return outputs
329
+
330
+
331
+ class UnetDeconvBlock(nn.Module):
332
+ '''
333
+ DeConvolutional block of a U-Net:
334
+ UnetDeSingleConvBlock (skip_connection)
335
+ Cat last_layer + skip_connection
336
+ UnetDeSingleConvBlock ( Cat )
337
+ Basic Dropout (optional)
338
+ '''
339
+
340
+ def __init__(self, in_size_layer, in_size_skip_con, out_size, dropout=0.0):
341
+ '''
342
+ Constructor of the convolutional block
343
+ '''
344
+ super(UnetDeconvBlock, self).__init__()
345
+
346
+ self.conv1 = UnetDeSingleConvBlock(in_size_skip_con, in_size_skip_con, dropout)
347
+ self.conv2 = UnetDeSingleConvBlock(in_size_layer + in_size_skip_con, out_size, dropout)
348
+
349
+ # Dropout
350
+ if dropout > 0.0:
351
+ self.drop = nn.Dropout(dropout)
352
+ else:
353
+ self.drop = None
354
+
355
+ def forward(self, inputs_layer, inputs_skip):
356
+ '''
357
+ Do a forward pass
358
+ '''
359
+
360
+ outputs = self.conv1(inputs_skip)
361
+
362
+ #outputs = changeDim(outputs, inputs_layer)
363
+
364
+ outputs = torch.cat((inputs_layer, outputs), 1)
365
+ outputs = self.conv2(outputs)
366
+
367
+ return outputs
368
+
369
+
370
+ class UpBlock(nn.Module):
371
+ """Upscaling then double conv"""
372
+
373
+ def __init__(self, in_size_layer, in_size_skip_con, out_size, bilinear=True):
374
+ super(UpBlock, self).__init__()
375
+
376
+ # if bilinear, use the normal convolutions to reduce the number of channels
377
+ if bilinear:
378
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
379
+ else:
380
+ self.up = nn.ConvTranspose2d(in_size_layer // 2, in_size_layer // 2, kernel_size=2, stride=2)
381
+
382
+ self.conv = UnetDeconvBlock(in_size_layer, in_size_skip_con, out_size)
383
+
384
+ def forward(self, inputs_layer, inputs_skip):
385
+
386
+ inputs_layer = self.up(inputs_layer)
387
+
388
+ # input is CHW
389
+ #inputs_layer = changeDim(inputs_layer, inputs_skip)
390
+
391
+ return self.conv(inputs_layer, inputs_skip)
392
+
393
+
394
+ class lastBlock(nn.Module):
395
+ '''
396
+ DeConvolutional block of a U-Net:
397
+ Conv2d - Batch normalization - LeakyReLU
398
+ Basic Dropout (optional)
399
+ '''
400
+
401
+ def __init__(self, in_size, out_size, dropout=0.0):
402
+ '''
403
+ Constructor of the convolutional block
404
+ '''
405
+ super(lastBlock, self).__init__()
406
+
407
+ # Convolutional layer with IN_SIZE --> OUT_SIZE
408
+ conv1 = nn.Conv2d(in_channels=in_size, out_channels=out_size, kernel_size=3, stride=1, padding=1)
409
+ # Activation unit
410
+ activ_unit1 = nn.Tanh()
411
+ # Add batch normalization if necessary
412
+ self.conv1 = nn.Sequential(conv1, nn.BatchNorm2d(out_size), activ_unit1)
413
+
414
+ # Dropout
415
+ if dropout > 0.0:
416
+ self.drop = nn.Dropout(dropout)
417
+ else:
418
+ self.drop = None
419
+
420
+ def forward(self, inputs):
421
+ '''
422
+ Do a forward pass
423
+ '''
424
+ outputs = self.conv1(inputs)
425
+ if not (self.drop is None):
426
+ outputs = self.drop(outputs)
427
+ return outputs
428
+
429
+
430
+ ################
431
+
432
+ class generator_UNet(nn.Module):
433
+ # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
434
+ # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
435
+ def __init__(self, input_dim=4, output_dim=1, input_shape=3, class_num=2, expand_net=3):
436
+ super(generator_UNet, self).__init__()
437
+ self.input_dim = input_dim + 1 # por la clase
438
+ self.output_dim = output_dim
439
+ # print ("self.output_dim", self.output_dim)
440
+ self.class_num = class_num
441
+ self.input_shape = list(input_shape)
442
+
443
+ self.input_shape[1] = self.input_dim # esto cambio despues por colores
444
+
445
+ self.expandNet = expand_net # 5
446
+
447
+ # Downsampling
448
+ self.conv1 = UnetConvBlock(self.input_dim, pow(2, self.expandNet), stride=1)
449
+ # self.maxpool1 = nn.MaxPool2d(kernel_size=2)
450
+ self.conv2 = UnetConvBlock(pow(2, self.expandNet), pow(2, self.expandNet + 1), stride=2)
451
+ # self.maxpool2 = nn.MaxPool2d(kernel_size=2)
452
+ self.conv3 = UnetConvBlock(pow(2, self.expandNet + 1), pow(2, self.expandNet + 2), stride=2)
453
+ # self.maxpool3 = nn.MaxPool2d(kernel_size=2)
454
+ # Middle ground
455
+ self.conv4 = UnetDeSingleConvBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 2), stride=2)
456
+ # UpSampling
457
+ self.up1 = UpBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 2), pow(2, self.expandNet + 1),
458
+ bilinear=True)
459
+ self.up2 = UpBlock(pow(2, self.expandNet + 1), pow(2, self.expandNet + 1), pow(2, self.expandNet),
460
+ bilinear=True)
461
+ self.up3 = UpBlock(pow(2, self.expandNet), pow(2, self.expandNet), 8, bilinear=True)
462
+ self.last = lastBlock(8, self.output_dim)
463
+
464
+ utils.initialize_weights(self)
465
+
466
+ def _get_conv_output(self, shape):
467
+ bs = 1
468
+ input = Variable(torch.rand(bs, *shape))
469
+ # print("inShape:",input.shape)
470
+ output_feat = self.conv1(input.squeeze()) ##CAMBIAR
471
+ # print ("output_feat",output_feat.shape)
472
+ n_size = output_feat.data.view(bs, -1).size(1)
473
+ # print ("n",n_size // 4)
474
+ return n_size // 4
475
+
476
+ def forward(self, clase, im):
477
+ x = im
478
+
479
+ ##PARA TENER LA CLASE DEL CORRIMIENTO
480
+ cl = ((clase == 1))
481
+ cl = cl[:, 1]
482
+ cl = cl.type(torch.FloatTensor)
483
+ max = (clase.size())[1] - 1
484
+ cl = cl / float(max)
485
+
486
+ ##crear imagen layer de corrimiento
487
+ tam = im.size()
488
+ layerClase = torch.ones([tam[0], tam[2], tam[3]], dtype=torch.float32, device="cuda:0")
489
+ for idx, item in enumerate(layerClase):
490
+ layerClase[idx] = item * cl[idx]
491
+ layerClase = layerClase.unsqueeze(0)
492
+ layerClase = layerClase.transpose(1, 0)
493
+
494
+ ##unir layer el rgb de la imagen
495
+ x = torch.cat((x, layerClase), 1)
496
+
497
+ x1 = self.conv1(x)
498
+ x2 = self.conv2(x1) # self.maxpool1(x1))
499
+ x3 = self.conv3(x2) # self.maxpool2(x2))
500
+ x4 = self.conv4(x3) # self.maxpool3(x3))
501
+ x = self.up1(x4, x3)
502
+ x = self.up2(x, x2)
503
+ x = self.up3(x, x1)
504
+ x = changeDim(x, im)
505
+ x = self.last(x)
506
+
507
+ return x
508
+
509
+
510
+ class discriminator_UNet(nn.Module):
511
+ # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
512
+ # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
513
+ def __init__(self, input_dim=1, output_dim=1, input_shape=[2, 2], class_num=10, expand_net = 2):
514
+ super(discriminator_UNet, self).__init__()
515
+ self.input_dim = input_dim * 2 # ya que le doy el origen
516
+ self.output_dim = output_dim
517
+ self.input_shape = list(input_shape)
518
+ self.class_num = class_num
519
+
520
+ self.input_shape[1] = self.input_dim # esto cambio despues por colores
521
+
522
+ self.expandNet = expand_net # 4
523
+
524
+ # Downsampling
525
+ self.conv1 = UnetConvBlock(self.input_dim, pow(2, self.expandNet), stride=1, dropout=0.3)
526
+ self.conv2 = UnetConvBlock(pow(2, self.expandNet), pow(2, self.expandNet + 1), stride=2, dropout=0.5)
527
+ self.conv3 = UnetConvBlock(pow(2, self.expandNet + 1), pow(2, self.expandNet + 2), stride=2, dropout=0.4)
528
+
529
+ # Middle ground
530
+ self.conv4 = UnetDeSingleConvBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 2), stride=2,
531
+ dropout=0.3)
532
+
533
+ self.n_size = self._get_conv_output(self.input_shape)
534
+
535
+ self.fc1 = nn.Sequential(
536
+ nn.Linear(self.n_size // 4, 1024),
537
+ nn.BatchNorm1d(1024),
538
+ nn.LeakyReLU(0.2),
539
+ )
540
+
541
+ self.dc = nn.Sequential(
542
+ nn.Linear(1024, self.output_dim),
543
+ # nn.Sigmoid(),
544
+ )
545
+ self.cl = nn.Sequential(
546
+ nn.Linear(1024, self.class_num),
547
+ nn.Softmax(dim=1), # poner el que la suma da 1
548
+ )
549
+ utils.initialize_weights(self)
550
+
551
+ # generate input sample and forward to get shape
552
+
553
+ def _get_conv_output(self, shape):
554
+ bs = 1
555
+ input = Variable(torch.rand(bs, *shape))
556
+ x = input.squeeze()
557
+ x = self.conv1(x)
558
+ x = self.conv2(x)
559
+ x = self.conv3(x)
560
+ x = self.conv4(x)
561
+ n_size = x.data.view(bs, -1).size(1)
562
+ return n_size
563
+
564
+ def forward(self, input, origen):
565
+ # esto va a cambiar cuando tenga color
566
+ # if (len(input.shape) <= 3):
567
+ # input = input[:, None, :, :]
568
+ # im = im[:, None, :, :]
569
+ # print("D in shape",input.shape)
570
+
571
+ # print(input.shape)
572
+ # print("this si X:", x)
573
+ # print("now shape", x.shape)
574
+ x = input
575
+ x = x.type(torch.FloatTensor)
576
+ x = x.to(device='cuda:0')
577
+
578
+ x = torch.cat((x, origen), 1)
579
+ x = self.conv1(x)
580
+ x = self.conv2(x)
581
+ x = self.conv3(x)
582
+ x = self.conv4(x)
583
+ x = x.view(x.size(0), -1)
584
+ x = self.fc1(x)
585
+ d = self.dc(x)
586
+ c = self.cl(x)
587
+
588
+ return d, c
589
+
590
+
591
+ def changeDim(x, y):
592
+ ''' Change dim-image from x to y '''
593
+
594
+ diffY = torch.tensor([y.size()[2] - x.size()[2]])
595
+ diffX = torch.tensor([y.size()[3] - x.size()[3]])
596
+ x = F.pad(x, [diffX // 2, diffX - diffX // 2,
597
+ diffY // 2, diffY - diffY // 2])
598
+ return x
599
+
600
+
601
+ ######################################## ACGAN ###########################################################
602
+
603
+ class depth_generator(nn.Module):
604
+ # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
605
+ # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
606
+ def __init__(self, input_dim=4, output_dim=1, input_shape=3, class_num=10, zdim=1, height=10, width=10):
607
+ super(depth_generator, self).__init__()
608
+ self.input_dim = input_dim
609
+ self.output_dim = output_dim
610
+ self.class_num = class_num
611
+ # print ("self.output_dim", self.output_dim)
612
+ self.input_shape = list(input_shape)
613
+ self.zdim = zdim
614
+ self.toPreDecov = 1024
615
+ self.toDecov = 1
616
+ self.height = height
617
+ self.width = width
618
+
619
+ self.input_shape[1] = self.input_dim # esto cambio despues por colores
620
+
621
+ # print("input shpe gen",self.input_shape)
622
+
623
+ self.conv1 = nn.Sequential(
624
+ ##############RED SUPER CHICA PARA QUE ANDE TO DO PORQUE RAM Y MEMORY
625
+ nn.Conv2d(self.input_dim, 2, 4, 2, 1), # para mi el 2 tendria que ser 1
626
+ nn.Conv2d(2, 1, 4, 2, 1),
627
+ nn.BatchNorm2d(1),
628
+ nn.LeakyReLU(0.2),
629
+ )
630
+
631
+ self.n_size = self._get_conv_output(self.input_shape)
632
+ # print ("self.n_size",self.n_size)
633
+ self.cubic = (self.n_size // 8192)
634
+ # print("self.cubic: ",self.cubic)
635
+
636
+ self.fc1 = nn.Sequential(
637
+ nn.Linear(self.n_size, self.n_size),
638
+ nn.BatchNorm1d(self.n_size),
639
+ nn.LeakyReLU(0.2),
640
+ )
641
+
642
+ self.preDeconv = nn.Sequential(
643
+ ##############RED SUPER CHICA PARA QUE ANDE TO DO PORQUE RAM Y MEMORY
644
+
645
+ # nn.Linear(self.toPreDecov + self.zdim + self.class_num, 1024),
646
+ # nn.BatchNorm1d(1024),
647
+ # nn.LeakyReLU(0.2),
648
+ # nn.Linear(1024, self.toDecov * self.height // 64 * self.width// 64),
649
+ # nn.BatchNorm1d(self.toDecov * self.height // 64 * self.width// 64),
650
+ # nn.LeakyReLU(0.2),
651
+ # nn.Linear(self.toDecov * self.height // 64 * self.width // 64 , self.toDecov * self.height // 32 * self.width // 32),
652
+ # nn.BatchNorm1d(self.toDecov * self.height // 32 * self.width // 32),
653
+ # nn.LeakyReLU(0.2),
654
+ # nn.Linear(self.toDecov * self.height // 32 * self.width // 32,
655
+ # 1 * self.height * self.width),
656
+ # nn.BatchNorm1d(1 * self.height * self.width),
657
+ # nn.LeakyReLU(0.2),
658
+
659
+ nn.Linear(self.n_size + self.zdim + self.class_num, 50),
660
+ nn.BatchNorm1d(50),
661
+ nn.LeakyReLU(0.2),
662
+ nn.Linear(50, 200),
663
+ nn.BatchNorm1d(200),
664
+ nn.LeakyReLU(0.2),
665
+ nn.Linear(200, self.output_dim * self.height * self.width),
666
+ nn.BatchNorm1d(self.output_dim * self.height * self.width),
667
+ nn.Tanh(), # Cambio porque hago como que termino ahi
668
+
669
+ )
670
+
671
+ """
672
+ self.deconv = nn.Sequential(
673
+ nn.ConvTranspose2d(self.toDecov, 2, 4, 2, 0),
674
+ nn.BatchNorm2d(2),
675
+ nn.ReLU(),
676
+ nn.ConvTranspose2d(2, self.output_dim, 4, 2, 1),
677
+ nn.Tanh(), #esta recomendado que la ultima sea TanH de la Generadora da valores entre -1 y 1
678
+ )
679
+ """
680
+ utils.initialize_weights(self)
681
+
682
+ def _get_conv_output(self, shape):
683
+ bs = 1
684
+ input = Variable(torch.rand(bs, *shape))
685
+ # print("inShape:",input.shape)
686
+ output_feat = self.conv1(input.squeeze())
687
+ # print ("output_feat",output_feat.shape)
688
+ n_size = output_feat.data.view(bs, -1).size(1)
689
+ # print ("n",n_size // 4)
690
+ return n_size // 4
691
+
692
+ def forward(self, input, clase, im, imDep):
693
+ ##Esto es lo que voy a hacer
694
+ # Cat entre la imagen y la profundidad
695
+ print ("H", self.height, "W", self.width)
696
+ # imDep = imDep[:, None, :, :]
697
+ # im = im[:, None, :, :]
698
+ print ("imdep", imDep.shape)
699
+ print ("im", im.shape)
700
+ x = torch.cat([im, imDep], 1)
701
+
702
+ # Ref Conv de ese cat
703
+ x = self.conv1(x)
704
+ x = x.view(x.size(0), -1)
705
+ print ("x:", x.shape)
706
+ x = self.fc1(x)
707
+ # print ("x:",x.shape)
708
+
709
+ # cat entre el ruido y la clase
710
+ y = torch.cat([input, clase], 1)
711
+ print("Cat entre input y clase", y.shape) # podria separarlo, unir primero con clase y despues con ruido
712
+
713
+ # Red Lineal que une la Conv con el cat anterior
714
+ x = torch.cat([x, y], 1)
715
+ x = self.preDeconv(x)
716
+ print ("antes de deconv", x.shape)
717
+ x = x.view(-1, self.output_dim, self.height, self.width)
718
+ print("Despues View: ", x.shape)
719
+ # Red que saca produce la imagen final
720
+ # x = self.deconv(x)
721
+ print("La salida de la generadora es: ", x.shape)
722
+
723
+ return x
724
+
725
+
726
+ class depth_discriminator(nn.Module):
727
+ # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
728
+ # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
729
+ def __init__(self, input_dim=1, output_dim=1, input_shape=2, class_num=10):
730
+ super(depth_discriminator, self).__init__()
731
+ self.input_dim = input_dim
732
+ self.output_dim = output_dim
733
+ self.input_shape = list(input_shape)
734
+ self.class_num = class_num
735
+
736
+ self.input_shape[1] = self.input_dim # esto cambio despues por colores
737
+ print(self.input_shape)
738
+
739
+ """""
740
+ in_channels (int): Number of channels in the input image
741
+ out_channels (int): Number of channels produced by the convolution
742
+ kernel_size (int or tuple): Size of the convolving kernel - lo que se agarra para la conv
743
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
744
+ padding (int or tuple, optional): Zero-padding added to both sides of the input.
745
+ """""
746
+
747
+ """
748
+ nn.Conv2d(self.input_dim, 64, 4, 2, 1), #para mi el 2 tendria que ser 1
749
+ nn.LeakyReLU(0.2),
750
+ nn.Conv2d(64, 32, 4, 2, 1),
751
+ nn.LeakyReLU(0.2),
752
+ nn.MaxPool2d(4, stride=2),
753
+ nn.Conv2d(32, 32, 4, 2, 1),
754
+ nn.LeakyReLU(0.2),
755
+ nn.MaxPool2d(4, stride=2),
756
+ nn.Conv2d(32, 20, 4, 2, 1),
757
+ nn.BatchNorm2d(20),
758
+ nn.LeakyReLU(0.2),
759
+ """
760
+
761
+ self.conv = nn.Sequential(
762
+
763
+ nn.Conv2d(self.input_dim, 4, 4, 2, 1), # para mi el 2 tendria que ser 1
764
+ nn.LeakyReLU(0.2),
765
+ nn.Conv2d(4, 8, 4, 2, 1),
766
+ nn.BatchNorm2d(8),
767
+ nn.LeakyReLU(0.2),
768
+ nn.Conv2d(8, 16, 4, 2, 1),
769
+ nn.BatchNorm2d(16),
770
+
771
+ )
772
+
773
+ self.n_size = self._get_conv_output(self.input_shape)
774
+
775
+ self.fc1 = nn.Sequential(
776
+ nn.Linear(self.n_size // 4, 1024),
777
+ nn.BatchNorm1d(1024),
778
+ nn.LeakyReLU(0.2),
779
+ nn.Linear(1024, 512),
780
+ nn.BatchNorm1d(512),
781
+ nn.LeakyReLU(0.2),
782
+ nn.Linear(512, 256),
783
+ nn.BatchNorm1d(256),
784
+ nn.LeakyReLU(0.2),
785
+ nn.Linear(256, 128),
786
+ nn.BatchNorm1d(128),
787
+ nn.LeakyReLU(0.2),
788
+ nn.Linear(128, 64),
789
+ nn.BatchNorm1d(64),
790
+ nn.LeakyReLU(0.2),
791
+ )
792
+ self.dc = nn.Sequential(
793
+ nn.Linear(64, self.output_dim),
794
+ nn.Sigmoid(),
795
+ )
796
+ self.cl = nn.Sequential(
797
+ nn.Linear(64, self.class_num),
798
+ nn.Sigmoid(),
799
+ )
800
+ utils.initialize_weights(self)
801
+
802
+ # generate input sample and forward to get shape
803
+
804
+ def _get_conv_output(self, shape):
805
+ bs = 1
806
+ input = Variable(torch.rand(bs, *shape))
807
+ output_feat = self.conv(input.squeeze())
808
+ n_size = output_feat.data.view(bs, -1).size(1)
809
+ return n_size
810
+
811
+ def forward(self, input, im):
812
+ # esto va a cambiar cuando tenga color
813
+ # if (len(input.shape) <= 3):
814
+ # input = input[:, None, :, :]
815
+ # im = im[:, None, :, :]
816
+ print("D in shape", input.shape)
817
+ print("D im shape", im.shape)
818
+ x = torch.cat([input, im], 1)
819
+ print(input.shape)
820
+ # print("this si X:", x)
821
+ # print("now shape", x.shape)
822
+ x = x.type(torch.FloatTensor)
823
+ x = x.to(device='cuda:0')
824
+ x = self.conv(x)
825
+ x = x.view(x.size(0), -1)
826
+ x = self.fc1(x)
827
+ d = self.dc(x)
828
+ c = self.cl(x)
829
+
830
+ return d, c
831
+
832
+
833
+ class depth_generator_UNet(nn.Module):
834
+ # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
835
+ # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
836
+ def __init__(self, input_dim=4, output_dim=1, class_num=10, expand_net=3, depth=True):
837
+ super(depth_generator_UNet, self).__init__()
838
+
839
+ if depth:
840
+ self.input_dim = input_dim + 1
841
+ else:
842
+ self.input_dim = input_dim
843
+ self.output_dim = output_dim
844
+ self.class_num = class_num
845
+ # print ("self.output_dim", self.output_dim)
846
+
847
+ self.expandNet = expand_net # 5
848
+ self.depth = depth
849
+
850
+ # Downsampling
851
+ self.conv1 = UnetConvBlock(self.input_dim, pow(2, self.expandNet))
852
+ # self.maxpool1 = nn.MaxPool2d(kernel_size=2)
853
+ self.conv2 = UnetConvBlock(pow(2, self.expandNet), pow(2, self.expandNet + 1), stride=2)
854
+ # self.maxpool2 = nn.MaxPool2d(kernel_size=2)
855
+ self.conv3 = UnetConvBlock(pow(2, self.expandNet + 1), pow(2, self.expandNet + 2), stride=2)
856
+ # self.maxpool3 = nn.MaxPool2d(kernel_size=2)
857
+ # Middle ground
858
+ self.conv4 = UnetDeSingleConvBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 2), stride=2)
859
+ # UpSampling
860
+ self.up1 = UpBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 2), pow(2, self.expandNet + 1))
861
+ self.up2 = UpBlock(pow(2, self.expandNet + 1), pow(2, self.expandNet + 1), pow(2, self.expandNet))
862
+ self.up3 = UpBlock(pow(2, self.expandNet), pow(2, self.expandNet), 8)
863
+ self.last = lastBlock(8, self.output_dim)
864
+
865
+ if depth:
866
+ self.upDep1 = UpBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 2), pow(2, self.expandNet + 1))
867
+ self.upDep2 = UpBlock(pow(2, self.expandNet + 1), pow(2, self.expandNet + 1), pow(2, self.expandNet))
868
+ self.upDep3 = UpBlock(pow(2, self.expandNet), pow(2, self.expandNet), 8)
869
+ self.lastDep = lastBlock(8, 1)
870
+
871
+
872
+
873
+ utils.initialize_weights(self)
874
+
875
+
876
+ def forward(self, clase, im, imDep):
877
+ ##Hago algo con el z?
878
+ #print (im.shape)
879
+ #print (z.shape)
880
+ #print (z)
881
+ #imz = torch.repeat_interleave(z, repeats=torch.tensor([2, 2]), dim=1)
882
+ #print (imz.shape)
883
+ #print (imz)
884
+ #sdadsadas
885
+ if self.depth:
886
+ x = torch.cat([im, imDep], 1)
887
+ x = torch.cat((x, clase), 1)
888
+ else:
889
+ x = torch.cat((im, clase), 1)
890
+ ##unir layer el rgb de la imagen
891
+
892
+
893
+ x1 = self.conv1(x)
894
+ x2 = self.conv2(x1) # self.maxpool1(x1))
895
+ x3 = self.conv3(x2) # self.maxpool2(x2))
896
+ x4 = self.conv4(x3) # self.maxpool3(x3))
897
+
898
+ x = self.up1(x4, x3)
899
+ x = self.up2(x, x2)
900
+ x = self.up3(x, x1)
901
+ #x = changeDim(x, im)
902
+ x = self.last(x)
903
+
904
+ #x = x[:, :3, :, :] #cambio teorico
905
+
906
+ if self.depth:
907
+ dep = self.upDep1(x4, x3)
908
+ dep = self.upDep2(dep, x2)
909
+ dep = self.upDep3(dep, x1)
910
+ # x = changeDim(x, im)
911
+ dep = self.lastDep(dep)
912
+ return x, dep
913
+ else:
914
+ return x,imDep
915
+
916
+
917
+ class depth_discriminator_UNet(nn.Module):
918
+ def __init__(self, input_dim=1, output_dim=1, input_shape=[8, 7, 128, 128], class_num=2, expand_net=2):
919
+ super(depth_discriminator_UNet, self).__init__()
920
+ self.input_dim = input_dim * 2 + 1
921
+
922
+ #discriminator_UNet.__init__(self, input_dim=self.input_dim, output_dim=output_dim, input_shape=input_shape,
923
+ # class_num=class_num, expand_net = expand_net)
924
+
925
+ self.output_dim = output_dim
926
+ self.input_shape = list(input_shape)
927
+ self.class_num = class_num
928
+ self.expandNet = expand_net
929
+
930
+ self.input_dim = input_dim * 2 + 1 # ya que le doy el origen + mapa de profundidad
931
+ self.conv1 = UnetConvBlock(self.input_dim, pow(2, self.expandNet), stride=1, dropout=0.3)
932
+ self.conv2 = UnetConvBlock(pow(2, self.expandNet), pow(2, self.expandNet + 1), stride=2, dropout=0.2)
933
+ self.conv3 = UnetConvBlock(pow(2, self.expandNet + 1), pow(2, self.expandNet + 2), stride=2, dropout=0.2)
934
+ self.conv4 = UnetDeSingleConvBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 2), stride=2,
935
+ dropout=0.3)
936
+
937
+ self.input_shape[1] = self.input_dim
938
+ self.n_size = self._get_conv_output(self.input_shape)
939
+
940
+ self.fc1 = nn.Sequential(
941
+ nn.Linear(self.n_size, 1024),
942
+ )
943
+
944
+ self.BnLr = nn.Sequential(
945
+ nn.BatchNorm1d(1024),
946
+ nn.LeakyReLU(0.2),
947
+ )
948
+
949
+ self.dc = nn.Sequential(
950
+ nn.Linear(1024, self.output_dim),
951
+ #nn.Sigmoid(),
952
+ )
953
+ self.cl = nn.Sequential(
954
+ nn.Linear(1024, self.class_num),
955
+ # nn.Softmax(dim=1), # poner el que la suma da 1
956
+ )
957
+
958
+ utils.initialize_weights(self)
959
+
960
+ def _get_conv_output(self, shape):
961
+ bs = 1
962
+ input = Variable(torch.rand(bs, *shape))
963
+ x = input.squeeze()
964
+ x = self.conv1(x)
965
+ x = self.conv2(x)
966
+ x = self.conv3(x)
967
+ x = self.conv4(x)
968
+ x = x.view(x.size(0), -1)
969
+ return x.shape[1]
970
+
971
+ def forward(self, input, origen, dep):
972
+ # esto va a cambiar cuando tenga color
973
+ # if (len(input.shape) <= 3):
974
+ # input = input[:, None, :, :]
975
+ # im = im[:, None, :, :]
976
+ # print("D in shape",input.shape)
977
+
978
+ # print(input.shape)
979
+ # print("this si X:", x)
980
+ # print("now shape", x.shape)
981
+ x = input
982
+
983
+ x = torch.cat((x, origen), 1)
984
+ x = torch.cat((x, dep), 1)
985
+ x = self.conv1(x)
986
+ x = self.conv2(x)
987
+ x = self.conv3(x)
988
+ x = self.conv4(x)
989
+ x = x.view(x.size(0), -1)
990
+ features = self.fc1(x)
991
+ x = self.BnLr(features)
992
+ d = self.dc(x)
993
+ c = self.cl(x)
994
+
995
+ return d, c, features
996
+
997
+ class depth_discriminator_noclass_UNet(nn.Module):
998
+ def __init__(self, input_dim=1, output_dim=1, input_shape=[8, 7, 128, 128], class_num=2, expand_net=2, depth=True, wgan = False):
999
+ super(depth_discriminator_noclass_UNet, self).__init__()
1000
+
1001
+ #discriminator_UNet.__init__(self, input_dim=self.input_dim, output_dim=output_dim, input_shape=input_shape,
1002
+ # class_num=class_num, expand_net = expand_net)
1003
+
1004
+ self.output_dim = output_dim
1005
+ self.input_shape = list(input_shape)
1006
+ self.class_num = class_num
1007
+ self.expandNet = expand_net
1008
+ self.depth = depth
1009
+ self.wgan = wgan
1010
+
1011
+ if depth:
1012
+ self.input_dim = input_dim * 2 + 2 # ya que le doy el origen + Dep + class
1013
+ else:
1014
+ self.input_dim = input_dim * 2 + 1 # ya que le doy el origen + class
1015
+ self.conv1 = UnetConvBlock(self.input_dim, pow(2, self.expandNet), stride=1, dropout=0.0, batch_norm = False )
1016
+ self.conv2 = UnetConvBlock(pow(2, self.expandNet), pow(2, self.expandNet + 1), stride=2, dropout=0.0, batch_norm = False )
1017
+ self.conv3 = UnetConvBlock(pow(2, self.expandNet + 1), pow(2, self.expandNet + 2), stride=2, dropout=0.0, batch_norm = False )
1018
+ self.conv4 = UnetConvBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 3), stride=2, dropout=0.0, batch_norm = False )
1019
+ self.conv5 = UnetDeSingleConvBlock(pow(2, self.expandNet + 3), pow(2, self.expandNet + 2), stride=1, dropout=0.0, batch_norm = False )
1020
+
1021
+ self.lastconvs = []
1022
+ imagesize = self.input_shape[2] / 8
1023
+ while imagesize > 4:
1024
+ self.lastconvs.append(UnetDeSingleConvBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 2), stride=2, dropout=0.0, batch_norm = False ))
1025
+ imagesize = imagesize/2
1026
+ else:
1027
+ self.lastconvs.append(UnetDeSingleConvBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 1), stride=1, dropout=0.0, batch_norm = False ))
1028
+
1029
+ self.input_shape[1] = self.input_dim
1030
+ self.n_size = self._get_conv_output(self.input_shape)
1031
+
1032
+ for layer in self.lastconvs:
1033
+ layer = layer.cuda()
1034
+
1035
+ self.fc1 = nn.Sequential(
1036
+ nn.Linear(self.n_size, 256),
1037
+ )
1038
+
1039
+ self.BnLr = nn.Sequential(
1040
+ nn.BatchNorm1d(256),
1041
+ nn.LeakyReLU(0.2),
1042
+ )
1043
+
1044
+ self.dc = nn.Sequential(
1045
+ nn.Linear(256, self.output_dim),
1046
+ #nn.Sigmoid(),
1047
+ )
1048
+
1049
+ utils.initialize_weights(self)
1050
+
1051
+ def _get_conv_output(self, shape):
1052
+ bs = 1
1053
+ input = Variable(torch.rand(bs, *shape))
1054
+ x = input.squeeze()
1055
+ x = self.conv1(x)
1056
+ x = self.conv2(x)
1057
+ x = self.conv3(x)
1058
+ x = self.conv4(x)
1059
+ x = self.conv5(x)
1060
+ for layer in self.lastconvs:
1061
+ x = layer(x)
1062
+ x = x.view(x.size(0), -1)
1063
+ return x.shape[1]
1064
+
1065
+ def forward(self, input, origen, dep, clase):
1066
+ # esto va a cambiar cuando tenga color
1067
+ # if (len(input.shape) <= 3):
1068
+ # input = input[:, None, :, :]
1069
+ # im = im[:, None, :, :]
1070
+ # print("D in shape",input.shape)
1071
+
1072
+ # print(input.shape)
1073
+ # print("this si X:", x)
1074
+ # print("now shape", x.shape)
1075
+ x = input
1076
+ ##unir layer el rgb de la imagen
1077
+ x = torch.cat((x, clase), 1)
1078
+
1079
+ x = torch.cat((x, origen), 1)
1080
+ if self.depth:
1081
+ x = torch.cat((x, dep), 1)
1082
+ x = self.conv1(x)
1083
+ x = self.conv2(x)
1084
+ x = self.conv3(x)
1085
+ x = self.conv4(x)
1086
+ x = self.conv5(x)
1087
+ for layer in self.lastconvs:
1088
+ x = layer(x)
1089
+ feature_vector = x.view(x.size(0), -1)
1090
+ x = self.fc1(feature_vector)
1091
+ x = self.BnLr(x)
1092
+ d = self.dc(x)
1093
+
1094
+ return d, feature_vector
config.ini ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [validation]
3
+ total = 50
4
+ 0 = 2822
5
+ 1 = 3038
6
+ 2 = 3760
7
+ 3 = 3512
8
+ 4 = 3349
9
+ 5 = 2812
10
+ 6 = 3383
11
+ 7 = 3606
12
+ 8 = 3612
13
+ 9 = 3666
14
+ 10 = 2933
15
+ 11 = 3613
16
+ 12 = 2881
17
+ 13 = 3609
18
+ 14 = 3066
19
+ 15 = 3654
20
+ 16 = 2821
21
+ 17 = 2784
22
+ 18 = 3186
23
+ 19 = 3138
24
+ 20 = 3187
25
+ 21 = 3482
26
+ 22 = 2701
27
+ 23 = 3320
28
+ 24 = 3716
29
+ 25 = 3501
30
+ 26 = 3441
31
+ 27 = 3768
32
+ 28 = 3158
33
+ 29 = 2841
34
+ 30 = 3466
35
+ 31 = 3547
36
+ 32 = 2920
37
+ 33 = 3439
38
+ 34 = 2669
39
+ 35 = 3183
40
+ 36 = 2760
41
+ 37 = 3605
42
+ 38 = 2941
43
+ 39 = 3729
44
+ 40 = 2958
45
+ 41 = 3745
46
+ 42 = 3417
47
+ 43 = 3218
48
+ 44 = 3093
49
+ 45 = 3699
50
+ 46 = 3255
51
+ 47 = 3616
52
+ 48 = 3623
53
+ 49 = 3590
54
+ 50 = 3496
55
+ [test]
56
+ total = 1
57
+ [train]
58
+ total = 200
59
+ 0 = 3192
60
+ 1 = 3086
61
+ 2 = 3205
62
+ 3 = 3061
63
+ 4 = 2688
64
+ 5 = 3347
65
+ 6 = 2850
66
+ 7 = 3508
67
+ 8 = 3285
68
+ 9 = 3487
69
+ 10 = 3433
70
+ 11 = 2687
71
+ 12 = 2860
72
+ 13 = 3353
73
+ 14 = 3526
74
+ 15 = 3112
75
+ 16 = 3123
76
+ 17 = 3109
77
+ 18 = 2825
78
+ 19 = 3114
79
+ 20 = 3413
80
+ 21 = 2876
81
+ 22 = 2910
82
+ 23 = 3339
83
+ 24 = 3011
84
+ 25 = 2753
85
+ 26 = 3551
86
+ 27 = 2942
87
+ 28 = 2998
88
+ 29 = 3370
89
+ 30 = 3560
90
+ 31 = 3446
91
+ 32 = 3017
92
+ 33 = 3703
93
+ 34 = 3327
94
+ 35 = 3498
95
+ 36 = 2884
96
+ 37 = 2934
97
+ 38 = 2671
98
+ 39 = 2871
99
+ 40 = 2727
100
+ 41 = 3144
101
+ 42 = 3393
102
+ 43 = 3693
103
+ 44 = 2761
104
+ 45 = 2895
105
+ 46 = 3537
106
+ 47 = 3735
107
+ 48 = 2755
108
+ 49 = 2710
109
+ 50 = 3379
110
+ 51 = 3475
111
+ 52 = 2750
112
+ 53 = 3390
113
+ 54 = 3189
114
+ 55 = 2817
115
+ 56 = 3765
116
+ 57 = 3653
117
+ 58 = 2776
118
+ 59 = 3568
119
+ 60 = 2782
120
+ 61 = 3079
121
+ 62 = 3283
122
+ 63 = 2999
123
+ 64 = 3586
124
+ 65 = 2740
125
+ 66 = 3651
126
+ 67 = 3549
127
+ 68 = 3106
128
+ 69 = 3160
129
+ 70 = 3092
130
+ 71 = 2940
131
+ 72 = 3603
132
+ 73 = 3733
133
+ 74 = 3371
134
+ 75 = 3290
135
+ 76 = 3091
136
+ 77 = 2978
137
+ 78 = 3730
138
+ 79 = 2961
139
+ 80 = 2748
140
+ 81 = 3094
141
+ 82 = 2914
142
+ 83 = 3490
143
+ 84 = 3120
144
+ 85 = 3759
145
+ 86 = 2715
146
+ 87 = 3287
147
+ 88 = 3723
148
+ 89 = 3776
149
+ 90 = 3305
150
+ 91 = 2830
151
+ 92 = 3313
152
+ 93 = 3368
153
+ 94 = 2944
154
+ 95 = 2925
155
+ 96 = 3780
156
+ 97 = 2680
157
+ 98 = 3622
158
+ 99 = 3065
159
+ 100 = 2905
160
+ 101 = 3346
161
+ 102 = 3397
162
+ 103 = 2875
163
+ 104 = 3262
164
+ 105 = 2783
165
+ 106 = 3485
166
+ 107 = 3234
167
+ 108 = 3330
168
+ 109 = 3099
169
+ 110 = 3625
170
+ 111 = 3540
171
+ 112 = 3523
172
+ 113 = 3279
173
+ 114 = 3280
174
+ 115 = 3428
175
+ 116 = 3372
176
+ 117 = 3497
177
+ 118 = 3626
178
+ 119 = 2733
179
+ 120 = 3578
180
+ 121 = 3593
181
+ 122 = 3700
182
+ 123 = 3167
183
+ 124 = 2848
184
+ 125 = 2775
185
+ 126 = 3726
186
+ 127 = 3425
187
+ 128 = 3751
188
+ 129 = 3520
189
+ 130 = 3458
190
+ 131 = 3164
191
+ 132 = 3381
192
+ 133 = 2873
193
+ 134 = 2890
194
+ 135 = 3548
195
+ 136 = 3728
196
+ 137 = 2745
197
+ 138 = 3041
198
+ 139 = 3663
199
+ 140 = 3098
200
+ 141 = 3631
201
+ 142 = 3127
202
+ 143 = 3704
203
+ 144 = 3658
204
+ 145 = 3629
205
+ 146 = 3467
206
+ 147 = 2676
207
+ 148 = 3178
208
+ 149 = 3275
209
+ 150 = 3324
210
+ 151 = 2756
211
+ 152 = 3200
212
+ 153 = 3034
213
+ 154 = 3749
214
+ 155 = 3558
215
+ 156 = 3173
216
+ 157 = 3792
217
+ 158 = 2681
218
+ 159 = 3367
219
+ 160 = 3579
220
+ 161 = 3155
221
+ 162 = 3128
222
+ 163 = 2816
223
+ 164 = 2973
224
+ 165 = 3246
225
+ 166 = 3129
226
+ 167 = 3762
227
+ 168 = 2939
228
+ 169 = 2929
229
+ 170 = 3711
230
+ 171 = 3608
231
+ 172 = 2679
232
+ 173 = 3214
233
+ 174 = 3687
234
+ 175 = 3291
235
+ 176 = 2700
236
+ 177 = 3131
237
+ 178 = 3597
238
+ 179 = 3519
239
+ 180 = 3481
240
+ 181 = 2725
241
+ 182 = 3761
242
+ 183 = 3610
243
+ 184 = 3073
244
+ 185 = 3135
245
+ 186 = 2891
246
+ 187 = 3769
247
+ 188 = 3557
248
+ 189 = 2967
249
+ 190 = 2697
250
+ 191 = 2861
251
+ 192 = 2956
252
+ 193 = 3052
253
+ 194 = 2995
254
+ 195 = 3054
255
+ 196 = 3588
256
+ 197 = 2960
257
+ 198 = 2952
258
+ 199 = 2766
259
+ 200 = 2917
dataloader.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from torchvision import datasets, transforms
3
+ from torch.utils.data import Dataset
4
+ import torch
5
+ from configparser import ConfigParser
6
+ import matplotlib.pyplot as plt
7
+ import os
8
+ import torch as th
9
+ from PIL import Image
10
+ import numpy as np
11
+ import random
12
+ from PIL import ImageMath
13
+ import random
14
+
15
+ def dataloader(dataset, input_size, batch_size,dim,split='train', trans=False):
16
+ #transform = transforms.Compose([transforms.Resize((input_size, input_size)), transforms.ToTensor(),
17
+ # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
18
+ if dataset == 'mnist':
19
+ data_loader = DataLoader(
20
+ datasets.MNIST('data/mnist', train=True, download=True, transform=transform),
21
+ batch_size=batch_size, shuffle=True)
22
+ elif dataset == 'fashion-mnist':
23
+ data_loader = DataLoader(
24
+ datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transform),
25
+ batch_size=batch_size, shuffle=True)
26
+ elif dataset == 'cifar10':
27
+ data_loader = DataLoader(
28
+ datasets.CIFAR10('data/cifar10', train=True, download=True, transform=transform),
29
+ batch_size=batch_size, shuffle=True)
30
+ elif dataset == 'svhn':
31
+ data_loader = DataLoader(
32
+ datasets.SVHN('data/svhn', split=split, download=True, transform=transform),
33
+ batch_size=batch_size, shuffle=True)
34
+ elif dataset == 'stl10':
35
+ data_loader = DataLoader(
36
+ datasets.STL10('data/stl10', split=split, download=True, transform=transform),
37
+ batch_size=batch_size, shuffle=True)
38
+ elif dataset == 'lsun-bed':
39
+ data_loader = DataLoader(
40
+ datasets.LSUN('data/lsun', classes=['bedroom_train'], transform=transform),
41
+ batch_size=batch_size, shuffle=True)
42
+ elif dataset == '4cam':
43
+ if split == 'score':
44
+ cams = ScoreDataset(root_dir=os.getcwd() + '/Images/Score-Test', dim=dim, name=split, cant_images=300) #hardcode is bad but quick
45
+ return DataLoader(cams, batch_size=batch_size, shuffle=False, num_workers=0)
46
+ if split != 'test':
47
+ cams = ImagesDataset(root_dir=os.getcwd() + '/Images/ActualDataset', dim=dim, name=split, transform=trans)
48
+ return DataLoader(cams, batch_size=batch_size, shuffle=True, num_workers=0)
49
+ else:
50
+ cams = TestingDataset(root_dir=os.getcwd() + '/Images/Input-Test', dim=dim, name=split)
51
+ return DataLoader(cams, batch_size=batch_size, shuffle=False, num_workers=0)
52
+
53
+ return data_loader
54
+
55
+
56
+ class ImagesDataset(Dataset):
57
+ """My dataset."""
58
+
59
+ def __init__(self, root_dir, dim, name, transform):
60
+ """
61
+ Args:
62
+ root_dir (string): Directory with all the images.
63
+ transform (callable, optional): Optional transform to be applied
64
+ on a sample.
65
+ """
66
+ self.root_dir = root_dir
67
+ self.nCameras = 2
68
+ self.imageDim = dim
69
+ self.name = name
70
+ self.parser = ConfigParser()
71
+ self.parser.read('config.ini')
72
+ self.transform = transform
73
+
74
+ def __len__(self):
75
+
76
+ return self.parser.getint(self.name, 'total')
77
+ #oneCameRoot = self.root_dir + '\CAM1'
78
+ #return int(len([name for name in os.listdir(oneCameRoot) if os.path.isfile(os.path.join(oneCameRoot, name))])/2) #por el depth
79
+
80
+
81
+ def __getitem__(self, idx):
82
+ if th.is_tensor(idx):
83
+ idx = idx.tolist()
84
+ idx = self.parser.get(self.name, str(idx))
85
+ if self.transform:
86
+ brighness = random.uniform(0.7, 1.2)
87
+ saturation = random.uniform(0, 2)
88
+ contrast = random.uniform(0.4, 2)
89
+ gamma = random.uniform(0.7, 1.3)
90
+ hue = random.uniform(-0.3, 0.3) # 0.01
91
+
92
+ oneCameRoot = self.root_dir + '/CAM0'
93
+
94
+ # foto normal
95
+ img_name = os.path.join(oneCameRoot, "n_" + idx + ".png")
96
+ img = Image.open(img_name).convert('RGB') # .convert('L')
97
+ if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
98
+ img = img.resize((self.imageDim, self.imageDim))
99
+ if self.transform:
100
+ img = transforms.functional.adjust_gamma(img, gamma)
101
+ img = transforms.functional.adjust_brightness(img, brighness)
102
+ img = transforms.functional.adjust_contrast(img, contrast)
103
+ img = transforms.functional.adjust_saturation(img, saturation)
104
+ img = transforms.functional.adjust_hue(img, hue)
105
+ x1 = transforms.ToTensor()(img)
106
+ x1 = (x1 * 2) - 1
107
+
108
+ # foto produndidad
109
+ img_name = os.path.join(oneCameRoot, "d_" + idx + ".png")
110
+ img = Image.open(img_name).convert('I')
111
+ img = convert_I_to_L(img)
112
+ if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
113
+ img = img.resize((self.imageDim, self.imageDim))
114
+ x1_dep = transforms.ToTensor()(img)
115
+ x1_dep = (x1_dep * 2) - 1
116
+
117
+ oneCameRoot = self.root_dir + '/CAM1'
118
+
119
+ # foto normal
120
+ img_name = os.path.join(oneCameRoot, "n_" + idx + ".png")
121
+ img = Image.open(img_name).convert('RGB') # .convert('L')
122
+ if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
123
+ img = img.resize((self.imageDim, self.imageDim))
124
+ if self.transform:
125
+ img = transforms.functional.adjust_gamma(img, gamma)
126
+ img = transforms.functional.adjust_brightness(img, brighness)
127
+ img = transforms.functional.adjust_contrast(img, contrast)
128
+ img = transforms.functional.adjust_saturation(img, saturation)
129
+ img = transforms.functional.adjust_hue(img, hue)
130
+ x2 = transforms.ToTensor()(img)
131
+ x2 = (x2 * 2) - 1
132
+
133
+ # foto produndidad
134
+ img_name = os.path.join(oneCameRoot, "d_" + idx + ".png")
135
+ img = Image.open(img_name).convert('I')
136
+ img = convert_I_to_L(img)
137
+ if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
138
+ img = img.resize((self.imageDim, self.imageDim))
139
+ x2_dep = transforms.ToTensor()(img)
140
+ x2_dep = (x2_dep * 2) - 1
141
+
142
+
143
+ #random izq o derecha
144
+ if (bool(random.getrandbits(1))):
145
+ sample = {'x_im': x1, 'x_dep': x1_dep, 'y_im': x2, 'y_dep': x2_dep, 'y_': torch.ones(1, self.imageDim, self.imageDim)}
146
+ else:
147
+ sample = {'x_im': x2, 'x_dep': x2_dep, 'y_im': x1, 'y_dep': x1_dep, 'y_': torch.zeros(1, self.imageDim, self.imageDim)}
148
+
149
+ return sample
150
+
151
+ def __iter__(self):
152
+
153
+ for i in range(this.__len__()):
154
+ list.append(this.__getitem__(i))
155
+ return iter(list)
156
+
157
+ class TestingDataset(Dataset):
158
+ """My dataset."""
159
+
160
+ def __init__(self, root_dir, dim, name):
161
+ """
162
+ Args:
163
+ root_dir (string): Directory with all the images.
164
+ transform (callable, optional): Optional transform to be applied
165
+ on a sample.
166
+ """
167
+ self.root_dir = root_dir
168
+ self.imageDim = dim
169
+ self.name = name
170
+ files = os.listdir(self.root_dir)
171
+ self.files = [ele for ele in files if not ele.endswith('_d.png')]
172
+
173
+ def __len__(self):
174
+
175
+ #return self.parser.getint(self.name, 'total')
176
+ #oneCameRoot = self.root_dir + '\CAM1'
177
+ #return int(len([name for name in os.listdir(self.root_dir) if os.path.isfile(os.path.join(self.root_dir, name))])/2) #por el depth
178
+ return len(self.files)
179
+
180
+
181
+ def __getitem__(self, idx):
182
+ if th.is_tensor(idx):
183
+ idx = idx.tolist()
184
+
185
+ # foto normal
186
+ img_name = os.path.join(self.root_dir, self.files[idx])
187
+ img = Image.open(img_name).convert('RGB') # .convert('L')
188
+ if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
189
+ img = img.resize((self.imageDim, self.imageDim))
190
+ x1 = transforms.ToTensor()(img)
191
+ x1 = (x1 * 2) - 1
192
+
193
+
194
+ # foto produndidad
195
+ img_name = os.path.join(self.root_dir , self.files[idx][:-4] + "_d.png")
196
+ img = Image.open(img_name).convert('I')
197
+ img = convert_I_to_L(img)
198
+ if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
199
+ img = img.resize((self.imageDim, self.imageDim))
200
+ x1_dep = transforms.ToTensor()(img)
201
+ x1_dep = (x1_dep * 2) - 1
202
+
203
+ sample = {'x_im': x1, 'x_dep': x1_dep}
204
+
205
+ return sample
206
+
207
+ def __iter__(self):
208
+
209
+ for i in range(this.__len__()):
210
+ list.append(this.__getitem__(i))
211
+ return iter(list)
212
+
213
+
214
+ def show_image(t_data, grey=False):
215
+
216
+ #from numpy
217
+ t_data2 = t_data.transpose(1, 2, 0)
218
+ t_data2 = t_data2 * 255.0
219
+ t_data2 = t_data2.astype(np.uint8)
220
+ if (not grey):
221
+ outIm = Image.fromarray(t_data2, mode='RGB')
222
+ else:
223
+ t_data2 = np.squeeze(t_data2, axis=2)
224
+ outIm = Image.fromarray(t_data2, mode='L')
225
+ outIm.show()
226
+
227
+ def convert_I_to_L(img):
228
+ array = np.uint8(np.array(img) / 256) #el numero esta bien, sino genera espacios en negro en la imagen
229
+ return Image.fromarray(array)
230
+
231
+ class ScoreDataset(Dataset):
232
+ """My dataset."""
233
+
234
+ def __init__(self, root_dir, dim, name, cant_images):
235
+ """
236
+ Args:
237
+ root_dir (string): Directory with all the images.
238
+ transform (callable, optional): Optional transform to be applied
239
+ on a sample.
240
+ """
241
+ self.root_dir = root_dir
242
+ self.nCameras = 2
243
+ self.imageDim = dim
244
+ self.name = name
245
+ self.size = cant_images
246
+
247
+ def __len__(self):
248
+
249
+ return self.size
250
+
251
+
252
+ def __getitem__(self, idx):
253
+
254
+ oneCameRoot = self.root_dir + '/CAM0'
255
+
256
+ idx = "{:04d}".format(idx)
257
+ # foto normal
258
+ img_name = os.path.join(oneCameRoot, "n_" + idx + ".png")
259
+ img = Image.open(img_name).convert('RGB') # .convert('L')
260
+ if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
261
+ img = img.resize((self.imageDim, self.imageDim))
262
+ x1 = transforms.ToTensor()(img)
263
+ x1 = (x1 * 2) - 1
264
+
265
+ # foto produndidad
266
+ img_name = os.path.join(oneCameRoot, "d_" + idx + ".png")
267
+ img = Image.open(img_name).convert('I')
268
+ img = convert_I_to_L(img)
269
+ if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
270
+ img = img.resize((self.imageDim, self.imageDim))
271
+ x1_dep = transforms.ToTensor()(img)
272
+ x1_dep = (x1_dep * 2) - 1
273
+
274
+ oneCameRoot = self.root_dir + '/CAM1'
275
+
276
+ # foto normal
277
+ img_name = os.path.join(oneCameRoot, "n_" + idx + ".png")
278
+ img = Image.open(img_name).convert('RGB') # .convert('L')
279
+ if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
280
+ img = img.resize((self.imageDim, self.imageDim))
281
+ x2 = transforms.ToTensor()(img)
282
+ x2 = (x2 * 2) - 1
283
+
284
+ # foto produndidad
285
+ img_name = os.path.join(oneCameRoot, "d_" + idx + ".png")
286
+ img = Image.open(img_name).convert('I')
287
+ img = convert_I_to_L(img)
288
+ if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
289
+ img = img.resize((self.imageDim, self.imageDim))
290
+ x2_dep = transforms.ToTensor()(img)
291
+ x2_dep = (x2_dep * 2) - 1
292
+
293
+
294
+ sample = {'x_im': x1, 'x_dep': x1_dep, 'y_im': x2, 'y_dep': x2_dep, 'y_': torch.ones(1, self.imageDim, self.imageDim)}
295
+ return sample
296
+
297
+ def __iter__(self):
298
+
299
+ for i in range(self.__len__()):
300
+ list.append(self.__getitem__(i))
301
+ return iter(list)
epochData.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:baf9bf7acbc95f817b9f79d9be24fe553e8beeacda79854ebcfe9fc2707df120
3
+ size 210
main.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ from WiggleGAN import WiggleGAN
5
+ #from MyACGAN import MyACGAN
6
+ #from MyGAN import MyGAN
7
+
8
+ """parsing and configuration"""
9
+
10
+
11
+ def parse_args():
12
+ desc = "Pytorch implementation of GAN collections"
13
+ parser = argparse.ArgumentParser(description=desc)
14
+
15
+ parser.add_argument('--gan_type', type=str, default='WiggleGAN',
16
+ choices=['MyACGAN', 'MyGAN', 'WiggleGAN'],
17
+ help='The type of GAN')
18
+ parser.add_argument('--dataset', type=str, default='4cam',
19
+ choices=['mnist', 'fashion-mnist', 'cifar10', 'cifar100', 'svhn', 'stl10', 'lsun-bed', '4cam'],
20
+ help='The name of dataset')
21
+ parser.add_argument('--split', type=str, default='', help='The split flag for svhn and stl10')
22
+ parser.add_argument('--epoch', type=int, default=50, help='The number of epochs to run')
23
+ parser.add_argument('--batch_size', type=int, default=16, help='The size of batch')
24
+ parser.add_argument('--input_size', type=int, default=10, help='The size of input image')
25
+ parser.add_argument('--save_dir', type=str, default='models',
26
+ help='Directory name to save the model')
27
+ parser.add_argument('--result_dir', type=str, default='results', help='Directory name to save the generated images')
28
+ parser.add_argument('--log_dir', type=str, default='logs', help='Directory name to save training logs')
29
+ parser.add_argument('--lrG', type=float, default=0.0002)
30
+ parser.add_argument('--lrD', type=float, default=0.001)
31
+ parser.add_argument('--beta1', type=float, default=0.5)
32
+ parser.add_argument('--beta2', type=float, default=0.999)
33
+ parser.add_argument('--gpu_mode', type=str2bool, default=True)
34
+ parser.add_argument('--benchmark_mode', type=str2bool, default=True)
35
+ parser.add_argument('--cameras', type=int, default=2)
36
+ parser.add_argument('--imageDim', type=int, default=128)
37
+ parser.add_argument('--epochV', type=int, default=0)
38
+ parser.add_argument('--cIm', type=int, default=4)
39
+ parser.add_argument('--seedLoad', type=str, default="-0000")
40
+ parser.add_argument('--zGF', type=float, default=0.2)
41
+ parser.add_argument('--zDF', type=float, default=0.2)
42
+ parser.add_argument('--bF', type=float, default=0.2)
43
+ parser.add_argument('--expandGen', type=int, default=3)
44
+ parser.add_argument('--expandDis', type=int, default=3)
45
+ parser.add_argument('--wiggleDepth', type=int, default=-1)
46
+ parser.add_argument('--visdom', type=str2bool, default=True)
47
+ parser.add_argument('--lambdaL1', type=int, default=100)
48
+ parser.add_argument('--clipping', type=float, default=-1)
49
+ parser.add_argument('--depth', type=str2bool, default=True)
50
+ parser.add_argument('--recreate', type=str2bool, default=False)
51
+ parser.add_argument('--name_wiggle', type=str, default='wiggle-result')
52
+
53
+ return check_args(parser.parse_args())
54
+
55
+
56
+ """checking arguments"""
57
+
58
+ def str2bool(v):
59
+ if isinstance(v, bool):
60
+ return v
61
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
62
+ return True
63
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
64
+ return False
65
+ else:
66
+ raise argparse.ArgumentTypeError('Boolean value expected.')
67
+
68
+
69
+ def check_args(args):
70
+ # --save_dir
71
+ if not os.path.exists(args.save_dir):
72
+ os.makedirs(args.save_dir)
73
+
74
+ # --result_dir
75
+ if not os.path.exists(args.result_dir):
76
+ os.makedirs(args.result_dir)
77
+
78
+ # --result_dir
79
+ if not os.path.exists(args.log_dir):
80
+ os.makedirs(args.log_dir)
81
+
82
+ # --epoch
83
+ try:
84
+ assert args.epoch >= 1
85
+ except:
86
+ print('number of epochs must be larger than or equal to one')
87
+
88
+ # --batch_size
89
+ try:
90
+ assert args.batch_size >= 1
91
+ except:
92
+ print('batch size must be larger than or equal to one')
93
+
94
+ return args
95
+
96
+
97
+ """main"""
98
+
99
+
100
+ def main():
101
+ # parse arguments
102
+ args = parse_args()
103
+ if args is None:
104
+ exit()
105
+
106
+ if args.benchmark_mode:
107
+ torch.backends.cudnn.benchmark = True
108
+
109
+ # declare instance for GAN
110
+ if args.gan_type == 'WiggleGAN':
111
+ gan = WiggleGAN(args)
112
+ #elif args.gan_type == 'MyACGAN':
113
+ # gan = MyACGAN(args)
114
+ #elif args.gan_type == 'MyGAN':
115
+ # gan = MyGAN(args)
116
+ else:
117
+ raise Exception("[!] There is no option for " + args.gan_type)
118
+
119
+ # launch the graph in a session
120
+ if (args.wiggleDepth < 0 and not args.recreate):
121
+ print(" [*] Training Starting!")
122
+ gan.train()
123
+ print(" [*] Training finished!")
124
+ else:
125
+ if not args.recreate:
126
+ print(" [*] Wiggle Started!")
127
+ gan.wiggleEf()
128
+ print(" [*] Wiggle finished!")
129
+ else:
130
+ print(" [*] Dataset recreation Started")
131
+ gan.recreate()
132
+ print(" [*] Dataset recreation finished")
133
+
134
+
135
+ if __name__ == '__main__':
136
+ main()
models/4cam/WiggleGAN/WiggleGAN_31219_110_G.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4b39604e99319045e9070632a7aa31cd5adbd0220126515093856f97af622ff
3
+ size 1252850
models/4cam/WiggleGAN/WiggleGAN_66942_110_G.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da994f51205701f9754dc1688cffd12b72f593f37c61833ec4b7c8860e152236
3
+ size 1252850
models/4cam/WiggleGAN/WiggleGAN_70466_110_G.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:310b22bf4f5375174b23347b85b64c9de7934cafec6a61b3d647bfb7f24b5ae7
3
+ size 1252850
models/4cam/WiggleGAN/WiggleGAN_70944_110_G.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5734a5e102c75e4afde944f2898171fb34373c002b651ca84901ed9f55ae385d
3
+ size 1252850
models/4cam/WiggleGAN/WiggleGAN_74962_110_G.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d06a0da4295b6b6c5277f3cf987327a60818460780fb3aec42e514cbc3f71c71
3
+ size 1252850
models/4cam/WiggleGAN/WiggleGAN_82122_110_G.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:170c8e095c66665ef87f199e5308a39d90fe2f5d0f2dfa5d8c789675657e0423
3
+ size 1252850
models/4cam/WiggleGAN/WiggleGAN_92332_110_G.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a9cbec7ad0978008bcda05a96865b71016663278ed18c935b25875f7b08a979
3
+ size 1252850
pyvenv.cfg ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ home = C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python37_64
2
+ include-system-site-packages = false
3
+ version = 3.7.8
requirements.txt CHANGED
@@ -1,4 +1,27 @@
1
  timm
2
- Pillow
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  torch
4
- opencv-python
 
1
  timm
2
+ opencv-python
3
+ certifi==2019.11.28
4
+ chardet==3.0.4
5
+ cycler==0.10.0
6
+ idna==2.8
7
+ imageio==2.5.0
8
+ jsonpatch==1.24
9
+ jsonpointer==2.0
10
+ kiwisolver==1.1.0
11
+ matplotlib==3.1.1
12
+ numpy==1.17.2
13
+ Pillow==6.1.0
14
+ pyparsing==2.4.2
15
+ python-dateutil==2.8.0
16
+ PyYAML==5.1.2
17
+ pyzmq==18.1.1
18
+ requests==2.22.0
19
+ scipy==1.1.0
20
+ six==1.12.0
21
+ urllib3==1.25.7
22
+ visdom==0.1.8.9
23
+ websocket-client==0.56.0
24
+ tornado==6.0.3
25
  torch
26
+ torchfile==0.1.0
27
+ torchvision==0.2.1
utils.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, gzip, torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import scipy.misc
5
+ import imageio
6
+ import matplotlib.pyplot as plt
7
+ from PIL import Image
8
+ from torchvision import datasets, transforms
9
+ import visdom
10
+ import random
11
+
12
+ def save_wiggle(images, rows=1, name="test"):
13
+
14
+
15
+ width = images[0].shape[1]
16
+ height = images[0].shape[2]
17
+ columns = int(len(images)/rows)
18
+ rows = int(rows)
19
+ margin = 4
20
+
21
+ total_width = (width + margin) * columns
22
+ total_height = (height + margin) * rows
23
+
24
+ new_im = Image.new('RGB', (total_width, total_height))
25
+
26
+ transToPil = transforms.ToPILImage()
27
+
28
+ x_offset = 3
29
+ y_offset = 3
30
+ for y in range(rows):
31
+ for x in range(columns):
32
+ im = images[x+y*columns]
33
+ im = transToPil((im+1)/2)
34
+ new_im.paste(im, (x_offset, y_offset))
35
+ x_offset += width + margin
36
+ x_offset = 3
37
+ y_offset += height + margin
38
+
39
+ new_im.save('./WiggleResults/' + name + '.jpg')
40
+
41
+ def load_mnist(dataset):
42
+ data_dir = os.path.join("./data", dataset)
43
+
44
+ def extract_data(filename, num_data, head_size, data_size):
45
+ with gzip.open(filename) as bytestream:
46
+ bytestream.read(head_size)
47
+ buf = bytestream.read(data_size * num_data)
48
+ data = np.frombuffer(buf, dtype=np.uint8).astype(np.float)
49
+ return data
50
+
51
+ data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28)
52
+ trX = data.reshape((60000, 28, 28, 1))
53
+
54
+ data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1)
55
+ trY = data.reshape((60000))
56
+
57
+ data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28)
58
+ teX = data.reshape((10000, 28, 28, 1))
59
+
60
+ data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1)
61
+ teY = data.reshape((10000))
62
+
63
+ trY = np.asarray(trY).astype(np.int)
64
+ teY = np.asarray(teY)
65
+
66
+ X = np.concatenate((trX, teX), axis=0)
67
+ y = np.concatenate((trY, teY), axis=0).astype(np.int)
68
+
69
+ seed = 547
70
+ np.random.seed(seed)
71
+ np.random.shuffle(X)
72
+ np.random.seed(seed)
73
+ np.random.shuffle(y)
74
+
75
+ y_vec = np.zeros((len(y), 10), dtype=np.float)
76
+ for i, label in enumerate(y):
77
+ y_vec[i, y[i]] = 1
78
+
79
+ X = X.transpose(0, 3, 1, 2) / 255.
80
+ # y_vec = y_vec.transpose(0, 3, 1, 2)
81
+
82
+ X = torch.from_numpy(X).type(torch.FloatTensor)
83
+ y_vec = torch.from_numpy(y_vec).type(torch.FloatTensor)
84
+ return X, y_vec
85
+
86
+ def load_celebA(dir, transform, batch_size, shuffle):
87
+ # transform = transforms.Compose([
88
+ # transforms.CenterCrop(160),
89
+ # transform.Scale(64)
90
+ # transforms.ToTensor(),
91
+ # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
92
+ # ])
93
+
94
+ # data_dir = 'data/celebA' # this path depends on your computer
95
+ dset = datasets.ImageFolder(dir, transform)
96
+ data_loader = torch.utils.data.DataLoader(dset, batch_size, shuffle)
97
+
98
+ return data_loader
99
+
100
+
101
+ def print_network(net):
102
+ num_params = 0
103
+ for param in net.parameters():
104
+ num_params += param.numel()
105
+ print(net)
106
+ print('Total number of parameters: %d' % num_params)
107
+
108
+ def save_images(images, size, image_path):
109
+ return imsave(images, size, image_path)
110
+
111
+ def imsave(images, size, path):
112
+ image = np.squeeze(merge(images, size))
113
+ return scipy.misc.imsave(path, image)
114
+
115
+ def merge(images, size):
116
+ #print ("shape", images.shape)
117
+ h, w = images.shape[1], images.shape[2]
118
+ if (images.shape[3] in (3,4)):
119
+ c = images.shape[3]
120
+ img = np.zeros((h * size[0], w * size[1], c))
121
+ for idx, image in enumerate(images):
122
+ i = idx % size[1]
123
+ j = idx // size[1]
124
+ img[j * h:j * h + h, i * w:i * w + w, :] = image
125
+ return img
126
+ elif images.shape[3]== 1:
127
+ img = np.zeros((h * size[0], w * size[1]))
128
+ for idx, image in enumerate(images):
129
+ #print("indez ",idx)
130
+ i = idx % size[1]
131
+ j = idx // size[1]
132
+ img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
133
+ return img
134
+ else:
135
+ raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')
136
+
137
+ def generate_animation(path, num):
138
+ images = []
139
+ for e in range(num):
140
+ img_name = path + '_epoch%04d' % (e+1) + '.png'
141
+ images.append(imageio.imread(img_name))
142
+ imageio.mimsave(path + '_generate_animation.gif', images, fps=5)
143
+
144
+ def loss_plot(hist, path = 'Train_hist.png', model_name = ''):
145
+ x1 = range(len(hist['D_loss_train']))
146
+ x2 = range(len(hist['G_loss_train']))
147
+
148
+ y1 = hist['D_loss_train']
149
+ y2 = hist['G_loss_train']
150
+
151
+ if (x1 != x2):
152
+ y1 = [0.0] * (len(y2) - len(y1)) + y1
153
+ x1 = x2
154
+
155
+ plt.plot(x1, y1, label='D_loss_train')
156
+
157
+ plt.plot(x2, y2, label='G_loss_train')
158
+
159
+ plt.xlabel('Iter')
160
+ plt.ylabel('Loss')
161
+
162
+ plt.legend(loc=4)
163
+ plt.grid(True)
164
+ plt.tight_layout()
165
+
166
+ path = os.path.join(path, model_name + '_loss.png')
167
+
168
+ plt.savefig(path)
169
+
170
+ plt.close()
171
+
172
+ def initialize_weights(net):
173
+ for m in net.modules():
174
+ if isinstance(m, nn.Conv2d):
175
+ m.weight.data.normal_(0, 0.02)
176
+ m.bias.data.zero_()
177
+ elif isinstance(m, nn.ConvTranspose2d):
178
+ m.weight.data.normal_(0, 0.02)
179
+ m.bias.data.zero_()
180
+ elif isinstance(m, nn.Linear):
181
+ m.weight.data.normal_(0, 0.02)
182
+ m.bias.data.zero_()
183
+
184
+ class VisdomLinePlotter(object):
185
+ """Plots to Visdom"""
186
+ def __init__(self, env_name='main'):
187
+ self.viz = visdom.Visdom()
188
+ self.env = env_name
189
+ self.ini = False
190
+ self.count = 1
191
+ def plot(self, var_name,names, split_name, hist):
192
+
193
+
194
+
195
+ x = []
196
+ y = []
197
+ for i, name in enumerate(names):
198
+ x.append(self.count)
199
+ y.append(hist[name])
200
+ self.count+=1
201
+ #x1 = (len(hist['D_loss_' +split_name]))
202
+ #x2 = (len(hist['G_loss_' +split_name]))
203
+
204
+ #y1 = hist['D_loss_'+split_name]
205
+ #y2 = hist['G_loss_'+split_name]
206
+
207
+
208
+ np.array(x)
209
+
210
+
211
+ for i,n in enumerate(names):
212
+ x[i] = np.arange(1, x[i]+1)
213
+
214
+ if not self.ini:
215
+ for i, name in enumerate(names):
216
+ if i == 0:
217
+ self.win = self.viz.line(X=x[i], Y=np.array(y[i]), env=self.env,name = name,opts=dict(
218
+ title=var_name + '_'+split_name, showlegend = True
219
+ ))
220
+ else:
221
+ self.viz.line(X=x[i], Y=np.array(y[i]), env=self.env,win=self.win, name=name, update='append')
222
+ self.ini = True
223
+ else:
224
+ x[0] = np.array([x[0][-2], x[0][-1]])
225
+
226
+ for i,n in enumerate(names):
227
+ y[i] = np.array([y[i][-2], y[i][-1]])
228
+ self.viz.line(X=x[0], Y=np.array(y[i]), env=self.env, win=self.win, name=n, update='append')
229
+
230
+
231
+ class VisdomLineTwoPlotter(VisdomLinePlotter):
232
+
233
+ def plot(self, var_name, epoch,names, hist):
234
+
235
+ x1 = epoch
236
+ y1 = hist[names[0]]
237
+ y2 = hist[names[1]]
238
+ y3 = hist[names[2]]
239
+ y4 = hist[names[3]]
240
+
241
+
242
+ #y1 = hist['D_loss_' + split_name]
243
+ #y2 = hist['G_loss_' + split_name]
244
+ #y3 = hist['D_loss_' + split_name2]
245
+ #y4 = hist['G_loss_' + split_name2]
246
+
247
+
248
+ #x1 = np.arange(1, x1+1)
249
+
250
+ if not self.ini:
251
+ self.win = self.viz.line(X=np.array([x1]), Y=np.array(y1), env=self.env,name = names[0],opts=dict(
252
+ title=var_name,
253
+ showlegend = True,
254
+ linecolor = np.array([[0, 0, 255]])
255
+ ))
256
+ self.viz.line(X=np.array([x1]), Y=np.array(y2), env=self.env,win=self.win, name=names[1],
257
+ update='append', opts=dict(
258
+ linecolor=np.array([[255, 153, 51]])
259
+ ))
260
+ self.viz.line(X=np.array([x1]), Y=np.array(y3), env=self.env, win=self.win, name=names[2],
261
+ update='append', opts=dict(
262
+ linecolor=np.array([[0, 51, 153]])
263
+ ))
264
+ self.viz.line(X=np.array([x1]), Y=np.array(y4), env=self.env, win=self.win, name=names[3],
265
+ update='append', opts=dict(
266
+ linecolor=np.array([[204, 51, 0]])
267
+ ))
268
+ self.ini = True
269
+ else:
270
+
271
+ y4 = np.array([y4[-2], y4[-1]])
272
+ y3 = np.array([y3[-2], y3[-1]])
273
+ y2 = np.array([y2[-2], y2[-1]])
274
+ y1 = np.array([y1[-2], y1[-1]])
275
+ x1 = np.array([x1 - 1, x1])
276
+ self.viz.line(X=x1, Y=np.array(y1), env=self.env, win=self.win, name=names[0], update='append')
277
+ self.viz.line(X=x1, Y=np.array(y2), env=self.env, win=self.win, name=names[1], update='append')
278
+ self.viz.line(X=x1, Y=np.array(y3), env=self.env, win=self.win, name=names[2],
279
+ update='append')
280
+ self.viz.line(X=x1, Y=np.array(y4), env=self.env, win=self.win, name=names[3],
281
+ update='append')
282
+
283
+ class VisdomImagePlotter(object):
284
+ """Plots to Visdom"""
285
+ def __init__(self, env_name='main'):
286
+ self.viz = visdom.Visdom()
287
+ self.env = env_name
288
+ def plot(self, epoch,images,rows):
289
+
290
+ list_images = []
291
+ for image in images:
292
+ #transforms.ToPILImage()(image)
293
+ image = (image + 1)/2
294
+ image = image.detach().numpy() * 255
295
+ list_images.append(image)
296
+ self.viz.images(
297
+ list_images,
298
+ padding=2,
299
+ nrow =rows,
300
+ opts=dict(title="epoch: " + str(epoch)),
301
+ env=self.env
302
+ )
303
+
304
+
305
+ def augmentData(x,y, randomness = 1, percent_noise = 0.1):
306
+ """
307
+ :param x: image X
308
+ :param y: image Y
309
+ :param randomness: Value of randomness (between 1 and 0)
310
+ :return: data x,y augmented
311
+ """
312
+
313
+
314
+ sampleX = torch.tensor([])
315
+ sampleY = torch.tensor([])
316
+
317
+ for aumX, aumY in zip(x,y):
318
+
319
+ # Preparing to get image # transforms.ToPILImage()(pil_to_tensor.squeeze_(0))
320
+ #percent_noise = percent_noise
321
+ #noise = torch.randn(aumX.shape)
322
+
323
+ #aumX = noise * percent_noise + aumX * (1 - percent_noise)
324
+ #aumY = noise * percent_noise + aumY * (1 - percent_noise)
325
+
326
+ aumX = (aumX + 1) / 2
327
+ aumY = (aumY + 1) / 2
328
+
329
+ imgX = transforms.ToPILImage()(aumX)
330
+ imgY = transforms.ToPILImage()(aumY)
331
+
332
+ # Values for augmentation #
333
+ brighness = random.uniform(0.7, 1.2)* randomness + (1-randomness)
334
+ saturation = random.uniform(0, 2)* randomness + (1-randomness)
335
+ contrast = random.uniform(0.4, 2)* randomness + (1-randomness)
336
+ gamma = random.uniform(0.7, 1.3)* randomness + (1-randomness)
337
+ hue = random.uniform(-0.3, 0.3)* randomness #0.01
338
+
339
+ imgX = transforms.functional.adjust_gamma(imgX, gamma)
340
+ imgX = transforms.functional.adjust_brightness(imgX, brighness)
341
+ imgX = transforms.functional.adjust_contrast(imgX, contrast)
342
+ imgX = transforms.functional.adjust_saturation(imgX, saturation)
343
+ imgX = transforms.functional.adjust_hue(imgX, hue)
344
+ #imgX.show()
345
+
346
+ imgY = transforms.functional.adjust_gamma(imgY, gamma)
347
+ imgY = transforms.functional.adjust_brightness(imgY, brighness)
348
+ imgY = transforms.functional.adjust_contrast(imgY, contrast)
349
+ imgY = transforms.functional.adjust_saturation(imgY, saturation)
350
+ imgY = transforms.functional.adjust_hue(imgY, hue)
351
+ #imgY.show()
352
+
353
+ sx = transforms.ToTensor()(imgX)
354
+ sx = (sx * 2)-1
355
+
356
+ sy = transforms.ToTensor()(imgY)
357
+ sy = (sy * 2)-1
358
+
359
+ sampleX = torch.cat((sampleX, sx.unsqueeze_(0)), 0)
360
+ sampleY = torch.cat((sampleY, sy.unsqueeze_(0)), 0)
361
+ return sampleX,sampleY
362
+
363
+ def RGBtoL (x):
364
+
365
+ return x[:,0,:,:].unsqueeze(0).transpose(0,1)
366
+
367
+ def LtoRGB (x):
368
+
369
+ return x.repeat(1, 3, 1, 1)