hylee commited on
Commit
d73173f
1 Parent(s): 8005fdc
Files changed (48) hide show
  1. APDrawingGAN2/data/__init__.py +75 -0
  2. APDrawingGAN2/data/aligned_dataset.py +288 -0
  3. APDrawingGAN2/data/base_data_loader.py +10 -0
  4. APDrawingGAN2/data/base_dataset.py +103 -0
  5. APDrawingGAN2/data/image_folder.py +68 -0
  6. APDrawingGAN2/data/single_dataset.py +176 -0
  7. APDrawingGAN2/docs/tips.md +8 -0
  8. APDrawingGAN2/models/__init__.py +39 -0
  9. APDrawingGAN2/models/apdrawingpp_style_model.py +692 -0
  10. APDrawingGAN2/models/base_model.py +545 -0
  11. APDrawingGAN2/models/networks.py +1194 -0
  12. APDrawingGAN2/models/test_model.py +214 -0
  13. APDrawingGAN2/options/__init__.py +0 -0
  14. APDrawingGAN2/options/base_options.py +192 -0
  15. APDrawingGAN2/options/test_options.py +23 -0
  16. APDrawingGAN2/options/train_options.py +62 -0
  17. APDrawingGAN2/preprocess/combine_A_and_B.py +48 -0
  18. APDrawingGAN2/preprocess/example/img_1701.jpg +0 -0
  19. APDrawingGAN2/preprocess/example/img_1701_aligned.png +0 -0
  20. APDrawingGAN2/preprocess/example/img_1701_aligned.txt +5 -0
  21. APDrawingGAN2/preprocess/example/img_1701_aligned_68lm.txt +68 -0
  22. APDrawingGAN2/preprocess/example/img_1701_aligned_bgmask.png +0 -0
  23. APDrawingGAN2/preprocess/example/img_1701_aligned_eyelmask.png +0 -0
  24. APDrawingGAN2/preprocess/example/img_1701_aligned_eyermask.png +0 -0
  25. APDrawingGAN2/preprocess/example/img_1701_aligned_facemask.png +0 -0
  26. APDrawingGAN2/preprocess/example/img_1701_aligned_mouthmask.png +0 -0
  27. APDrawingGAN2/preprocess/example/img_1701_aligned_nosemask.png +0 -0
  28. APDrawingGAN2/preprocess/example/img_1701_facial5point.mat +0 -0
  29. APDrawingGAN2/preprocess/face_align_512.m +55 -0
  30. APDrawingGAN2/preprocess/get_partmask.py +152 -0
  31. APDrawingGAN2/preprocess/readme.md +71 -0
  32. APDrawingGAN2/readme.md +105 -0
  33. APDrawingGAN2/requirements.txt +10 -0
  34. APDrawingGAN2/script/test.sh +2 -0
  35. APDrawingGAN2/script/test_single.sh +2 -0
  36. APDrawingGAN2/script/train.sh +3 -0
  37. APDrawingGAN2/test.py +69 -0
  38. APDrawingGAN2/train.py +67 -0
  39. APDrawingGAN2/util/__init__.py +0 -0
  40. APDrawingGAN2/util/get_data.py +115 -0
  41. APDrawingGAN2/util/html.py +68 -0
  42. APDrawingGAN2/util/image_pool.py +32 -0
  43. APDrawingGAN2/util/util.py +60 -0
  44. APDrawingGAN2/util/visualizer.py +171 -0
  45. README.md +1 -0
  46. app.py +210 -0
  47. packages.txt +2 -0
  48. requirements.txt +8 -0
APDrawingGAN2/data/__init__.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import torch.utils.data
3
+ from data.base_data_loader import BaseDataLoader
4
+ from data.base_dataset import BaseDataset
5
+
6
+
7
+ def find_dataset_using_name(dataset_name):
8
+ # Given the option --dataset_mode [datasetname],
9
+ # the file "data/datasetname_dataset.py"
10
+ # will be imported.
11
+ dataset_filename = "data." + dataset_name + "_dataset"
12
+ datasetlib = importlib.import_module(dataset_filename)
13
+
14
+ # In the file, the class called DatasetNameDataset() will
15
+ # be instantiated. It has to be a subclass of BaseDataset,
16
+ # and it is case-insensitive.
17
+ dataset = None
18
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
19
+ for name, cls in datasetlib.__dict__.items():
20
+ if name.lower() == target_dataset_name.lower() \
21
+ and issubclass(cls, BaseDataset):
22
+ dataset = cls
23
+
24
+ if dataset is None:
25
+ print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
26
+ exit(0)
27
+
28
+ return dataset
29
+
30
+
31
+ def get_option_setter(dataset_name):
32
+ dataset_class = find_dataset_using_name(dataset_name)
33
+ return dataset_class.modify_commandline_options
34
+
35
+
36
+ def create_dataset(opt):
37
+ dataset = find_dataset_using_name(opt.dataset_mode)
38
+ instance = dataset()
39
+ instance.initialize(opt)
40
+ print("dataset [%s] was created" % (instance.name()))
41
+ return instance
42
+
43
+
44
+ def CreateDataLoader(opt):
45
+ data_loader = CustomDatasetDataLoader()
46
+ data_loader.initialize(opt)
47
+ return data_loader
48
+
49
+
50
+ # Wrapper class of Dataset class that performs
51
+ # multi-threaded data loading
52
+ class CustomDatasetDataLoader(BaseDataLoader):
53
+ def name(self):
54
+ return 'CustomDatasetDataLoader'
55
+
56
+ def initialize(self, opt):
57
+ BaseDataLoader.initialize(self, opt)
58
+ self.dataset = create_dataset(opt)
59
+ self.dataloader = torch.utils.data.DataLoader(
60
+ self.dataset,
61
+ batch_size=opt.batch_size,
62
+ shuffle=not opt.serial_batches,#in training, serial_batches by default is false, shuffle=true
63
+ num_workers=int(opt.num_threads))
64
+
65
+ def load_data(self):
66
+ return self
67
+
68
+ def __len__(self):
69
+ return min(len(self.dataset), self.opt.max_dataset_size)
70
+
71
+ def __iter__(self):
72
+ for i, data in enumerate(self.dataloader):
73
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
74
+ break
75
+ yield data
APDrawingGAN2/data/aligned_dataset.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import random
3
+ import torchvision.transforms as transforms
4
+ import torch
5
+ from data.base_dataset import BaseDataset
6
+ from data.image_folder import make_dataset
7
+ from PIL import Image
8
+ import numpy as np
9
+ import cv2
10
+ import csv
11
+
12
+ def getfeats(featpath):
13
+ trans_points = np.empty([5,2],dtype=np.int64)
14
+ with open(featpath, 'r') as csvfile:
15
+ reader = csv.reader(csvfile, delimiter=' ')
16
+ for ind,row in enumerate(reader):
17
+ trans_points[ind,:] = row
18
+ return trans_points
19
+
20
+ def tocv2(ts):
21
+ img = (ts.numpy()/2+0.5)*255
22
+ img = img.astype('uint8')
23
+ img = np.transpose(img,(1,2,0))
24
+ img = img[:,:,::-1]#rgb->bgr
25
+ return img
26
+
27
+ def dt(img):
28
+ if(img.shape[2]==3):
29
+ img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
30
+ #convert to BW
31
+ ret1,thresh1 = cv2.threshold(img,127,255,cv2.THRESH_BINARY)
32
+ ret2,thresh2 = cv2.threshold(img,127,255,cv2.THRESH_BINARY_INV)
33
+ dt1 = cv2.distanceTransform(thresh1,cv2.DIST_L2,5)
34
+ dt2 = cv2.distanceTransform(thresh2,cv2.DIST_L2,5)
35
+ dt1 = dt1/dt1.max()#->[0,1]
36
+ dt2 = dt2/dt2.max()
37
+ return dt1, dt2
38
+
39
+ def getSoft(size,xb,yb,boundwidth=5.0):
40
+ xarray = np.tile(np.arange(0,size[1]),(size[0],1))
41
+ yarray = np.tile(np.arange(0,size[0]),(size[1],1)).transpose()
42
+ cxdists = []
43
+ cydists = []
44
+ for i in range(len(xb)):
45
+ xba = np.tile(xb[i],(size[1],1)).transpose()
46
+ yba = np.tile(yb[i],(size[0],1))
47
+ cxdists.append(np.abs(xarray-xba))
48
+ cydists.append(np.abs(yarray-yba))
49
+ xdist = np.minimum.reduce(cxdists)
50
+ ydist = np.minimum.reduce(cydists)
51
+ manhdist = np.minimum.reduce([xdist,ydist])
52
+ im = (manhdist+1) / (boundwidth+1) * 1.0
53
+ im[im>=1.0] = 1.0
54
+ return im
55
+
56
+ class AlignedDataset(BaseDataset):
57
+ @staticmethod
58
+ def modify_commandline_options(parser, is_train):
59
+ return parser
60
+
61
+ def initialize(self, opt):
62
+ self.opt = opt
63
+ self.root = opt.dataroot
64
+ imglist = 'datasets/apdrawing_list/%s/%s.txt' % (opt.phase, opt.dataroot)
65
+ if os.path.exists(imglist):
66
+ lines = open(imglist, 'r').read().splitlines()
67
+ lines = sorted(lines)
68
+ self.AB_paths = [line.split()[0] for line in lines]
69
+ if len(lines[0].split()) == 2:
70
+ self.B_paths = [line.split()[1] for line in lines]
71
+ else:
72
+ self.dir_AB = os.path.join(opt.dataroot, opt.phase)
73
+ self.AB_paths = sorted(make_dataset(self.dir_AB))
74
+ assert(opt.resize_or_crop == 'resize_and_crop')
75
+
76
+ def __getitem__(self, index):
77
+ AB_path = self.AB_paths[index]
78
+ AB = Image.open(AB_path).convert('RGB')
79
+ w, h = AB.size
80
+ if w/h == 2:
81
+ w2 = int(w / 2)
82
+ A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
83
+ B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
84
+ else: # if w/h != 2, need B_paths
85
+ A = AB.resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
86
+ B = Image.open(self.B_paths[index]).convert('RGB')
87
+ B = B.resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
88
+ A = transforms.ToTensor()(A)
89
+ B = transforms.ToTensor()(B)
90
+ w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
91
+ h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
92
+
93
+ A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]#C,H,W
94
+ B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
95
+
96
+ A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A)
97
+ B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B)
98
+
99
+ if self.opt.which_direction == 'BtoA':
100
+ input_nc = self.opt.output_nc
101
+ output_nc = self.opt.input_nc
102
+ else:
103
+ input_nc = self.opt.input_nc
104
+ output_nc = self.opt.output_nc
105
+
106
+ flipped = False
107
+ if (not self.opt.no_flip) and random.random() < 0.5:
108
+ flipped = True
109
+ idx = [i for i in range(A.size(2) - 1, -1, -1)]
110
+ idx = torch.LongTensor(idx)
111
+ A = A.index_select(2, idx)
112
+ B = B.index_select(2, idx)
113
+
114
+ if input_nc == 1: # RGB to gray
115
+ tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
116
+ A = tmp.unsqueeze(0)
117
+
118
+ if output_nc == 1: # RGB to gray
119
+ tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
120
+ B = tmp.unsqueeze(0)
121
+
122
+ item = {'A': A, 'B': B,
123
+ 'A_paths': AB_path, 'B_paths': AB_path}
124
+
125
+ if self.opt.use_local:
126
+ regions = ['eyel','eyer','nose','mouth']
127
+ basen = os.path.basename(AB_path)[:-4]+'.txt'
128
+ if self.opt.region_enm in [0,1]:
129
+ featdir = self.opt.lm_dir
130
+ featpath = os.path.join(featdir,basen)
131
+ feats = getfeats(featpath)
132
+ if flipped:
133
+ for i in range(5):
134
+ feats[i,0] = self.opt.fineSize - feats[i,0] - 1
135
+ tmp = [feats[0,0],feats[0,1]]
136
+ feats[0,:] = [feats[1,0],feats[1,1]]
137
+ feats[1,:] = tmp
138
+ mouth_x = int((feats[3,0]+feats[4,0])/2.0)
139
+ mouth_y = int((feats[3,1]+feats[4,1])/2.0)
140
+ ratio = self.opt.fineSize / 256
141
+ EYE_H = self.opt.EYE_H * ratio
142
+ EYE_W = self.opt.EYE_W * ratio
143
+ NOSE_H = self.opt.NOSE_H * ratio
144
+ NOSE_W = self.opt.NOSE_W * ratio
145
+ MOUTH_H = self.opt.MOUTH_H * ratio
146
+ MOUTH_W = self.opt.MOUTH_W * ratio
147
+ center = torch.IntTensor([[feats[0,0],feats[0,1]-4*ratio],[feats[1,0],feats[1,1]-4*ratio],[feats[2,0],feats[2,1]-NOSE_H/2+16*ratio],[mouth_x,mouth_y]])
148
+ item['center'] = center
149
+ rhs = [int(EYE_H),int(EYE_H),int(NOSE_H),int(MOUTH_H)]
150
+ rws = [int(EYE_W),int(EYE_W),int(NOSE_W),int(MOUTH_W)]
151
+ if self.opt.soft_border:
152
+ soft_border_mask4 = []
153
+ for i in range(4):
154
+ xb = [np.zeros(rhs[i]),np.ones(rhs[i])*(rws[i]-1)]
155
+ yb = [np.zeros(rws[i]),np.ones(rws[i])*(rhs[i]-1)]
156
+ soft_border_mask = getSoft([rhs[i],rws[i]],xb,yb)
157
+ soft_border_mask4.append(torch.Tensor(soft_border_mask).unsqueeze(0))
158
+ item['soft_'+regions[i]+'_mask'] = soft_border_mask4[i]
159
+ for i in range(4):
160
+ item[regions[i]+'_A'] = A[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2]
161
+ item[regions[i]+'_B'] = B[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2]
162
+ if self.opt.soft_border:
163
+ item[regions[i]+'_A'] = item[regions[i]+'_A'] * soft_border_mask4[i].repeat(int(input_nc/output_nc),1,1)
164
+ item[regions[i]+'_B'] = item[regions[i]+'_B'] * soft_border_mask4[i]
165
+ if self.opt.compactmask:
166
+ cmasks0 = []
167
+ cmasks = []
168
+ for i in range(4):
169
+ if flipped and i in [0,1]:
170
+ cmaskpath = os.path.join(self.opt.cmask_dir,regions[1-i],basen[:-4]+'.png')
171
+ else:
172
+ cmaskpath = os.path.join(self.opt.cmask_dir,regions[i],basen[:-4]+'.png')
173
+ im_cmask = Image.open(cmaskpath)
174
+ cmask0 = transforms.ToTensor()(im_cmask)
175
+ if flipped:
176
+ cmask0 = cmask0.index_select(2, idx)
177
+ if output_nc == 1 and cmask0.shape[0] == 3:
178
+ tmp = cmask0[0, ...] * 0.299 + cmask0[1, ...] * 0.587 + cmask0[2, ...] * 0.114
179
+ cmask0 = tmp.unsqueeze(0)
180
+ cmask0 = (cmask0 >= 0.5).float()
181
+ cmasks0.append(cmask0)
182
+ cmask = cmask0.clone()
183
+ if self.opt.region_enm in [0,1]:
184
+ cmask = cmask[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2]
185
+ elif self.opt.region_enm in [2]: # need to multiply cmask
186
+ item[regions[i]+'_A'] = (A/2+0.5) * cmask * 2 - 1
187
+ item[regions[i]+'_B'] = (B/2+0.5) * cmask * 2 - 1
188
+ cmasks.append(cmask)
189
+ item['cmaskel'] = cmasks[0]
190
+ item['cmasker'] = cmasks[1]
191
+ item['cmask'] = cmasks[2]
192
+ item['cmaskmo'] = cmasks[3]
193
+ if self.opt.hair_local:
194
+ mask = torch.ones(B.shape)
195
+ if self.opt.region_enm == 0:
196
+ for i in range(4):
197
+ mask[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2] = 0
198
+ if self.opt.soft_border:
199
+ imgsize = self.opt.fineSize
200
+ maskn = mask[0].numpy()
201
+ masks = [np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize])]
202
+ masks[0][1:] = maskn[:-1]
203
+ masks[1][:-1] = maskn[1:]
204
+ masks[2][:,1:] = maskn[:,:-1]
205
+ masks[3][:,:-1] = maskn[:,1:]
206
+ masks2 = [maskn-e for e in masks]
207
+ bound = np.minimum.reduce(masks2)
208
+ bound = -bound
209
+ xb = []
210
+ yb = []
211
+ for i in range(4):
212
+ xbi = [center[i,0]-rws[i]/2, center[i,0]+rws[i]/2-1]
213
+ ybi = [center[i,1]-rhs[i]/2, center[i,1]+rhs[i]/2-1]
214
+ for j in range(2):
215
+ maskx = bound[:,xbi[j]]
216
+ masky = bound[ybi[j],:]
217
+ tmp_a = torch.from_numpy(maskx)*xbi[j].double()
218
+ tmp_b = torch.from_numpy(1-maskx)
219
+ xb += [tmp_b*10000 + tmp_a]
220
+
221
+ tmp_a = torch.from_numpy(masky)*ybi[j].double()
222
+ tmp_b = torch.from_numpy(1-masky)
223
+ yb += [tmp_b*10000 + tmp_a]
224
+ soft = 1-getSoft([imgsize,imgsize],xb,yb)
225
+ soft = torch.Tensor(soft).unsqueeze(0)
226
+ mask = (torch.ones(mask.shape)-mask)*soft + mask
227
+ elif self.opt.region_enm == 1:
228
+ for i in range(4):
229
+ cmask0 = cmasks0[i]
230
+ rec = torch.zeros(B.shape)
231
+ rec[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2] = 1
232
+ mask = mask * (torch.ones(B.shape) - cmask0 * rec)
233
+ elif self.opt.region_enm == 2:
234
+ for i in range(4):
235
+ cmask0 = cmasks0[i]
236
+ mask = mask * (torch.ones(B.shape) - cmask0)
237
+ hair_A = (A/2+0.5) * mask.repeat(int(input_nc/output_nc),1,1) * 2 - 1
238
+ hair_B = (B/2+0.5) * mask * 2 - 1
239
+ item['hair_A'] = hair_A
240
+ item['hair_B'] = hair_B
241
+ item['mask'] = mask # mask out eyes, nose, mouth
242
+ if self.opt.bg_local:
243
+ bgdir = self.opt.bg_dir
244
+ bgpath = os.path.join(bgdir,basen[:-4]+'.png')
245
+ im_bg = Image.open(bgpath)
246
+ mask2 = transforms.ToTensor()(im_bg) # mask out background
247
+ if flipped:
248
+ mask2 = mask2.index_select(2, idx)
249
+ mask2 = (mask2 >= 0.5).float()
250
+ hair_A = (A/2+0.5) * mask.repeat(int(input_nc/output_nc),1,1) * mask2.repeat(int(input_nc/output_nc),1,1) * 2 - 1
251
+ hair_B = (B/2+0.5) * mask * mask2 * 2 - 1
252
+ bg_A = (A/2+0.5) * (torch.ones(mask2.shape)-mask2).repeat(int(input_nc/output_nc),1,1) * 2 - 1
253
+ bg_B = (B/2+0.5) * (torch.ones(mask2.shape)-mask2) * 2 - 1
254
+ item['hair_A'] = hair_A
255
+ item['hair_B'] = hair_B
256
+ item['bg_A'] = bg_A
257
+ item['bg_B'] = bg_B
258
+ item['mask'] = mask
259
+ item['mask2'] = mask2
260
+
261
+ if (self.opt.isTrain and self.opt.chamfer_loss):
262
+ if self.opt.which_direction == 'AtoB':
263
+ img = tocv2(B)
264
+ else:
265
+ img = tocv2(A)
266
+ dt1, dt2 = dt(img)
267
+ dt1 = torch.from_numpy(dt1)
268
+ dt2 = torch.from_numpy(dt2)
269
+ dt1 = dt1.unsqueeze(0)
270
+ dt2 = dt2.unsqueeze(0)
271
+ item['dt1gt'] = dt1
272
+ item['dt2gt'] = dt2
273
+
274
+ if self.opt.isTrain and self.opt.emphasis_conti_face:
275
+ face_mask_path = os.path.join(self.opt.facemask_dir,basen[:-4]+'.png')
276
+ face_mask = Image.open(face_mask_path)
277
+ face_mask = transforms.ToTensor()(face_mask) # [0,1]
278
+ if flipped:
279
+ face_mask = face_mask.index_select(2, idx)
280
+ item['face_mask'] = face_mask
281
+
282
+ return item
283
+
284
+ def __len__(self):
285
+ return len(self.AB_paths)
286
+
287
+ def name(self):
288
+ return 'AlignedDataset'
APDrawingGAN2/data/base_data_loader.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ class BaseDataLoader():
2
+ def __init__(self):
3
+ pass
4
+
5
+ def initialize(self, opt):
6
+ self.opt = opt
7
+ pass
8
+
9
+ def load_data():
10
+ return None
APDrawingGAN2/data/base_dataset.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+ from PIL import Image
3
+ import torchvision.transforms as transforms
4
+
5
+
6
+ class BaseDataset(data.Dataset):
7
+ def __init__(self):
8
+ super(BaseDataset, self).__init__()
9
+
10
+ def name(self):
11
+ return 'BaseDataset'
12
+
13
+ @staticmethod
14
+ def modify_commandline_options(parser, is_train):
15
+ return parser
16
+
17
+ def initialize(self, opt):
18
+ pass
19
+
20
+ def __len__(self):
21
+ return 0
22
+
23
+
24
+ def get_transform(opt):
25
+ transform_list = []
26
+ if opt.resize_or_crop == 'resize_and_crop':
27
+ osize = [opt.loadSize, opt.fineSize]
28
+ transform_list.append(transforms.Resize(osize, Image.BICUBIC))
29
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
30
+ elif opt.resize_or_crop == 'crop':
31
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
32
+ elif opt.resize_or_crop == 'scale_width':
33
+ transform_list.append(transforms.Lambda(
34
+ lambda img: __scale_width(img, opt.fineSize)))
35
+ elif opt.resize_or_crop == 'scale_width_and_crop':
36
+ transform_list.append(transforms.Lambda(
37
+ lambda img: __scale_width(img, opt.loadSize)))
38
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
39
+ elif opt.resize_or_crop == 'none':
40
+ transform_list.append(transforms.Lambda(
41
+ lambda img: __adjust(img)))
42
+ else:
43
+ raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop)
44
+
45
+ if opt.isTrain and not opt.no_flip:
46
+ transform_list.append(transforms.RandomHorizontalFlip())
47
+
48
+ transform_list += [transforms.ToTensor(),
49
+ transforms.Normalize((0.5, 0.5, 0.5),
50
+ (0.5, 0.5, 0.5))]
51
+ return transforms.Compose(transform_list)
52
+
53
+ # just modify the width and height to be multiple of 4
54
+ def __adjust(img):
55
+ ow, oh = img.size
56
+
57
+ # the size needs to be a multiple of this number,
58
+ # because going through generator network may change img size
59
+ # and eventually cause size mismatch error
60
+ mult = 4
61
+ if ow % mult == 0 and oh % mult == 0:
62
+ return img
63
+ w = (ow - 1) // mult
64
+ w = (w + 1) * mult
65
+ h = (oh - 1) // mult
66
+ h = (h + 1) * mult
67
+
68
+ if ow != w or oh != h:
69
+ __print_size_warning(ow, oh, w, h)
70
+
71
+ return img.resize((w, h), Image.BICUBIC)
72
+
73
+
74
+ def __scale_width(img, target_width):
75
+ ow, oh = img.size
76
+
77
+ # the size needs to be a multiple of this number,
78
+ # because going through generator network may change img size
79
+ # and eventually cause size mismatch error
80
+ mult = 4
81
+ assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult
82
+ if (ow == target_width and oh % mult == 0):
83
+ return img
84
+ w = target_width
85
+ target_height = int(target_width * oh / ow)
86
+ m = (target_height - 1) // mult
87
+ h = (m + 1) * mult
88
+
89
+ if target_height != h:
90
+ __print_size_warning(target_width, target_height, w, h)
91
+
92
+ return img.resize((w, h), Image.BICUBIC)
93
+
94
+
95
+ def __print_size_warning(ow, oh, w, h):
96
+ if not hasattr(__print_size_warning, 'has_printed'):
97
+ print("The image size needs to be a multiple of 4. "
98
+ "The loaded image size was (%d, %d), so it was adjusted to "
99
+ "(%d, %d). This adjustment will be done to all images "
100
+ "whose sizes are not multiples of 4" % (ow, oh, w, h))
101
+ __print_size_warning.has_printed = True
102
+
103
+
APDrawingGAN2/data/image_folder.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###############################################################################
2
+ # Code from
3
+ # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4
+ # Modified the original code so that it also loads images from the current
5
+ # directory as well as the subdirectories
6
+ ###############################################################################
7
+
8
+ import torch.utils.data as data
9
+
10
+ from PIL import Image
11
+ import os
12
+ import os.path
13
+
14
+ IMG_EXTENSIONS = [
15
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
16
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
17
+ ]
18
+
19
+
20
+ def is_image_file(filename):
21
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22
+
23
+
24
+ def make_dataset(dir):
25
+ images = []
26
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
27
+
28
+ for root, _, fnames in sorted(os.walk(dir)):
29
+ for fname in fnames:
30
+ if is_image_file(fname):
31
+ path = os.path.join(root, fname)
32
+ images.append(path)
33
+
34
+ return images
35
+
36
+
37
+ def default_loader(path):
38
+ return Image.open(path).convert('RGB')
39
+
40
+
41
+ class ImageFolder(data.Dataset):
42
+
43
+ def __init__(self, root, transform=None, return_paths=False,
44
+ loader=default_loader):
45
+ imgs = make_dataset(root)
46
+ if len(imgs) == 0:
47
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
48
+ "Supported image extensions are: " +
49
+ ",".join(IMG_EXTENSIONS)))
50
+
51
+ self.root = root
52
+ self.imgs = imgs
53
+ self.transform = transform
54
+ self.return_paths = return_paths
55
+ self.loader = loader
56
+
57
+ def __getitem__(self, index):
58
+ path = self.imgs[index]
59
+ img = self.loader(path)
60
+ if self.transform is not None:
61
+ img = self.transform(img)
62
+ if self.return_paths:
63
+ return img, path
64
+ else:
65
+ return img
66
+
67
+ def __len__(self):
68
+ return len(self.imgs)
APDrawingGAN2/data/single_dataset.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from data.base_dataset import BaseDataset, get_transform
3
+ from data.image_folder import make_dataset
4
+ from PIL import Image
5
+ import numpy as np
6
+ import csv
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+
10
+ def getfeats(featpath):
11
+ trans_points = np.empty([5,2],dtype=np.int64)
12
+ with open(featpath, 'r') as csvfile:
13
+ reader = csv.reader(csvfile, delimiter=' ')
14
+ for ind,row in enumerate(reader):
15
+ trans_points[ind,:] = row
16
+ return trans_points
17
+
18
+ def getSoft(size,xb,yb,boundwidth=5.0):
19
+ xarray = np.tile(np.arange(0,size[1]),(size[0],1))
20
+ yarray = np.tile(np.arange(0,size[0]),(size[1],1)).transpose()
21
+ cxdists = []
22
+ cydists = []
23
+ for i in range(len(xb)):
24
+ xba = np.tile(xb[i],(size[1],1)).transpose()
25
+ yba = np.tile(yb[i],(size[0],1))
26
+ cxdists.append(np.abs(xarray-xba))
27
+ cydists.append(np.abs(yarray-yba))
28
+ xdist = np.minimum.reduce(cxdists)
29
+ ydist = np.minimum.reduce(cydists)
30
+ manhdist = np.minimum.reduce([xdist,ydist])
31
+ im = (manhdist+1) / (boundwidth+1) * 1.0
32
+ im[im>=1.0] = 1.0
33
+ return im
34
+
35
+ class SingleDataset(BaseDataset):
36
+ @staticmethod
37
+ def modify_commandline_options(parser, is_train):
38
+ return parser
39
+
40
+ def initialize(self, opt):
41
+ self.opt = opt
42
+ self.root = opt.dataroot
43
+ self.dir_A = os.path.join(opt.dataroot)
44
+ imglist = 'datasets/apdrawing_list/%s/%s.txt' % (opt.phase, opt.dataroot)
45
+ if os.path.exists(imglist):
46
+ lines = open(imglist, 'r').read().splitlines()
47
+ self.A_paths = sorted(lines)
48
+ else:
49
+ self.A_paths = make_dataset(self.dir_A)
50
+ self.A_paths = sorted(self.A_paths)
51
+ self.transform = get_transform(opt) # this function uses NO_FLIP; aligned dataset do not use this, aligned dataset manually transform
52
+
53
+ def __getitem__(self, index):
54
+ A_path = self.A_paths[index]
55
+ A_img = Image.open(A_path).convert('RGB')
56
+ A = self.transform(A_img)
57
+ if self.opt.which_direction == 'BtoA':
58
+ input_nc = self.opt.output_nc
59
+ output_nc = self.opt.input_nc
60
+ else:
61
+ input_nc = self.opt.input_nc
62
+ output_nc = self.opt.output_nc
63
+
64
+ if input_nc == 1: # RGB to gray
65
+ tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
66
+ A = tmp.unsqueeze(0)
67
+
68
+ item = {'A': A, 'A_paths': A_path}
69
+
70
+ if self.opt.use_local:
71
+ regions = ['eyel','eyer','nose','mouth']
72
+ basen = os.path.basename(A_path)[:-4]+'.txt'
73
+ featdir = self.opt.lm_dir
74
+ featpath = os.path.join(featdir,basen)
75
+ feats = getfeats(featpath)
76
+ mouth_x = int((feats[3,0]+feats[4,0])/2.0)
77
+ mouth_y = int((feats[3,1]+feats[4,1])/2.0)
78
+ ratio = self.opt.fineSize / 256
79
+ EYE_H = self.opt.EYE_H * ratio
80
+ EYE_W = self.opt.EYE_W * ratio
81
+ NOSE_H = self.opt.NOSE_H * ratio
82
+ NOSE_W = self.opt.NOSE_W * ratio
83
+ MOUTH_H = self.opt.MOUTH_H * ratio
84
+ MOUTH_W = self.opt.MOUTH_W * ratio
85
+ center = torch.IntTensor([[feats[0,0],feats[0,1]-4*ratio],[feats[1,0],feats[1,1]-4*ratio],[feats[2,0],feats[2,1]-NOSE_H/2+16*ratio],[mouth_x,mouth_y]])
86
+ item['center'] = center
87
+ rhs = [int(EYE_H),int(EYE_H),int(NOSE_H),int(MOUTH_H)]
88
+ rws = [int(EYE_W),int(EYE_W),int(NOSE_W),int(MOUTH_W)]
89
+ if self.opt.soft_border:
90
+ soft_border_mask4 = []
91
+ for i in range(4):
92
+ xb = [np.zeros(rhs[i]),np.ones(rhs[i])*(rws[i]-1)]
93
+ yb = [np.zeros(rws[i]),np.ones(rws[i])*(rhs[i]-1)]
94
+ soft_border_mask = getSoft([rhs[i],rws[i]],xb,yb)
95
+ soft_border_mask4.append(torch.Tensor(soft_border_mask).unsqueeze(0))
96
+ item['soft_'+regions[i]+'_mask'] = soft_border_mask4[i]
97
+ for i in range(4):
98
+ item[regions[i]+'_A'] = A[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2]
99
+ if self.opt.soft_border:
100
+ item[regions[i]+'_A'] = item[regions[i]+'_A'] * soft_border_mask4[i].repeat(int(input_nc/output_nc),1,1)
101
+ if self.opt.compactmask:
102
+ cmasks0 = []
103
+ cmasks = []
104
+ for i in range(4):
105
+ cmaskpath = os.path.join(self.opt.cmask_dir,regions[i],basen[:-4]+'.png')
106
+ im_cmask = Image.open(cmaskpath)
107
+ cmask0 = transforms.ToTensor()(im_cmask)
108
+ if output_nc == 1 and cmask0.shape[0] == 3:
109
+ tmp = cmask0[0, ...] * 0.299 + cmask0[1, ...] * 0.587 + cmask0[2, ...] * 0.114
110
+ cmask0 = tmp.unsqueeze(0)
111
+ cmask0 = (cmask0 >= 0.5).float()
112
+ cmasks0.append(cmask0)
113
+ cmask = cmask0.clone()
114
+ cmask = cmask[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2]
115
+ cmasks.append(cmask)
116
+ item['cmaskel'] = cmasks[0]
117
+ item['cmasker'] = cmasks[1]
118
+ item['cmask'] = cmasks[2]
119
+ item['cmaskmo'] = cmasks[3]
120
+ if self.opt.hair_local:
121
+ output_nc = self.opt.output_nc
122
+ mask = torch.ones([output_nc,A.shape[1],A.shape[2]])
123
+ for i in range(4):
124
+ mask[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2] = 0
125
+ if self.opt.soft_border:
126
+ imgsize = self.opt.fineSize
127
+ maskn = mask[0].numpy()
128
+ masks = [np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize])]
129
+ masks[0][1:] = maskn[:-1]
130
+ masks[1][:-1] = maskn[1:]
131
+ masks[2][:,1:] = maskn[:,:-1]
132
+ masks[3][:,:-1] = maskn[:,1:]
133
+ masks2 = [maskn-e for e in masks]
134
+ bound = np.minimum.reduce(masks2)
135
+ bound = -bound
136
+ xb = []
137
+ yb = []
138
+ for i in range(4):
139
+ xbi = [center[i,0]-rws[i]/2, center[i,0]+rws[i]/2-1]
140
+ ybi = [center[i,1]-rhs[i]/2, center[i,1]+rhs[i]/2-1]
141
+ for j in range(2):
142
+ maskx = bound[:,xbi[j]]
143
+ masky = bound[ybi[j],:]
144
+ tmp_a = torch.from_numpy(maskx)*xbi[j].double()
145
+ tmp_b = torch.from_numpy(1-maskx)
146
+ xb += [tmp_b*10000 + tmp_a]
147
+
148
+ tmp_a = torch.from_numpy(masky)*ybi[j].double()
149
+ tmp_b = torch.from_numpy(1-masky)
150
+ yb += [tmp_b*10000 + tmp_a]
151
+ soft = 1-getSoft([imgsize,imgsize],xb,yb)
152
+ soft = torch.Tensor(soft).unsqueeze(0)
153
+ mask = (torch.ones(mask.shape)-mask)*soft + mask
154
+ hair_A = (A/2+0.5) * mask.repeat(int(input_nc/output_nc),1,1) * 2 - 1
155
+ item['hair_A'] = hair_A
156
+ item['mask'] = mask
157
+ if self.opt.bg_local:
158
+ bgdir = self.opt.bg_dir
159
+ bgpath = os.path.join(bgdir,basen[:-4]+'.png')
160
+ im_bg = Image.open(bgpath)
161
+ mask2 = transforms.ToTensor()(im_bg) # mask out background
162
+ mask2 = (mask2 >= 0.5).float()
163
+ hair_A = (A/2+0.5) * mask.repeat(int(input_nc/output_nc),1,1) * mask2.repeat(int(input_nc/output_nc),1,1) * 2 - 1
164
+ bg_A = (A/2+0.5) * (torch.ones(mask2.shape)-mask2).repeat(int(input_nc/output_nc),1,1) * 2 - 1
165
+ item['hair_A'] = hair_A
166
+ item['bg_A'] = bg_A
167
+ item['mask'] = mask
168
+ item['mask2'] = mask2
169
+
170
+ return item
171
+
172
+ def __len__(self):
173
+ return len(self.A_paths)
174
+
175
+ def name(self):
176
+ return 'SingleImageDataset'
APDrawingGAN2/docs/tips.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ## Training/test Tips
2
+ - Flags: see `options/train_options.py` and `options/base_options.py` for the training flags; see `options/test_options.py` and `options/base_options.py` for the test flags. The default values of these options are somtimes adjusted in the model files.
3
+
4
+ - CPU/GPU (default `--gpu_ids 0`): set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. You need a large batch size (e.g. `--batch_size 32`) to benefit from multiple GPUs.
5
+
6
+ - Visualization: during training, the current results can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will appear on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have `visdom` installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id -1`. Second, the intermediate results are saved to `[opt.checkpoints_dir]/[opt.name]/web/` as an HTML file. To avoid this, set `--no_html`.
7
+
8
+ - Fine-tuning/Resume training: to fine-tune a pre-trained model, or resume the previous training, use the `--continue_train` flag. The program will then load the model based on `which_epoch`. By default, the program will initialize the epoch count as 1. Set `--epoch_count <int>` to specify a different starting epoch count.
APDrawingGAN2/models/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from models.base_model import BaseModel
3
+
4
+
5
+ def find_model_using_name(model_name):
6
+ # Given the option --model [modelname],
7
+ # the file "models/modelname_model.py"
8
+ # will be imported.
9
+ model_filename = "models." + model_name + "_model"
10
+ modellib = importlib.import_module(model_filename)
11
+
12
+ # In the file, the class called ModelNameModel() will
13
+ # be instantiated. It has to be a subclass of BaseModel,
14
+ # and it is case-insensitive.
15
+ model = None
16
+ target_model_name = model_name.replace('_', '') + 'model'
17
+ for name, cls in modellib.__dict__.items():
18
+ if name.lower() == target_model_name.lower() \
19
+ and issubclass(cls, BaseModel):
20
+ model = cls
21
+
22
+ if model is None:
23
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
24
+ exit(0)
25
+
26
+ return model
27
+
28
+
29
+ def get_option_setter(model_name):
30
+ model_class = find_model_using_name(model_name)
31
+ return model_class.modify_commandline_options
32
+
33
+
34
+ def create_model(opt):
35
+ model = find_model_using_name(opt.model)
36
+ instance = model()
37
+ instance.initialize(opt)
38
+ print("model [%s] was created" % (instance.name()))
39
+ return instance
APDrawingGAN2/models/apdrawingpp_style_model.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from util.image_pool import ImagePool
3
+ from .base_model import BaseModel
4
+ from . import networks
5
+ import os
6
+ import math
7
+
8
+ W = 11
9
+ aa = int(math.floor(512./W))
10
+ res = 512 - W*aa
11
+
12
+
13
+ def padpart(A,part,centers,opt,device):
14
+ IMAGE_SIZE = opt.fineSize
15
+ bs,nc,_,_ = A.shape
16
+ ratio = IMAGE_SIZE / 256
17
+ NOSE_W = opt.NOSE_W * ratio
18
+ NOSE_H = opt.NOSE_H * ratio
19
+ EYE_W = opt.EYE_W * ratio
20
+ EYE_H = opt.EYE_H * ratio
21
+ MOUTH_W = opt.MOUTH_W * ratio
22
+ MOUTH_H = opt.MOUTH_H * ratio
23
+ A_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(device)
24
+ padvalue = -1 # black
25
+ for i in range(bs):
26
+ center = centers[i]
27
+ if part == 'nose':
28
+ A_p[i] = torch.nn.ConstantPad2d((center[2,0] - NOSE_W / 2, IMAGE_SIZE - (center[2,0]+NOSE_W/2), center[2,1] - NOSE_H / 2, IMAGE_SIZE - (center[2,1]+NOSE_H/2)),padvalue)(A[i])
29
+ elif part == 'eyel':
30
+ A_p[i] = torch.nn.ConstantPad2d((center[0,0] - EYE_W / 2, IMAGE_SIZE - (center[0,0]+EYE_W/2), center[0,1] - EYE_H / 2, IMAGE_SIZE - (center[0,1]+EYE_H/2)),padvalue)(A[i])
31
+ elif part == 'eyer':
32
+ A_p[i] = torch.nn.ConstantPad2d((center[1,0] - EYE_W / 2, IMAGE_SIZE - (center[1,0]+EYE_W/2), center[1,1] - EYE_H / 2, IMAGE_SIZE - (center[1,1]+EYE_H/2)),padvalue)(A[i])
33
+ elif part == 'mouth':
34
+ A_p[i] = torch.nn.ConstantPad2d((center[3,0] - MOUTH_W / 2, IMAGE_SIZE - (center[3,0]+MOUTH_W/2), center[3,1] - MOUTH_H / 2, IMAGE_SIZE - (center[3,1]+MOUTH_H/2)),padvalue)(A[i])
35
+ return A_p
36
+
37
+ import numpy as np
38
+ def nonlinearDt(dt,type='atan',xmax=torch.Tensor([10.0])):#dt in [0,1], first multiply xmax(>1), then remap to [0,1]
39
+ if type == 'atan':
40
+ nldt = torch.atan(dt*xmax) / torch.atan(xmax)
41
+ elif type == 'sigmoid':
42
+ nldt = (torch.sigmoid(dt*xmax)-0.5) / (torch.sigmoid(xmax)-0.5)
43
+ elif type == 'tanh':
44
+ nldt = torch.tanh(dt*xmax) / torch.tanh(xmax)
45
+ elif type == 'pow':
46
+ nldt = torch.pow(dt*xmax,2) / torch.pow(xmax,2)
47
+ elif type == 'exp':
48
+ if xmax.item()>1:
49
+ xmax = xmax / 3
50
+ nldt = (torch.exp(dt*xmax)-1) / (torch.exp(xmax)-1)
51
+ #print("remap dt:", type, xmax.item())
52
+ return nldt
53
+
54
+ class APDrawingPPStyleModel(BaseModel):
55
+ def name(self):
56
+ return 'APDrawingPPStyleModel'
57
+
58
+ @staticmethod
59
+ def modify_commandline_options(parser, is_train=True):
60
+
61
+ # changing the default values to match the pix2pix paper
62
+ # (https://phillipi.github.io/pix2pix/)
63
+ parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch')# no_lsgan=True, use_lsgan=False
64
+ parser.set_defaults(dataset_mode='aligned')
65
+ parser.set_defaults(auxiliary_root='auxiliaryeye2o')
66
+ parser.set_defaults(use_local=True, hair_local=True, bg_local=True)
67
+ parser.set_defaults(discriminator_local=True, gan_loss_strategy=2)
68
+ parser.set_defaults(chamfer_loss=True, dt_nonlinear='exp', lambda_chamfer=0.35, lambda_chamfer2=0.35)
69
+ parser.set_defaults(nose_ae=True, others_ae=True, compactmask=True, MOUTH_H=56)
70
+ parser.set_defaults(soft_border=1, batch_size=1, save_epoch_freq=25)
71
+ parser.add_argument('--nnG_hairc', type=int, default=6, help='nnG for hair classifier')
72
+ parser.add_argument('--use_resnet', action='store_true', help='use resnet for generator')
73
+ parser.add_argument('--regarch', type=int, default=4, help='architecture for netRegressor')
74
+ if is_train:
75
+ parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
76
+ parser.add_argument('--lambda_local', type=float, default=25.0, help='weight for Local loss')
77
+ parser.set_defaults(netG_dt='unet_512')
78
+ parser.set_defaults(netG_line='unet_512')
79
+
80
+ return parser
81
+
82
+ def initialize(self, opt):
83
+ BaseModel.initialize(self, opt)
84
+ self.isTrain = opt.isTrain
85
+ # specify the training losses you want to print out. The program will call base_model.get_current_losses
86
+ self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
87
+ if self.isTrain and self.opt.no_l1_loss:
88
+ self.loss_names = ['G_GAN', 'D_real', 'D_fake']
89
+ if self.isTrain and self.opt.use_local and not self.opt.no_G_local_loss:
90
+ self.loss_names.append('G_local')
91
+ self.loss_names.append('G_hair_local')
92
+ self.loss_names.append('G_bg_local')
93
+ if self.isTrain and self.opt.discriminator_local:
94
+ self.loss_names.append('D_real_local')
95
+ self.loss_names.append('D_fake_local')
96
+ self.loss_names.append('G_GAN_local')
97
+ if self.isTrain and self.opt.chamfer_loss:
98
+ self.loss_names.append('G_chamfer')
99
+ self.loss_names.append('G_chamfer2')
100
+ if self.isTrain and self.opt.continuity_loss:
101
+ self.loss_names.append('G_continuity')
102
+ self.loss_names.append('G')
103
+ print('loss_names', self.loss_names)
104
+ # specify the images you want to save/display. The program will call base_model.get_current_visuals
105
+ self.visual_names = ['real_A', 'fake_B', 'real_B']
106
+ if self.opt.use_local:
107
+ self.visual_names += ['fake_B0', 'fake_B1']
108
+ self.visual_names += ['fake_B_hair', 'real_B_hair', 'real_A_hair']
109
+ self.visual_names += ['fake_B_bg', 'real_B_bg', 'real_A_bg']
110
+ if self.opt.region_enm in [0,1]:
111
+ if self.opt.nose_ae:
112
+ self.visual_names += ['fake_B_nose_v','fake_B_nose_v1','fake_B_nose_v2','cmask1no']
113
+ if self.opt.others_ae:
114
+ self.visual_names += ['fake_B_eyel_v','fake_B_eyel_v1','fake_B_eyel_v2','cmask1el']
115
+ self.visual_names += ['fake_B_eyer_v','fake_B_eyer_v1','fake_B_eyer_v2','cmask1er']
116
+ self.visual_names += ['fake_B_mouth_v','fake_B_mouth_v1','fake_B_mouth_v2','cmask1mo']
117
+ elif self.opt.region_enm in [2]:
118
+ self.visual_names += ['fake_B_nose','fake_B_eyel','fake_B_eyer','fake_B_mouth']
119
+ if self.isTrain and self.opt.chamfer_loss:
120
+ self.visual_names += ['dt1', 'dt2']
121
+ self.visual_names += ['dt1gt', 'dt2gt']
122
+ if self.isTrain and self.opt.soft_border:
123
+ self.visual_names += ['mask']
124
+ if not self.isTrain and self.opt.save2:
125
+ self.visual_names = ['real_A', 'fake_B']
126
+ print('visuals', self.visual_names)
127
+ # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
128
+ self.auxiliary_model_names = []
129
+ if self.isTrain:
130
+ self.model_names = ['G', 'D']
131
+ if self.opt.discriminator_local:
132
+ self.model_names += ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG']
133
+ # auxiliary nets for loss calculation
134
+ if self.opt.chamfer_loss:
135
+ self.auxiliary_model_names += ['DT1', 'DT2']
136
+ self.auxiliary_model_names += ['Line1', 'Line2']
137
+ if self.opt.continuity_loss:
138
+ self.auxiliary_model_names += ['Regressor']
139
+ else: # during test time, only load Gs
140
+ self.model_names = ['G']
141
+ if self.opt.test_continuity_loss:
142
+ self.auxiliary_model_names += ['Regressor']
143
+ if self.opt.use_local:
144
+ self.model_names += ['GLEyel','GLEyer','GLNose','GLMouth','GLHair','GLBG','GCombine']
145
+ self.auxiliary_model_names += ['CLm','CLh']
146
+ # auxiliary nets for local output refinement
147
+ if self.opt.nose_ae:
148
+ self.auxiliary_model_names += ['AE']
149
+ if self.opt.others_ae:
150
+ self.auxiliary_model_names += ['AEel','AEer','AEmowhite','AEmoblack']
151
+ print('model_names', self.model_names)
152
+ print('auxiliary_model_names', self.auxiliary_model_names)
153
+ # load/define networks
154
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
155
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
156
+ opt.nnG)
157
+ print('netG', opt.netG)
158
+
159
+ if self.isTrain:
160
+ use_sigmoid = opt.no_lsgan
161
+ self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
162
+ opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
163
+ print('netD', opt.netD, opt.n_layers_D)
164
+ if self.opt.discriminator_local:
165
+ self.netDLEyel = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
166
+ opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
167
+ self.netDLEyer = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
168
+ opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
169
+ self.netDLNose = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
170
+ opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
171
+ self.netDLMouth = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
172
+ opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
173
+ self.netDLHair = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
174
+ opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
175
+ self.netDLBG = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
176
+ opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
177
+
178
+
179
+ if self.opt.use_local:
180
+ netlocal1 = 'partunet' if self.opt.use_resnet == 0 else 'resnet_nblocks'
181
+ netlocal2 = 'partunet2' if self.opt.use_resnet == 0 else 'resnet_6blocks'
182
+ netlocal2_style = 'partunet2style' if self.opt.use_resnet == 0 else 'resnet_style2_6blocks'
183
+ self.netGLEyel = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
184
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
185
+ self.netGLEyer = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
186
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
187
+ self.netGLNose = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
188
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
189
+ self.netGLMouth = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
190
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
191
+ self.netGLHair = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2_style, opt.norm,
192
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=4,
193
+ extra_channel=3)
194
+ self.netGLBG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2, opt.norm,
195
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=4)
196
+ # by default combiner_type is combiner, which uses resnet
197
+ print('combiner_type', self.opt.combiner_type)
198
+ self.netGCombine = networks.define_G(2*opt.output_nc, opt.output_nc, opt.ngf, self.opt.combiner_type, opt.norm,
199
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 2)
200
+ # auxiliary classifiers for mouth and hair
201
+ ratio = self.opt.fineSize / 256
202
+ self.MOUTH_H = int(self.opt.MOUTH_H * ratio)
203
+ self.MOUTH_W = int(self.opt.MOUTH_W * ratio)
204
+ self.netCLm = networks.define_G(opt.input_nc, 2, opt.ngf, 'classifier', opt.norm,
205
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
206
+ nnG = 3, ae_h = self.MOUTH_H, ae_w = self.MOUTH_W)
207
+ self.netCLh = networks.define_G(opt.input_nc, 3, opt.ngf, 'classifier', opt.norm,
208
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
209
+ nnG = opt.nnG_hairc, ae_h = opt.fineSize, ae_w = opt.fineSize)
210
+
211
+
212
+ if self.isTrain:
213
+ self.fake_AB_pool = ImagePool(opt.pool_size)
214
+ # define loss functions
215
+ self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
216
+ self.criterionL1 = torch.nn.L1Loss()
217
+
218
+ # initialize optimizers
219
+ self.optimizers = []
220
+ if not self.opt.use_local:
221
+ print('G_params 1 components')
222
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
223
+ lr=opt.lr, betas=(opt.beta1, 0.999))
224
+ else:
225
+ G_params = list(self.netG.parameters()) + list(self.netGLEyel.parameters()) + list(self.netGLEyer.parameters()) + list(self.netGLNose.parameters()) + list(self.netGLMouth.parameters()) + list(self.netGCombine.parameters()) + list(self.netGLHair.parameters()) + list(self.netGLBG.parameters())
226
+ print('G_params 8 components')
227
+ self.optimizer_G = torch.optim.Adam(G_params,
228
+ lr=opt.lr, betas=(opt.beta1, 0.999))
229
+
230
+ if not self.opt.discriminator_local:
231
+ print('D_params 1 components')
232
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
233
+ lr=opt.lr, betas=(opt.beta1, 0.999))
234
+ else:#self.opt.discriminator_local == True
235
+ D_params = list(self.netD.parameters()) + list(self.netDLEyel.parameters()) +list(self.netDLEyer.parameters()) + list(self.netDLNose.parameters()) + list(self.netDLMouth.parameters()) + list(self.netDLHair.parameters()) + list(self.netDLBG.parameters())
236
+ print('D_params 7 components')
237
+ self.optimizer_D = torch.optim.Adam(D_params,
238
+ lr=opt.lr, betas=(opt.beta1, 0.999))
239
+ self.optimizers.append(self.optimizer_G)
240
+ self.optimizers.append(self.optimizer_D)
241
+
242
+ # ==================================auxiliary nets (loaded, parameters fixed)=============================
243
+ if self.opt.use_local and self.opt.nose_ae:
244
+ ratio = self.opt.fineSize / 256
245
+ NOSE_H = self.opt.NOSE_H * ratio
246
+ NOSE_W = self.opt.NOSE_W * ratio
247
+ self.netAE = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
248
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
249
+ latent_dim=self.opt.ae_latentno, ae_h=NOSE_H, ae_w=NOSE_W)
250
+ self.set_requires_grad(self.netAE, False)
251
+ if self.opt.use_local and self.opt.others_ae:
252
+ ratio = self.opt.fineSize / 256
253
+ EYE_H = self.opt.EYE_H * ratio
254
+ EYE_W = self.opt.EYE_W * ratio
255
+ MOUTH_H = self.opt.MOUTH_H * ratio
256
+ MOUTH_W = self.opt.MOUTH_W * ratio
257
+ self.netAEel = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
258
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
259
+ latent_dim=self.opt.ae_latenteye, ae_h=EYE_H, ae_w=EYE_W)
260
+ self.netAEer = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
261
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
262
+ latent_dim=self.opt.ae_latenteye, ae_h=EYE_H, ae_w=EYE_W)
263
+ self.netAEmowhite = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
264
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
265
+ latent_dim=self.opt.ae_latentmo, ae_h=MOUTH_H, ae_w=MOUTH_W)
266
+ self.netAEmoblack = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
267
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
268
+ latent_dim=self.opt.ae_latentmo, ae_h=MOUTH_H, ae_w=MOUTH_W)
269
+ self.set_requires_grad(self.netAEel, False)
270
+ self.set_requires_grad(self.netAEer, False)
271
+ self.set_requires_grad(self.netAEmowhite, False)
272
+ self.set_requires_grad(self.netAEmoblack, False)
273
+
274
+
275
+ if self.isTrain and self.opt.continuity_loss:
276
+ self.nc = 1
277
+ self.netRegressor = networks.define_G(self.nc, 1, opt.ngf, 'regressor', opt.norm,
278
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p,
279
+ nnG = opt.regarch)
280
+ self.set_requires_grad(self.netRegressor, False)
281
+
282
+ if self.isTrain and self.opt.chamfer_loss:
283
+ self.nc = 1
284
+ self.netDT1 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_dt, opt.norm,
285
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p)
286
+ self.netDT2 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_dt, opt.norm,
287
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p)
288
+ self.set_requires_grad(self.netDT1, False)
289
+ self.set_requires_grad(self.netDT2, False)
290
+ self.netLine1 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_line, opt.norm,
291
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p)
292
+ self.netLine2 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_line, opt.norm,
293
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p)
294
+ self.set_requires_grad(self.netLine1, False)
295
+ self.set_requires_grad(self.netLine2, False)
296
+
297
+ # ==================================for test (nets loaded, parameters fixed)=============================
298
+ if not self.isTrain and self.opt.test_continuity_loss:
299
+ self.nc = 1
300
+ self.netRegressor = networks.define_G(self.nc, 1, opt.ngf, 'regressor', opt.norm,
301
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
302
+ nnG = opt.regarch)
303
+ self.set_requires_grad(self.netRegressor, False)
304
+
305
+
306
+ def set_input(self, input):
307
+ AtoB = self.opt.which_direction == 'AtoB'
308
+ self.real_A = input['A' if AtoB else 'B'].to(self.device)
309
+ self.real_B = input['B' if AtoB else 'A'].to(self.device)
310
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
311
+ self.batch_size = len(self.image_paths)
312
+ if self.opt.use_local:
313
+ self.real_A_eyel = input['eyel_A'].to(self.device)
314
+ self.real_A_eyer = input['eyer_A'].to(self.device)
315
+ self.real_A_nose = input['nose_A'].to(self.device)
316
+ self.real_A_mouth = input['mouth_A'].to(self.device)
317
+ self.real_B_eyel = input['eyel_B'].to(self.device)
318
+ self.real_B_eyer = input['eyer_B'].to(self.device)
319
+ self.real_B_nose = input['nose_B'].to(self.device)
320
+ self.real_B_mouth = input['mouth_B'].to(self.device)
321
+ if self.opt.region_enm in [0,1]:
322
+ self.center = input['center']
323
+ if self.opt.soft_border:
324
+ self.softel = input['soft_eyel_mask'].to(self.device)
325
+ self.softer = input['soft_eyer_mask'].to(self.device)
326
+ self.softno = input['soft_nose_mask'].to(self.device)
327
+ self.softmo = input['soft_mouth_mask'].to(self.device)
328
+ if self.opt.compactmask:
329
+ self.cmask = input['cmask'].to(self.device)
330
+ self.cmask1 = self.cmask*2-1#[0,1]->[-1,1]
331
+ self.cmaskel = input['cmaskel'].to(self.device)
332
+ self.cmask1el = self.cmaskel*2-1
333
+ self.cmasker = input['cmasker'].to(self.device)
334
+ self.cmask1er = self.cmasker*2-1
335
+ self.cmaskmo = input['cmaskmo'].to(self.device)
336
+ self.cmask1mo = self.cmaskmo*2-1
337
+ self.real_A_hair = input['hair_A'].to(self.device)
338
+ self.real_B_hair = input['hair_B'].to(self.device)
339
+ self.mask = input['mask'].to(self.device) # mask for non-eyes,nose,mouth
340
+ self.mask2 = input['mask2'].to(self.device) # mask for non-bg
341
+ self.real_A_bg = input['bg_A'].to(self.device)
342
+ self.real_B_bg = input['bg_B'].to(self.device)
343
+ if (self.isTrain and self.opt.chamfer_loss):
344
+ self.dt1gt = input['dt1gt'].to(self.device)
345
+ self.dt2gt = input['dt2gt'].to(self.device)
346
+ if self.isTrain and self.opt.emphasis_conti_face:
347
+ self.face_mask = input['face_mask'].cuda(self.gpu_ids_p[0])
348
+
349
+ def getonehot(self,outputs,classes):
350
+ [maxv,index] = torch.max(outputs,1)
351
+ y = torch.unsqueeze(index,1)
352
+ onehot = torch.FloatTensor(self.batch_size,classes).to(self.device)
353
+ onehot.zero_()
354
+ onehot.scatter_(1,y,1)
355
+ return onehot
356
+
357
+ def forward(self):
358
+ if not self.opt.use_local:
359
+ self.fake_B = self.netG(self.real_A)
360
+ else:
361
+ self.fake_B0 = self.netG(self.real_A)
362
+ # EYES, MOUTH
363
+ outputs1 = self.netCLm(self.real_A_mouth)
364
+ onehot1 = self.getonehot(outputs1,2)
365
+
366
+ if not self.opt.others_ae:
367
+ fake_B_eyel = self.netGLEyel(self.real_A_eyel)
368
+ fake_B_eyer = self.netGLEyer(self.real_A_eyer)
369
+ fake_B_mouth = self.netGLMouth(self.real_A_mouth)
370
+ else: # use AE that only constains compact region, need cmask!
371
+ self.fake_B_eyel1 = self.netGLEyel(self.real_A_eyel)
372
+ self.fake_B_eyer1 = self.netGLEyer(self.real_A_eyer)
373
+ self.fake_B_mouth1 = self.netGLMouth(self.real_A_mouth)
374
+ self.fake_B_eyel2,_ = self.netAEel(self.fake_B_eyel1)
375
+ self.fake_B_eyer2,_ = self.netAEer(self.fake_B_eyer1)
376
+ # USE 2 AEs
377
+ self.fake_B_mouth2 = torch.FloatTensor(self.batch_size,self.opt.output_nc,self.MOUTH_H,self.MOUTH_W).to(self.device)
378
+ for i in range(self.batch_size):
379
+ if onehot1[i][0] == 1:
380
+ self.fake_B_mouth2[i],_ = self.netAEmowhite(self.fake_B_mouth1[i].unsqueeze(0))
381
+ #print('AEmowhite')
382
+ elif onehot1[i][1] == 1:
383
+ self.fake_B_mouth2[i],_ = self.netAEmoblack(self.fake_B_mouth1[i].unsqueeze(0))
384
+ #print('AEmoblack')
385
+ fake_B_eyel = self.add_with_mask(self.fake_B_eyel2,self.fake_B_eyel1,self.cmaskel)
386
+ fake_B_eyer = self.add_with_mask(self.fake_B_eyer2,self.fake_B_eyer1,self.cmasker)
387
+ fake_B_mouth = self.add_with_mask(self.fake_B_mouth2,self.fake_B_mouth1,self.cmaskmo)
388
+ # NOSE
389
+ if not self.opt.nose_ae:
390
+ fake_B_nose = self.netGLNose(self.real_A_nose)
391
+ else: # use AE that only constains compact region, need cmask!
392
+ self.fake_B_nose1 = self.netGLNose(self.real_A_nose)
393
+ self.fake_B_nose2,_ = self.netAE(self.fake_B_nose1)
394
+ fake_B_nose = self.add_with_mask(self.fake_B_nose2,self.fake_B_nose1,self.cmask)
395
+
396
+ # for visuals and later local loss
397
+ if self.opt.region_enm in [0,1]:
398
+ self.fake_B_nose = fake_B_nose
399
+ self.fake_B_eyel = fake_B_eyel
400
+ self.fake_B_eyer = fake_B_eyer
401
+ self.fake_B_mouth = fake_B_mouth
402
+ # for soft border of 4 rectangle facial feature
403
+ if self.opt.region_enm == 0 and self.opt.soft_border:
404
+ self.fake_B_nose = self.masked(fake_B_nose, self.softno)
405
+ self.fake_B_eyel = self.masked(fake_B_eyel, self.softel)
406
+ self.fake_B_eyer = self.masked(fake_B_eyer, self.softer)
407
+ self.fake_B_mouth = self.masked(fake_B_mouth, self.softmo)
408
+ elif self.opt.region_enm in [2]: # need to multiply cmask
409
+ self.fake_B_nose = self.masked(fake_B_nose,self.cmask)
410
+ self.fake_B_eyel = self.masked(fake_B_eyel,self.cmaskel)
411
+ self.fake_B_eyer = self.masked(fake_B_eyer,self.cmasker)
412
+ self.fake_B_mouth = self.masked(fake_B_mouth,self.cmaskmo)
413
+
414
+ # HAIR, BG AND PARTCOMBINE
415
+ outputs2 = self.netCLh(self.real_A_hair)
416
+ onehot2 = self.getonehot(outputs2,3)
417
+
418
+ if not self.isTrain:
419
+ opt = self.opt
420
+ if opt.imagefolder == 'images':
421
+ file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'styleonehot.txt')
422
+ else:
423
+ file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), opt.imagefolder, 'styleonehot.txt')
424
+ message = '%s [%d %d] [%d %d %d]' % (self.image_paths[0], onehot1[0][0], onehot1[0][1],
425
+ onehot2[0][0], onehot2[0][1], onehot2[0][2])
426
+ with open(file_name, 'a+') as s_file:
427
+ s_file.write(message)
428
+ s_file.write('\n')
429
+
430
+ fake_B_hair = self.netGLHair(self.real_A_hair,onehot2)
431
+ fake_B_bg = self.netGLBG(self.real_A_bg)
432
+ self.fake_B_hair = self.masked(fake_B_hair,self.mask*self.mask2)
433
+ self.fake_B_bg = self.masked(fake_B_bg,self.inverse_mask(self.mask2))
434
+ if not self.opt.compactmask:
435
+ self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,fake_B_bg,self.mask*self.mask2,self.inverse_mask(self.mask2),self.opt.comb_op)
436
+ else:
437
+ self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,fake_B_bg,self.mask*self.mask2,self.inverse_mask(self.mask2),self.opt.comb_op,self.opt.region_enm,self.cmaskel,self.cmasker,self.cmask,self.cmaskmo)
438
+
439
+ self.fake_B = self.netGCombine(torch.cat([self.fake_B0,self.fake_B1],1))
440
+
441
+ # for AE visuals
442
+ if self.opt.region_enm in [0,1]:
443
+ if self.opt.nose_ae:
444
+ self.fake_B_nose_v = padpart(self.fake_B_nose, 'nose', self.center, self.opt, self.device)
445
+ self.fake_B_nose_v1 = padpart(self.fake_B_nose1, 'nose', self.center, self.opt, self.device)
446
+ self.fake_B_nose_v2 = padpart(self.fake_B_nose2, 'nose', self.center, self.opt, self.device)
447
+ self.cmask1no = padpart(self.cmask1, 'nose', self.center, self.opt, self.device)
448
+ if self.opt.others_ae:
449
+ self.fake_B_eyel_v = padpart(self.fake_B_eyel, 'eyel', self.center, self.opt, self.device)
450
+ self.fake_B_eyel_v1 = padpart(self.fake_B_eyel1, 'eyel', self.center, self.opt, self.device)
451
+ self.fake_B_eyel_v2 = padpart(self.fake_B_eyel2, 'eyel', self.center, self.opt, self.device)
452
+ self.cmask1el = padpart(self.cmask1el, 'eyel', self.center, self.opt, self.device)
453
+ self.fake_B_eyer_v = padpart(self.fake_B_eyer, 'eyer', self.center, self.opt, self.device)
454
+ self.fake_B_eyer_v1 = padpart(self.fake_B_eyer1, 'eyer', self.center, self.opt, self.device)
455
+ self.fake_B_eyer_v2 = padpart(self.fake_B_eyer2, 'eyer', self.center, self.opt, self.device)
456
+ self.cmask1er = padpart(self.cmask1er, 'eyer', self.center, self.opt, self.device)
457
+ self.fake_B_mouth_v = padpart(self.fake_B_mouth, 'mouth', self.center, self.opt, self.device)
458
+ self.fake_B_mouth_v1 = padpart(self.fake_B_mouth1, 'mouth', self.center, self.opt, self.device)
459
+ self.fake_B_mouth_v2 = padpart(self.fake_B_mouth2, 'mouth', self.center, self.opt, self.device)
460
+ self.cmask1mo = padpart(self.cmask1mo, 'mouth', self.center, self.opt, self.device)
461
+
462
+ if not self.isTrain and self.opt.test_continuity_loss:
463
+ self.ContinuityForTest(real=1)
464
+
465
+
466
+ def backward_D(self):
467
+ # Fake
468
+ # stop backprop to the generator by detaching fake_B
469
+ fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
470
+ #print('fake_AB', fake_AB.shape) # (1,4,512,512)
471
+ pred_fake = self.netD(fake_AB.detach())# by detach, not affect G's gradient
472
+ self.loss_D_fake = self.criterionGAN(pred_fake, False)
473
+ if self.opt.discriminator_local:
474
+ fake_AB_parts = self.getLocalParts(fake_AB)
475
+ local_names = ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG']
476
+ self.loss_D_fake_local = 0
477
+ for i in range(len(fake_AB_parts)):
478
+ net = getattr(self, 'net' + local_names[i])
479
+ pred_fake_tmp = net(fake_AB_parts[i].detach())
480
+ addw = self.getaddw(local_names[i])
481
+ self.loss_D_fake_local = self.loss_D_fake_local + self.criterionGAN(pred_fake_tmp, False) * addw
482
+ self.loss_D_fake = self.loss_D_fake + self.loss_D_fake_local
483
+
484
+ # Real
485
+ real_AB = torch.cat((self.real_A, self.real_B), 1)
486
+ pred_real = self.netD(real_AB)
487
+ self.loss_D_real = self.criterionGAN(pred_real, True)
488
+ if self.opt.discriminator_local:
489
+ real_AB_parts = self.getLocalParts(real_AB)
490
+ local_names = ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG']
491
+ self.loss_D_real_local = 0
492
+ for i in range(len(real_AB_parts)):
493
+ net = getattr(self, 'net' + local_names[i])
494
+ pred_real_tmp = net(real_AB_parts[i])
495
+ addw = self.getaddw(local_names[i])
496
+ self.loss_D_real_local = self.loss_D_real_local + self.criterionGAN(pred_real_tmp, True) * addw
497
+ self.loss_D_real = self.loss_D_real + self.loss_D_real_local
498
+
499
+ # Combined loss
500
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
501
+
502
+ self.loss_D.backward()
503
+
504
+ def backward_G(self):
505
+ # First, G(A) should fake the discriminator
506
+ fake_AB = torch.cat((self.real_A, self.fake_B), 1)
507
+ pred_fake = self.netD(fake_AB) # (1,4,512,512)->(1,1,30,30)
508
+ self.loss_G_GAN = self.criterionGAN(pred_fake, True)
509
+ if self.opt.discriminator_local:
510
+ fake_AB_parts = self.getLocalParts(fake_AB)
511
+ local_names = ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG']
512
+ self.loss_G_GAN_local = 0 # G_GAN_local is then added into G_GAN
513
+ for i in range(len(fake_AB_parts)):
514
+ net = getattr(self, 'net' + local_names[i])
515
+ pred_fake_tmp = net(fake_AB_parts[i])
516
+ addw = self.getaddw(local_names[i])
517
+ self.loss_G_GAN_local = self.loss_G_GAN_local + self.criterionGAN(pred_fake_tmp, True) * addw
518
+ if self.opt.gan_loss_strategy == 1:
519
+ self.loss_G_GAN = (self.loss_G_GAN + self.loss_G_GAN_local) / (len(fake_AB_parts) + 1)
520
+ elif self.opt.gan_loss_strategy == 2:
521
+ self.loss_G_GAN_local = self.loss_G_GAN_local * 0.25
522
+ self.loss_G_GAN = self.loss_G_GAN + self.loss_G_GAN_local
523
+
524
+ # Second, G(A) = B
525
+ if not self.opt.no_l1_loss:
526
+ self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
527
+
528
+ if self.opt.use_local and not self.opt.no_G_local_loss:
529
+ local_names = ['eyel','eyer','nose','mouth']
530
+ self.loss_G_local = 0
531
+ for i in range(len(local_names)):
532
+ fakeblocal = getattr(self, 'fake_B_' + local_names[i])
533
+ realblocal = getattr(self, 'real_B_' + local_names[i])
534
+ addw = self.getaddw(local_names[i])
535
+ self.loss_G_local = self.loss_G_local + self.criterionL1(fakeblocal,realblocal) * self.opt.lambda_local * addw
536
+ self.loss_G_hair_local = self.criterionL1(self.fake_B_hair, self.real_B_hair) * self.opt.lambda_local * self.opt.addw_hair
537
+ self.loss_G_bg_local = self.criterionL1(self.fake_B_bg, self.real_B_bg) * self.opt.lambda_local * self.opt.addw_bg
538
+
539
+ # Third, chamfer matching (assume chamfer_2way and chamfer_only_line is true)
540
+ if self.opt.chamfer_loss:
541
+ if self.fake_B.shape[1] == 3:
542
+ tmp = self.fake_B[:,0,...]*0.299+self.fake_B[:,1,...]*0.587+self.fake_B[:,2,...]*0.114
543
+ fake_B_gray = tmp.unsqueeze(1)
544
+ else:
545
+ fake_B_gray = self.fake_B
546
+ if self.real_B.shape[1] == 3:
547
+ tmp = self.real_B[:,0,...]*0.299+self.real_B[:,1,...]*0.587+self.real_B[:,2,...]*0.114
548
+ real_B_gray = tmp.unsqueeze(1)
549
+ else:
550
+ real_B_gray = self.real_B
551
+
552
+ gpu_p = self.opt.gpu_ids_p[0]
553
+ gpu = self.opt.gpu_ids[0]
554
+ if gpu_p != gpu:
555
+ fake_B_gray = fake_B_gray.cuda(gpu_p)
556
+ real_B_gray = real_B_gray.cuda(gpu_p)
557
+
558
+ # d_CM(a_i,G(p_i))
559
+ self.dt1 = self.netDT1(fake_B_gray)
560
+ self.dt2 = self.netDT2(fake_B_gray)
561
+ dt1 = self.dt1/2.0+0.5#[-1,1]->[0,1]
562
+ dt2 = self.dt2/2.0+0.5
563
+ if self.opt.dt_nonlinear != '':
564
+ dt_xmax = torch.Tensor([self.opt.dt_xmax]).cuda(gpu_p)
565
+ dt1 = nonlinearDt(dt1, self.opt.dt_nonlinear, dt_xmax)
566
+ dt2 = nonlinearDt(dt2, self.opt.dt_nonlinear, dt_xmax)
567
+ #print('dt1dt2',torch.min(dt1).item(),torch.max(dt1).item(),torch.min(dt2).item(),torch.max(dt2).item())
568
+
569
+ bs = real_B_gray.shape[0]
570
+ real_B_gray_line1 = self.netLine1(real_B_gray)
571
+ real_B_gray_line2 = self.netLine2(real_B_gray)
572
+ self.loss_G_chamfer = (dt1[(real_B_gray<0)&(real_B_gray_line1<0)].sum() + dt2[(real_B_gray>=0)&(real_B_gray_line2>=0)].sum()) / bs * self.opt.lambda_chamfer
573
+ if gpu_p != gpu:
574
+ self.loss_G_chamfer = self.loss_G_chamfer.cuda(gpu)
575
+
576
+ # d_CM(G(p_i),a_i)
577
+ if gpu_p != gpu:
578
+ dt1gt = self.dt1gt.cuda(gpu_p)
579
+ dt2gt = self.dt2gt.cuda(gpu_p)
580
+ else:
581
+ dt1gt = self.dt1gt
582
+ dt2gt = self.dt2gt
583
+ if self.opt.dt_nonlinear != '':
584
+ dt1gt = nonlinearDt(dt1gt, self.opt.dt_nonlinear, dt_xmax)
585
+ dt2gt = nonlinearDt(dt2gt, self.opt.dt_nonlinear, dt_xmax)
586
+ #print('dt1gtdt2gt',torch.min(dt1gt).item(),torch.max(dt1gt).item(),torch.min(dt2gt).item(),torch.max(dt2gt).item())
587
+ self.dt1gt = (self.dt1gt-0.5)*2
588
+ self.dt2gt = (self.dt2gt-0.5)*2
589
+
590
+ fake_B_gray_line1 = self.netLine1(fake_B_gray)
591
+ fake_B_gray_line2 = self.netLine2(fake_B_gray)
592
+ self.loss_G_chamfer2 = (dt1gt[(fake_B_gray<0)&(fake_B_gray_line1<0)].sum() + dt2gt[(fake_B_gray>=0)&(fake_B_gray_line2>=0)].sum()) / bs * self.opt.lambda_chamfer2
593
+ if gpu_p != gpu:
594
+ self.loss_G_chamfer2 = self.loss_G_chamfer2.cuda(gpu)
595
+
596
+ # Fourth, line continuity loss, constrained on synthesized drawing
597
+ if self.opt.continuity_loss:
598
+ # Patch-based
599
+ self.get_patches()
600
+ self.outputs = self.netRegressor(self.fake_B_patches)
601
+ if not self.opt.emphasis_conti_face:
602
+ self.loss_G_continuity = (1.0-torch.mean(self.outputs)).cuda(gpu) * self.opt.lambda_continuity
603
+ else:
604
+ self.loss_G_continuity = torch.mean((1.0-self.outputs)*self.conti_weights).cuda(gpu) * self.opt.lambda_continuity
605
+
606
+
607
+
608
+ self.loss_G = self.loss_G_GAN
609
+ if 'G_L1' in self.loss_names:
610
+ self.loss_G = self.loss_G + self.loss_G_L1
611
+ if 'G_local' in self.loss_names:
612
+ self.loss_G = self.loss_G + self.loss_G_local
613
+ if 'G_hair_local' in self.loss_names:
614
+ self.loss_G = self.loss_G + self.loss_G_hair_local
615
+ if 'G_bg_local' in self.loss_names:
616
+ self.loss_G = self.loss_G + self.loss_G_bg_local
617
+ if 'G_chamfer' in self.loss_names:
618
+ self.loss_G = self.loss_G + self.loss_G_chamfer
619
+ if 'G_chamfer2' in self.loss_names:
620
+ self.loss_G = self.loss_G + self.loss_G_chamfer2
621
+ if 'G_continuity' in self.loss_names:
622
+ self.loss_G = self.loss_G + self.loss_G_continuity
623
+
624
+ self.loss_G.backward()
625
+
626
+ def optimize_parameters(self):
627
+ self.forward()
628
+ # update D
629
+ self.set_requires_grad(self.netD, True)
630
+
631
+ if self.opt.discriminator_local:
632
+ self.set_requires_grad(self.netDLEyel, True)
633
+ self.set_requires_grad(self.netDLEyer, True)
634
+ self.set_requires_grad(self.netDLNose, True)
635
+ self.set_requires_grad(self.netDLMouth, True)
636
+ self.set_requires_grad(self.netDLHair, True)
637
+ self.set_requires_grad(self.netDLBG, True)
638
+ self.optimizer_D.zero_grad()
639
+ self.backward_D()
640
+ self.optimizer_D.step()
641
+
642
+ # update G
643
+ self.set_requires_grad(self.netD, False)
644
+ if self.opt.discriminator_local:
645
+ self.set_requires_grad(self.netDLEyel, False)
646
+ self.set_requires_grad(self.netDLEyer, False)
647
+ self.set_requires_grad(self.netDLNose, False)
648
+ self.set_requires_grad(self.netDLMouth, False)
649
+ self.set_requires_grad(self.netDLHair, False)
650
+ self.set_requires_grad(self.netDLBG, False)
651
+ self.optimizer_G.zero_grad()
652
+ self.backward_G()
653
+ self.optimizer_G.step()
654
+
655
+ def get_patches(self):
656
+ gpu_p = self.opt.gpu_ids_p[0]
657
+ gpu = self.opt.gpu_ids[0]
658
+ if gpu_p != gpu:
659
+ self.fake_B = self.fake_B.cuda(gpu_p)
660
+ # [1,1,512,512]->[bs,1,11,11]
661
+ patches = []
662
+ if self.isTrain and self.opt.emphasis_conti_face:
663
+ weights = []
664
+ W2 = int(W/2)
665
+ t = np.random.randint(res,size=2)
666
+ for i in range(aa):
667
+ for j in range(aa):
668
+ p = self.fake_B[:,:,t[0]+i*W:t[0]+(i+1)*W,t[1]+j*W:t[1]+(j+1)*W]
669
+ whitenum = torch.sum(p>=0.0)
670
+ #if whitenum < 5 or whitenum > W*W-5:
671
+ if whitenum < 1 or whitenum > W*W-1:
672
+ continue
673
+ patches.append(p)
674
+ if self.isTrain and self.opt.emphasis_conti_face:
675
+ weights.append(self.face_mask[:,:,t[0]+i*W+W2,t[1]+j*W+W2])
676
+ self.fake_B_patches = torch.cat(patches, dim=0)
677
+ if self.isTrain and self.opt.emphasis_conti_face:
678
+ self.conti_weights = torch.cat(weights, dim=0)+1 #0->1,1->2
679
+
680
+ def get_patches_real(self):
681
+ # [1,1,512,512]->[bs,1,11,11]
682
+ patches = []
683
+ t = np.random.randint(res,size=2)
684
+ for i in range(aa):
685
+ for j in range(aa):
686
+ p = self.real_B[:,:,t[0]+i*W:t[0]+(i+1)*W,t[1]+j*W:t[1]+(j+1)*W]
687
+ whitenum = torch.sum(p>=0.0)
688
+ #if whitenum < 5 or whitenum > W*W-5:
689
+ if whitenum < 1 or whitenum > W*W-1:
690
+ continue
691
+ patches.append(p)
692
+ self.real_B_patches = torch.cat(patches, dim=0)
APDrawingGAN2/models/base_model.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from . import networks
5
+
6
+
7
+ class BaseModel():
8
+
9
+ # modify parser to add command line options,
10
+ # and also change the default values if needed
11
+ @staticmethod
12
+ def modify_commandline_options(parser, is_train):
13
+ return parser
14
+
15
+ def name(self):
16
+ return 'BaseModel'
17
+
18
+ def initialize(self, opt):
19
+ self.opt = opt
20
+ self.gpu_ids = opt.gpu_ids
21
+ self.gpu_ids_p = opt.gpu_ids_p
22
+ self.isTrain = opt.isTrain
23
+ self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
24
+ self.device_p = torch.device('cuda:{}'.format(self.gpu_ids_p[0])) if self.gpu_ids else torch.device('cpu')
25
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
26
+ self.auxiliary_dir = os.path.join(opt.checkpoints_dir, opt.auxiliary_root)
27
+ if opt.resize_or_crop != 'scale_width':
28
+ torch.backends.cudnn.benchmark = True
29
+ self.loss_names = []
30
+ self.model_names = []
31
+ self.visual_names = []
32
+ self.image_paths = []
33
+
34
+ def set_input(self, input):
35
+ self.input = input
36
+
37
+ def forward(self):
38
+ pass
39
+
40
+ # load and print networks; create schedulers
41
+ def setup(self, opt, parser=None):
42
+ if self.isTrain:
43
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
44
+
45
+ if not self.isTrain or opt.continue_train:
46
+ self.load_networks(opt.which_epoch)
47
+ if len(self.auxiliary_model_names) > 0:
48
+ self.load_auxiliary_networks()
49
+ self.print_networks(opt.verbose)
50
+
51
+ # make models eval mode during test time
52
+ def eval(self):
53
+ for name in self.model_names:
54
+ if isinstance(name, str):
55
+ net = getattr(self, 'net' + name)
56
+ net.eval()
57
+
58
+ # used in test time, wrapping `forward` in no_grad() so we don't save
59
+ # intermediate steps for backprop
60
+ def test(self):
61
+ with torch.no_grad():
62
+ self.forward()
63
+
64
+ # get image paths
65
+ def get_image_paths(self):
66
+ return self.image_paths
67
+
68
+ def optimize_parameters(self):
69
+ pass
70
+
71
+ # update learning rate (called once every epoch)
72
+ def update_learning_rate(self):
73
+ for scheduler in self.schedulers:
74
+ scheduler.step()
75
+ lr = self.optimizers[0].param_groups[0]['lr']
76
+ print('learning rate = %.7f' % lr)
77
+
78
+ # return visualization images. train.py will display these images, and save the images to a html
79
+ def get_current_visuals(self):
80
+ visual_ret = OrderedDict()
81
+ for name in self.visual_names:
82
+ if isinstance(name, str):
83
+ visual_ret[name] = getattr(self, name)
84
+ return visual_ret
85
+
86
+ # return traning losses/errors. train.py will print out these errors as debugging information
87
+ def get_current_losses(self):
88
+ errors_ret = OrderedDict()
89
+ for name in self.loss_names:
90
+ if isinstance(name, str):
91
+ # float(...) works for both scalar tensor and float number
92
+ errors_ret[name] = float(getattr(self, 'loss_' + name))
93
+ return errors_ret
94
+
95
+ # save models to the disk
96
+ def save_networks(self, which_epoch):
97
+ for name in self.model_names:
98
+ if isinstance(name, str):
99
+ save_filename = '%s_net_%s.pth' % (which_epoch, name)
100
+ save_path = os.path.join(self.save_dir, save_filename)
101
+ net = getattr(self, 'net' + name)
102
+
103
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
104
+ torch.save(net.module.cpu().state_dict(), save_path)
105
+ net.cuda(self.gpu_ids[0])
106
+ else:
107
+ torch.save(net.cpu().state_dict(), save_path)
108
+
109
+ def save_networks2(self, which_epoch):
110
+ gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch))
111
+ dis_name = os.path.join(self.save_dir, '%s_net_dis.pt' % (which_epoch))
112
+ dict_gen = {}
113
+ dict_dis = {}
114
+ for name in self.model_names:
115
+ if isinstance(name, str):
116
+ net = getattr(self, 'net' + name)
117
+
118
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
119
+ state_dict = net.module.cpu().state_dict()
120
+ net.cuda(self.gpu_ids[0])
121
+ else:
122
+ state_dict = net.cpu().state_dict()
123
+
124
+ if name[0] == 'G':
125
+ dict_gen[name] = state_dict
126
+ elif name[0] == 'D':
127
+ dict_dis[name] = state_dict
128
+ else:
129
+ save_filename = '%s_net_%s.pth' % (which_epoch, name)
130
+ save_path = os.path.join(self.save_dir, save_filename)
131
+ torch.save(state_dict, save_path)
132
+ if dict_gen:
133
+ torch.save(dict_gen, gen_name)
134
+ if dict_dis:
135
+ torch.save(dict_dis, dis_name)
136
+
137
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
138
+ key = keys[i]
139
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
140
+ if module.__class__.__name__.startswith('InstanceNorm') and \
141
+ (key == 'running_mean' or key == 'running_var'):
142
+ if getattr(module, key) is None:
143
+ state_dict.pop('.'.join(keys))
144
+ if module.__class__.__name__.startswith('InstanceNorm') and \
145
+ (key == 'num_batches_tracked'):
146
+ state_dict.pop('.'.join(keys))
147
+ else:
148
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
149
+
150
+ # load models from the disk
151
+ def load_networks(self, which_epoch):
152
+ gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch))
153
+ if os.path.exists(gen_name):
154
+ self.load_networks2(which_epoch)
155
+ return
156
+ for name in self.model_names:
157
+ if isinstance(name, str):
158
+ load_filename = '%s_net_%s.pth' % (which_epoch, name)
159
+ load_path = os.path.join(self.save_dir, load_filename)
160
+ net = getattr(self, 'net' + name)
161
+ if isinstance(net, torch.nn.DataParallel):
162
+ net = net.module
163
+ print('loading the model from %s' % load_path)
164
+ # if you are using PyTorch newer than 0.4 (e.g., built from
165
+ # GitHub source), you can remove str() on self.device
166
+ state_dict = torch.load(load_path, map_location=str(self.device))
167
+ if hasattr(state_dict, '_metadata'):
168
+ del state_dict._metadata
169
+
170
+ # patch InstanceNorm checkpoints prior to 0.4
171
+ for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
172
+ self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
173
+ net.load_state_dict(state_dict)
174
+
175
+ def load_networks2(self, which_epoch):
176
+ gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch))
177
+ gen_state_dict = torch.load(gen_name, map_location=str(self.device))
178
+ if self.isTrain and self.opt.model != 'apdrawing_style_nogan':
179
+ dis_name = os.path.join(self.save_dir, '%s_net_dis.pt' % (which_epoch))
180
+ dis_state_dict = torch.load(dis_name, map_location=str(self.device))
181
+ for name in self.model_names:
182
+ if isinstance(name, str):
183
+ net = getattr(self, 'net' + name)
184
+ if isinstance(net, torch.nn.DataParallel):
185
+ net = net.module
186
+ if name[0] == 'G':
187
+ print('loading the model %s from %s' % (name,gen_name))
188
+ state_dict = gen_state_dict[name]
189
+ elif name[0] == 'D':
190
+ print('loading the model %s from %s' % (name,gen_name))
191
+ state_dict = dis_state_dict[name]
192
+
193
+ if hasattr(state_dict, '_metadata'):
194
+ del state_dict._metadata
195
+ # patch InstanceNorm checkpoints prior to 0.4
196
+ for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
197
+ self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
198
+ net.load_state_dict(state_dict)
199
+
200
+ # load auxiliary net models from the disk
201
+ def load_auxiliary_networks(self):
202
+ for name in self.auxiliary_model_names:
203
+ if isinstance(name, str):
204
+ if 'AE' in name and self.opt.ae_small:
205
+ load_filename = '%s_net_%s_small.pth' % ('latest', name)
206
+ elif 'Regressor' in name:
207
+ load_filename = '%s_net_%s%d.pth' % ('latest', name, self.opt.regarch)
208
+ else:
209
+ load_filename = '%s_net_%s.pth' % ('latest', name)
210
+ load_path = os.path.join(self.auxiliary_dir, load_filename)
211
+ net = getattr(self, 'net' + name)
212
+ if isinstance(net, torch.nn.DataParallel):
213
+ net = net.module
214
+ print('loading the model from %s' % load_path)
215
+ # if you are using PyTorch newer than 0.4 (e.g., built from
216
+ # GitHub source), you can remove str() on self.device
217
+ if name in ['DT1', 'DT2', 'Line1', 'Line2', 'Continuity1', 'Continuity2', 'Regressor', 'Regressorhair', 'Regressorface']:
218
+ state_dict = torch.load(load_path, map_location=str(self.device_p))
219
+ else:
220
+ state_dict = torch.load(load_path, map_location=str(self.device))
221
+ if hasattr(state_dict, '_metadata'):
222
+ del state_dict._metadata
223
+
224
+ # patch InstanceNorm checkpoints prior to 0.4
225
+ for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
226
+ self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
227
+ net.load_state_dict(state_dict)
228
+
229
+ # print network information
230
+ def print_networks(self, verbose):
231
+ print('---------- Networks initialized -------------')
232
+ for name in self.model_names:
233
+ if isinstance(name, str):
234
+ net = getattr(self, 'net' + name)
235
+ num_params = 0
236
+ for param in net.parameters():
237
+ num_params += param.numel()
238
+ if verbose:
239
+ print(net)
240
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
241
+ print('-----------------------------------------------')
242
+
243
+ # set requies_grad=Fasle to avoid computation
244
+ def set_requires_grad(self, nets, requires_grad=False):
245
+ if not isinstance(nets, list):
246
+ nets = [nets]
247
+ for net in nets:
248
+ if net is not None:
249
+ for param in net.parameters():
250
+ param.requires_grad = requires_grad
251
+
252
+ # =============================================================================================================
253
+ def inverse_mask(self, mask):
254
+ return torch.ones(mask.shape).to(self.device)-mask
255
+
256
+ def masked(self, A,mask):
257
+ return (A/2+0.5)*mask*2-1
258
+
259
+ def add_with_mask(self, A,B,mask):
260
+ return ((A/2+0.5)*mask+(B/2+0.5)*(torch.ones(mask.shape).to(self.device)-mask))*2-1
261
+
262
+ def addone_with_mask(self, A,mask):
263
+ return ((A/2+0.5)*mask+(torch.ones(mask.shape).to(self.device)-mask))*2-1
264
+
265
+ def partCombiner(self, eyel, eyer, nose, mouth, average_pos=False, comb_op = 1, region_enm = 0, cmaskel = None, cmasker = None, cmaskno = None, cmaskmo = None):
266
+ '''
267
+ x y
268
+ 100.571 123.429
269
+ 155.429 123.429
270
+ 128.000 155.886
271
+ 103.314 185.417
272
+ 152.686 185.417
273
+ this is the mean locaiton of 5 landmarks (for 256x256)
274
+ Pad2d Left,Right,Top,Down
275
+ '''
276
+ if comb_op == 0:
277
+ # use max pooling, pad black for eyes etc
278
+ padvalue = -1
279
+ if region_enm in [1,2]:
280
+ eyel = eyel * cmaskel
281
+ eyer = eyer * cmasker
282
+ nose = nose * cmaskno
283
+ mouth = mouth * cmaskmo
284
+ else:
285
+ # use min pooling, pad white for eyes etc
286
+ padvalue = 1
287
+ if region_enm in [1,2]:
288
+ eyel = self.addone_with_mask(eyel, cmaskel)
289
+ eyer = self.addone_with_mask(eyer, cmasker)
290
+ nose = self.addone_with_mask(nose, cmaskno)
291
+ mouth = self.addone_with_mask(mouth, cmaskmo)
292
+ if region_enm in [0,1]: # need to pad
293
+ IMAGE_SIZE = self.opt.fineSize
294
+ ratio = IMAGE_SIZE / 256
295
+ EYE_W = self.opt.EYE_W * ratio
296
+ EYE_H = self.opt.EYE_H * ratio
297
+ NOSE_W = self.opt.NOSE_W * ratio
298
+ NOSE_H = self.opt.NOSE_H * ratio
299
+ MOUTH_W = self.opt.MOUTH_W * ratio
300
+ MOUTH_H = self.opt.MOUTH_H * ratio
301
+ bs,nc,_,_ = eyel.shape
302
+ eyel_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
303
+ eyer_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
304
+ nose_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
305
+ mouth_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
306
+ for i in range(bs):
307
+ if not average_pos:
308
+ center = self.center[i]#x,y
309
+ else:# if average_pos = True
310
+ center = torch.tensor([[101,123-4],[155,123-4],[128,156-NOSE_H/2+16],[128,185]])
311
+ eyel_p[i] = torch.nn.ConstantPad2d((int(center[0,0] - EYE_W / 2 - 1), int(IMAGE_SIZE - (center[0,0]+EYE_W/2-1)), int(center[0,1] - EYE_H / 2 - 1),int(IMAGE_SIZE - (center[0,1]+EYE_H/2 - 1))),-1)(eyel[i])
312
+ eyer_p[i] = torch.nn.ConstantPad2d((int(center[1,0] - EYE_W / 2 - 1), int(IMAGE_SIZE - (center[1,0]+EYE_W/2-1)), int(center[1,1] - EYE_H / 2 - 1), int(IMAGE_SIZE - (center[1,1]+EYE_H/2 - 1))),-1)(eyer[i])
313
+ nose_p[i] = torch.nn.ConstantPad2d((int(center[2,0] - NOSE_W / 2 - 1), int(IMAGE_SIZE - (center[2,0]+NOSE_W/2-1)), int(center[2,1] - NOSE_H / 2 - 1), int(IMAGE_SIZE - (center[2,1]+NOSE_H/2 - 1))),-1)(nose[i])
314
+ mouth_p[i] = torch.nn.ConstantPad2d((int(center[3,0] - MOUTH_W / 2 - 1), int(IMAGE_SIZE - (center[3,0]+MOUTH_W/2-1)), int(center[3,1] - MOUTH_H / 2 - 1), int(IMAGE_SIZE - (center[3,1]+MOUTH_H/2 - 1))),-1)(mouth[i])
315
+ elif region_enm in [2]:
316
+ eyel_p = eyel
317
+ eyer_p = eyer
318
+ nose_p = nose
319
+ mouth_p = mouth
320
+ if comb_op == 0:
321
+ # use max pooling
322
+ eyes = torch.max(eyel_p, eyer_p)
323
+ eye_nose = torch.max(eyes, nose_p)
324
+ result = torch.max(eye_nose, mouth_p)
325
+ else:
326
+ # use min pooling
327
+ eyes = torch.min(eyel_p, eyer_p)
328
+ eye_nose = torch.min(eyes, nose_p)
329
+ result = torch.min(eye_nose, mouth_p)
330
+ return result
331
+
332
+ def partCombiner2(self, eyel, eyer, nose, mouth, hair, mask, comb_op = 1, region_enm = 0, cmaskel = None, cmasker = None, cmaskno = None, cmaskmo = None):
333
+ if comb_op == 0:
334
+ # use max pooling, pad black for eyes etc
335
+ padvalue = -1
336
+ hair = self.masked(hair, mask)
337
+ if region_enm in [1,2]:
338
+ eyel = eyel * cmaskel
339
+ eyer = eyer * cmasker
340
+ nose = nose * cmaskno
341
+ mouth = mouth * cmaskmo
342
+ else:
343
+ # use min pooling, pad white for eyes etc
344
+ padvalue = 1
345
+ hair = self.addone_with_mask(hair, mask)
346
+ if region_enm in [1,2]:
347
+ eyel = self.addone_with_mask(eyel, cmaskel)
348
+ eyer = self.addone_with_mask(eyer, cmasker)
349
+ nose = self.addone_with_mask(nose, cmaskno)
350
+ mouth = self.addone_with_mask(mouth, cmaskmo)
351
+ if region_enm in [0,1]: # need to pad
352
+ IMAGE_SIZE = self.opt.fineSize
353
+ ratio = IMAGE_SIZE / 256
354
+ EYE_W = self.opt.EYE_W * ratio
355
+ EYE_H = self.opt.EYE_H * ratio
356
+ NOSE_W = self.opt.NOSE_W * ratio
357
+ NOSE_H = self.opt.NOSE_H * ratio
358
+ MOUTH_W = self.opt.MOUTH_W * ratio
359
+ MOUTH_H = self.opt.MOUTH_H * ratio
360
+ bs,nc,_,_ = eyel.shape
361
+ eyel_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
362
+ eyer_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
363
+ nose_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
364
+ mouth_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
365
+ for i in range(bs):
366
+ center = self.center[i]#x,y
367
+ eyel_p[i] = torch.nn.ConstantPad2d((center[0,0] - EYE_W / 2, IMAGE_SIZE - (center[0,0]+EYE_W/2), center[0,1] - EYE_H / 2, IMAGE_SIZE - (center[0,1]+EYE_H/2)),padvalue)(eyel[i])
368
+ eyer_p[i] = torch.nn.ConstantPad2d((center[1,0] - EYE_W / 2, IMAGE_SIZE - (center[1,0]+EYE_W/2), center[1,1] - EYE_H / 2, IMAGE_SIZE - (center[1,1]+EYE_H/2)),padvalue)(eyer[i])
369
+ nose_p[i] = torch.nn.ConstantPad2d((center[2,0] - NOSE_W / 2, IMAGE_SIZE - (center[2,0]+NOSE_W/2), center[2,1] - NOSE_H / 2, IMAGE_SIZE - (center[2,1]+NOSE_H/2)),padvalue)(nose[i])
370
+ mouth_p[i] = torch.nn.ConstantPad2d((center[3,0] - MOUTH_W / 2, IMAGE_SIZE - (center[3,0]+MOUTH_W/2), center[3,1] - MOUTH_H / 2, IMAGE_SIZE - (center[3,1]+MOUTH_H/2)),padvalue)(mouth[i])
371
+ elif region_enm in [2]:
372
+ eyel_p = eyel
373
+ eyer_p = eyer
374
+ nose_p = nose
375
+ mouth_p = mouth
376
+ if comb_op == 0:
377
+ # use max pooling
378
+ eyes = torch.max(eyel_p, eyer_p)
379
+ eye_nose = torch.max(eyes, nose_p)
380
+ eye_nose_mouth = torch.max(eye_nose, mouth_p)
381
+ result = torch.max(hair,eye_nose_mouth)
382
+ else:
383
+ # use min pooling
384
+ eyes = torch.min(eyel_p, eyer_p)
385
+ eye_nose = torch.min(eyes, nose_p)
386
+ eye_nose_mouth = torch.min(eye_nose, mouth_p)
387
+ result = torch.min(hair,eye_nose_mouth)
388
+ return result
389
+
390
+ def partCombiner2_bg(self, eyel, eyer, nose, mouth, hair, bg, maskh, maskb, comb_op = 1, region_enm = 0, cmaskel = None, cmasker = None, cmaskno = None, cmaskmo = None):
391
+ if comb_op == 0:
392
+ # use max pooling, pad black for eyes etc
393
+ padvalue = -1
394
+ hair = self.masked(hair, maskh)
395
+ bg = self.masked(bg, maskb)
396
+ if region_enm in [1,2]:
397
+ eyel = eyel * cmaskel
398
+ eyer = eyer * cmasker
399
+ nose = nose * cmaskno
400
+ mouth = mouth * cmaskmo
401
+ else:
402
+ # use min pooling, pad white for eyes etc
403
+ padvalue = 1
404
+ hair = self.addone_with_mask(hair, maskh)
405
+ bg = self.addone_with_mask(bg, maskb)
406
+ if region_enm in [1,2]:
407
+ eyel = self.addone_with_mask(eyel, cmaskel)
408
+ eyer = self.addone_with_mask(eyer, cmasker)
409
+ nose = self.addone_with_mask(nose, cmaskno)
410
+ mouth = self.addone_with_mask(mouth, cmaskmo)
411
+ if region_enm in [0,1]: # need to pad to full size
412
+ IMAGE_SIZE = self.opt.fineSize
413
+ ratio = IMAGE_SIZE / 256
414
+ EYE_W = self.opt.EYE_W * ratio
415
+ EYE_H = self.opt.EYE_H * ratio
416
+ NOSE_W = self.opt.NOSE_W * ratio
417
+ NOSE_H = self.opt.NOSE_H * ratio
418
+ MOUTH_W = self.opt.MOUTH_W * ratio
419
+ MOUTH_H = self.opt.MOUTH_H * ratio
420
+ bs,nc,_,_ = eyel.shape
421
+ eyel_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
422
+ eyer_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
423
+ nose_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
424
+ mouth_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
425
+ for i in range(bs):
426
+ center = self.center[i]#x,y
427
+ eyel_p[i] = torch.nn.ConstantPad2d((center[0,0] - EYE_W / 2, IMAGE_SIZE - (center[0,0]+EYE_W/2), center[0,1] - EYE_H / 2, IMAGE_SIZE - (center[0,1]+EYE_H/2)),padvalue)(eyel[i])
428
+ eyer_p[i] = torch.nn.ConstantPad2d((center[1,0] - EYE_W / 2, IMAGE_SIZE - (center[1,0]+EYE_W/2), center[1,1] - EYE_H / 2, IMAGE_SIZE - (center[1,1]+EYE_H/2)),padvalue)(eyer[i])
429
+ nose_p[i] = torch.nn.ConstantPad2d((center[2,0] - NOSE_W / 2, IMAGE_SIZE - (center[2,0]+NOSE_W/2), center[2,1] - NOSE_H / 2, IMAGE_SIZE - (center[2,1]+NOSE_H/2)),padvalue)(nose[i])
430
+ mouth_p[i] = torch.nn.ConstantPad2d((center[3,0] - MOUTH_W / 2, IMAGE_SIZE - (center[3,0]+MOUTH_W/2), center[3,1] - MOUTH_H / 2, IMAGE_SIZE - (center[3,1]+MOUTH_H/2)),padvalue)(mouth[i])
431
+ elif region_enm in [2]:
432
+ eyel_p = eyel
433
+ eyer_p = eyer
434
+ nose_p = nose
435
+ mouth_p = mouth
436
+ if comb_op == 0:
437
+ eyes = torch.max(eyel_p, eyer_p)
438
+ eye_nose = torch.max(eyes, nose_p)
439
+ eye_nose_mouth = torch.max(eye_nose, mouth_p)
440
+ eye_nose_mouth_hair = torch.max(hair,eye_nose_mouth)
441
+ result = torch.max(bg,eye_nose_mouth_hair)
442
+ else:
443
+ eyes = torch.min(eyel_p, eyer_p)
444
+ eye_nose = torch.min(eyes, nose_p)
445
+ eye_nose_mouth = torch.min(eye_nose, mouth_p)
446
+ eye_nose_mouth_hair = torch.min(hair,eye_nose_mouth)
447
+ result = torch.min(bg,eye_nose_mouth_hair)
448
+ return result
449
+
450
+ def partCombiner3(self, face, hair, maskf, maskh, comb_op = 1):
451
+ if comb_op == 0:
452
+ # use max pooling, pad black etc
453
+ padvalue = -1
454
+ face = self.masked(face, maskf)
455
+ hair = self.masked(hair, maskh)
456
+ else:
457
+ # use min pooling, pad white etc
458
+ padvalue = 1
459
+ face = self.addone_with_mask(face, maskf)
460
+ hair = self.addone_with_mask(hair, maskh)
461
+ if comb_op == 0:
462
+ result = torch.max(face,hair)
463
+ else:
464
+ result = torch.min(face,hair)
465
+ return result
466
+
467
+
468
+ def tocv2(ts):
469
+ img = (ts.numpy()/2+0.5)*255
470
+ img = img.astype('uint8')
471
+ img = np.transpose(img,(1,2,0))
472
+ img = img[:,:,::-1]#rgb->bgr
473
+ return img
474
+
475
+ def totor(img):
476
+ img = img[:,:,::-1]
477
+ tor = transforms.ToTensor()(img)
478
+ tor = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(tor)
479
+ return tor
480
+
481
+
482
+ def ContinuityForTest(self, real = 0):
483
+ # Patch-based
484
+ self.get_patches()
485
+ self.outputs = self.netRegressor(self.fake_B_patches)
486
+ line_continuity = torch.mean(self.outputs)
487
+ opt = self.opt
488
+ file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'continuity.txt')
489
+ message = '%s %.04f' % (self.image_paths[0], line_continuity)
490
+ with open(file_name, 'a+') as c_file:
491
+ c_file.write(message)
492
+ c_file.write('\n')
493
+ if real == 1:
494
+ self.get_patches_real()
495
+ self.outputs2 = self.netRegressor(self.real_B_patches)
496
+ line_continuity2 = torch.mean(self.outputs2)
497
+ file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'continuity-r.txt')
498
+ message = '%s %.04f' % (self.image_paths[0], line_continuity2)
499
+ with open(file_name, 'a+') as c_file:
500
+ c_file.write(message)
501
+ c_file.write('\n')
502
+
503
+ def getLocalParts(self,fakeAB):
504
+ bs,nc,_,_ = fakeAB.shape #dtype torch.float32
505
+ ncr = int(nc / self.opt.output_nc)
506
+ if self.opt.region_enm in [0,1]:
507
+ ratio = self.opt.fineSize / 256
508
+ EYE_H = self.opt.EYE_H * ratio
509
+ EYE_W = self.opt.EYE_W * ratio
510
+ NOSE_H = self.opt.NOSE_H * ratio
511
+ NOSE_W = self.opt.NOSE_W * ratio
512
+ MOUTH_H = self.opt.MOUTH_H * ratio
513
+ MOUTH_W = self.opt.MOUTH_W * ratio
514
+ eyel = torch.ones((bs,nc,int(EYE_H),int(EYE_W))).to(self.device)
515
+ eyer = torch.ones((bs,nc,int(EYE_H),int(EYE_W))).to(self.device)
516
+ nose = torch.ones((bs,nc,int(NOSE_H),int(NOSE_W))).to(self.device)
517
+ mouth = torch.ones((bs,nc,int(MOUTH_H),int(MOUTH_W))).to(self.device)
518
+ for i in range(bs):
519
+ center = self.center[i]
520
+ eyel[i] = fakeAB[i,:,center[0,1]-EYE_H/2:center[0,1]+EYE_H/2,center[0,0]-EYE_W/2:center[0,0]+EYE_W/2]
521
+ eyer[i] = fakeAB[i,:,center[1,1]-EYE_H/2:center[1,1]+EYE_H/2,center[1,0]-EYE_W/2:center[1,0]+EYE_W/2]
522
+ nose[i] = fakeAB[i,:,center[2,1]-NOSE_H/2:center[2,1]+NOSE_H/2,center[2,0]-NOSE_W/2:center[2,0]+NOSE_W/2]
523
+ mouth[i] = fakeAB[i,:,center[3,1]-MOUTH_H/2:center[3,1]+MOUTH_H/2,center[3,0]-MOUTH_W/2:center[3,0]+MOUTH_W/2]
524
+ elif self.opt.region_enm in [2]:
525
+ eyel = (fakeAB/2+0.5) * self.cmaskel.repeat(1,ncr,1,1) * 2 - 1
526
+ eyer = (fakeAB/2+0.5) * self.cmasker.repeat(1,ncr,1,1) * 2 - 1
527
+ nose = (fakeAB/2+0.5) * self.cmask.repeat(1,ncr,1,1) * 2 - 1
528
+ mouth = (fakeAB/2+0.5) * self.cmaskmo.repeat(1,ncr,1,1) * 2 - 1
529
+ hair = (fakeAB/2+0.5) * self.mask.repeat(1,ncr,1,1) * self.mask2.repeat(1,ncr,1,1) * 2 - 1
530
+ bg = (fakeAB/2+0.5) * (torch.ones(fakeAB.shape).to(self.device)-self.mask2.repeat(1,ncr,1,1)) * 2 - 1
531
+ return eyel, eyer, nose, mouth, hair, bg
532
+
533
+ def getaddw(self,local_name):
534
+ addw = 1
535
+ if local_name in ['DLEyel','DLEyer','eyel','eyer','DLFace','face']:
536
+ addw = self.opt.addw_eye
537
+ elif local_name in ['DLNose', 'nose']:
538
+ addw = self.opt.addw_nose
539
+ elif local_name in ['DLMouth', 'mouth']:
540
+ addw = self.opt.addw_mouth
541
+ elif local_name in ['DLHair', 'hair']:
542
+ addw = self.opt.addw_hair
543
+ elif local_name in ['DLBG', 'bg']:
544
+ addw = self.opt.addw_bg
545
+ return addw
APDrawingGAN2/models/networks.py ADDED
@@ -0,0 +1,1194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import functools
5
+ from torch.optim import lr_scheduler
6
+
7
+ ###############################################################################
8
+ # Helper Functions
9
+ ###############################################################################
10
+
11
+
12
+ def get_norm_layer(norm_type='instance'):
13
+ if norm_type == 'batch':
14
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
15
+ elif norm_type == 'instance':
16
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
17
+ elif norm_type == 'none':
18
+ norm_layer = None
19
+ else:
20
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
21
+ return norm_layer
22
+
23
+
24
+ def get_scheduler(optimizer, opt):
25
+ if opt.lr_policy == 'lambda':
26
+ def lambda_rule(epoch):
27
+ lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
28
+ return lr_l
29
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
30
+ elif opt.lr_policy == 'step':
31
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
32
+ elif opt.lr_policy == 'plateau':
33
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
34
+ elif opt.lr_policy == 'cosine':
35
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
36
+ else:
37
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
38
+ return scheduler
39
+
40
+
41
+ def init_weights(net, init_type='normal', gain=0.02):
42
+ def init_func(m):
43
+ classname = m.__class__.__name__
44
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
45
+ if init_type == 'normal':
46
+ init.normal_(m.weight.data, 0.0, gain)
47
+ elif init_type == 'xavier':
48
+ init.xavier_normal_(m.weight.data, gain=gain)
49
+ elif init_type == 'kaiming':
50
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
51
+ elif init_type == 'orthogonal':
52
+ init.orthogonal_(m.weight.data, gain=gain)
53
+ else:
54
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
55
+ if hasattr(m, 'bias') and m.bias is not None:
56
+ init.constant_(m.bias.data, 0.0)
57
+ elif classname.find('BatchNorm2d') != -1:
58
+ init.normal_(m.weight.data, 1.0, gain)
59
+ init.constant_(m.bias.data, 0.0)
60
+
61
+ print('initialize network with %s' % init_type)
62
+ net.apply(init_func)
63
+
64
+
65
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
66
+ if len(gpu_ids) > 0:
67
+ assert(torch.cuda.is_available())
68
+ net.to(gpu_ids[0])
69
+ net = torch.nn.DataParallel(net, gpu_ids)
70
+ init_weights(net, init_type, gain=init_gain)
71
+ return net
72
+
73
+
74
+ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], nnG=9, multiple=2, latent_dim=1024, ae_h=96, ae_w=96, extra_channel=2, nres=1):
75
+ net = None
76
+ norm_layer = get_norm_layer(norm_type=norm)
77
+
78
+ if netG == 'autoencoder':
79
+ net = AutoEncoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
80
+ elif netG == 'autoencoderfc':
81
+ net = AutoEncoderWithFC(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
82
+ multiple=multiple, latent_dim=latent_dim, h=ae_h, w=ae_w)
83
+ elif netG == 'autoencoderfc2':
84
+ net = AutoEncoderWithFC2(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
85
+ multiple=multiple, latent_dim=latent_dim, h=ae_h, w=ae_w)
86
+ elif netG == 'vae':
87
+ net = VAE(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
88
+ multiple=multiple, latent_dim=latent_dim, h=ae_h, w=ae_w)
89
+ elif netG == 'classifier':
90
+ net = Classifier(input_nc, output_nc, ngf, num_downs=nnG, norm_layer=norm_layer, use_dropout=use_dropout, h=ae_h, w=ae_w)
91
+ elif netG == 'regressor':
92
+ net = Regressor(input_nc, ngf, norm_layer=norm_layer, arch=nnG)
93
+ elif netG == 'resnet_9blocks':#default for cyclegan
94
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
95
+ elif netG == 'resnet_6blocks':
96
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
97
+ elif netG == 'resnet_nblocks':
98
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=nnG)
99
+ elif netG == 'resnet_style2_9blocks':
100
+ net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=0, extra_channel=extra_channel)
101
+ elif netG == 'resnet_style2_6blocks':
102
+ net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, model0_res=0, extra_channel=extra_channel)
103
+ elif netG == 'resnet_style2_nblocks':
104
+ net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=nnG, model0_res=0, extra_channel=extra_channel)
105
+ elif netG == 'unet_128':
106
+ net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
107
+ elif netG == 'unet_256':#default for pix2pix
108
+ net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
109
+ elif netG == 'unet_512':
110
+ net = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
111
+ elif netG == 'unet_ndown':
112
+ net = UnetGenerator(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
113
+ elif netG == 'unetres_ndown':
114
+ net = UnetResGenerator(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout, nres=nres)
115
+ elif netG == 'partunet':
116
+ net = PartUnet(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
117
+ elif netG == 'partunet2':
118
+ net = PartUnet2(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
119
+ elif netG == 'partunetres':
120
+ net = PartUnetRes(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout,nres=nres)
121
+ elif netG == 'partunet2res':
122
+ net = PartUnet2Res(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout,nres=nres)
123
+ elif netG == 'partunet2style':
124
+ net = PartUnet2Style(input_nc, output_nc, nnG, ngf, extra_channel=extra_channel, norm_layer=norm_layer, use_dropout=use_dropout)
125
+ elif netG == 'partunet2resstyle':
126
+ net = PartUnet2ResStyle(input_nc, output_nc, nnG, ngf, extra_channel=extra_channel, norm_layer=norm_layer, use_dropout=use_dropout,nres=nres)
127
+ elif netG == 'combiner':
128
+ net = Combiner(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=2)
129
+ elif netG == 'combiner2':
130
+ net = Combiner2(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
131
+ else:
132
+ raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
133
+ return init_net(net, init_type, init_gain, gpu_ids)
134
+
135
+
136
+ def define_D(input_nc, ndf, netD,
137
+ n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
138
+ net = None
139
+ norm_layer = get_norm_layer(norm_type=norm)
140
+
141
+ if netD == 'basic':
142
+ net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
143
+ elif netD == 'n_layers':
144
+ net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
145
+ elif netD == 'pixel':
146
+ net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
147
+ else:
148
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
149
+ return init_net(net, init_type, init_gain, gpu_ids)
150
+
151
+
152
+ ##############################################################################
153
+ # Classes
154
+ ##############################################################################
155
+
156
+
157
+ # Defines the GAN loss which uses either LSGAN or the regular GAN.
158
+ # When LSGAN is used, it is basically same as MSELoss,
159
+ # but it abstracts away the need to create the target label tensor
160
+ # that has the same size as the input
161
+ class GANLoss(nn.Module):
162
+ def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
163
+ super(GANLoss, self).__init__()
164
+ self.register_buffer('real_label', torch.tensor(target_real_label))
165
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
166
+ if use_lsgan:
167
+ self.loss = nn.MSELoss()
168
+ else:#no_lsgan
169
+ self.loss = nn.BCELoss()
170
+
171
+ def get_target_tensor(self, input, target_is_real):
172
+ if target_is_real:
173
+ target_tensor = self.real_label
174
+ else:
175
+ target_tensor = self.fake_label
176
+ return target_tensor.expand_as(input)
177
+
178
+ def __call__(self, input, target_is_real):
179
+ target_tensor = self.get_target_tensor(input, target_is_real)
180
+ return self.loss(input, target_tensor)
181
+
182
+
183
+ class AutoEncoderMNIST(nn.Module):
184
+ def __init__(self):
185
+ super(AutoEncoderMNIST, self).__init__()
186
+ self.encoder = nn.Sequential(
187
+ nn.Conv2d(1, 16, 3, stride=3, padding=1), # b, 16, 10, 10
188
+ nn.ReLU(True),
189
+ nn.MaxPool2d(2, stride=2), # b, 16, 5, 5
190
+ nn.Conv2d(16, 8, 3, stride=2, padding=1), # b, 8, 3, 3
191
+ nn.ReLU(True),
192
+ nn.MaxPool2d(2, stride=1) # b, 8, 2, 2
193
+ )
194
+ self.decoder = nn.Sequential(
195
+ nn.ConvTranspose2d(8, 16, 3, stride=2), # b, 16, 5, 5
196
+ nn.ReLU(True),
197
+ nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1), # b, 8, 15, 15
198
+ nn.ReLU(True),
199
+ nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1), # b, 1, 28, 28
200
+ nn.Tanh()
201
+ )
202
+
203
+ def forward(self, x):
204
+ x = self.encoder(x)
205
+ x = self.decoder(x)
206
+ return x
207
+
208
+ class AutoEncoder(nn.Module):
209
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, padding_type='reflect'):
210
+ super(AutoEncoder, self).__init__()
211
+ self.input_nc = input_nc
212
+ self.output_nc = output_nc
213
+ self.ngf = ngf
214
+ if type(norm_layer) == functools.partial:
215
+ use_bias = norm_layer.func == nn.InstanceNorm2d
216
+ else:
217
+ use_bias = norm_layer == nn.InstanceNorm2d
218
+
219
+ model = [nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]
220
+ n_downsampling = 3
221
+ for i in range(n_downsampling):
222
+ mult = 2**i
223
+ model += [nn.LeakyReLU(0.2),
224
+ nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=4,
225
+ stride=2, padding=1, bias=use_bias),
226
+ norm_layer(ngf * mult * 2)]
227
+ self.encoder = nn.Sequential(*model)
228
+
229
+ model2 = []
230
+ for i in range(n_downsampling):
231
+ mult = 2**(n_downsampling - i)
232
+ model2 += [nn.ReLU(),
233
+ nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
234
+ kernel_size=4, stride=2,
235
+ padding=1, bias=use_bias),
236
+ norm_layer(int(ngf * mult / 2))]
237
+ model2 += [nn.ReLU()]
238
+ model2 += [nn.ConvTranspose2d(ngf, output_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)]
239
+ model2 += [nn.Tanh()]
240
+ self.decoder = nn.Sequential(*model2)
241
+
242
+ def forward(self, x):
243
+ ax = self.encoder(x) # b, 512, 6, 6
244
+ y = self.decoder(ax)
245
+ return y, ax
246
+
247
+ class AutoEncoderWithFC(nn.Module):
248
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, multiple=2,latent_dim=1024, h=96, w=96):
249
+ super(AutoEncoderWithFC, self).__init__()
250
+ self.input_nc = input_nc
251
+ self.output_nc = output_nc
252
+ self.ngf = ngf
253
+ if type(norm_layer) == functools.partial:
254
+ use_bias = norm_layer.func == nn.InstanceNorm2d
255
+ else:
256
+ use_bias = norm_layer == nn.InstanceNorm2d
257
+
258
+ model = [nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]
259
+ n_downsampling = 3
260
+ #multiple = 2
261
+ for i in range(n_downsampling):
262
+ mult = multiple**i
263
+ model += [nn.LeakyReLU(0.2),
264
+ nn.Conv2d(int(ngf * mult), int(ngf * mult * multiple), kernel_size=4,
265
+ stride=2, padding=1, bias=use_bias),
266
+ norm_layer(int(ngf * mult * multiple))]
267
+ self.encoder = nn.Sequential(*model)
268
+ self.fc1 = nn.Linear(int(ngf*(multiple**n_downsampling)*h/16*w/16),latent_dim)
269
+ self.relu = nn.ReLU(latent_dim)
270
+ self.fc2 = nn.Linear(latent_dim,int(ngf*(multiple**n_downsampling)*h/16*w/16))
271
+ self.rh = int(h/16)
272
+ self.rw = int(w/16)
273
+ model2 = []
274
+ for i in range(n_downsampling):
275
+ mult = multiple**(n_downsampling - i)
276
+ model2 += [nn.ReLU(),
277
+ nn.ConvTranspose2d(int(ngf * mult), int(ngf * mult / multiple),
278
+ kernel_size=4, stride=2,
279
+ padding=1, bias=use_bias),
280
+ norm_layer(int(ngf * mult / multiple))]
281
+ model2 += [nn.ReLU()]
282
+ model2 += [nn.ConvTranspose2d(ngf, output_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)]
283
+ model2 += [nn.Tanh()]
284
+ self.decoder = nn.Sequential(*model2)
285
+
286
+ def forward(self, x):
287
+ ax = self.encoder(x) # b, 512, 6, 6
288
+ ax = ax.view(ax.size(0), -1) # view -- reshape
289
+ ax = self.relu(self.fc1(ax))
290
+ ax = self.fc2(ax)
291
+ ax = ax.view(ax.size(0),-1,self.rh,self.rw)
292
+ y = self.decoder(ax)
293
+ return y, ax
294
+
295
+ class AutoEncoderWithFC2(nn.Module):
296
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, multiple=2,latent_dim=1024, h=96, w=96):
297
+ super(AutoEncoderWithFC2, self).__init__()
298
+ self.input_nc = input_nc
299
+ self.output_nc = output_nc
300
+ self.ngf = ngf
301
+ if type(norm_layer) == functools.partial:
302
+ use_bias = norm_layer.func == nn.InstanceNorm2d
303
+ else:
304
+ use_bias = norm_layer == nn.InstanceNorm2d
305
+
306
+ model = [nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]
307
+ n_downsampling = 2
308
+ #multiple = 2
309
+ for i in range(n_downsampling):
310
+ mult = multiple**i
311
+ model += [nn.LeakyReLU(0.2),
312
+ nn.Conv2d(int(ngf * mult), int(ngf * mult * multiple), kernel_size=4,
313
+ stride=2, padding=1, bias=use_bias),
314
+ norm_layer(int(ngf * mult * multiple))]
315
+ self.encoder = nn.Sequential(*model)
316
+ self.fc1 = nn.Linear(int(ngf*(multiple**n_downsampling)*h/8*w/8),latent_dim)
317
+ self.relu = nn.ReLU(latent_dim)
318
+ self.fc2 = nn.Linear(latent_dim,int(ngf*(multiple**n_downsampling)*h/8*w/8))
319
+ self.rh = h/8
320
+ self.rw = w/8
321
+ model2 = []
322
+ for i in range(n_downsampling):
323
+ mult = multiple**(n_downsampling - i)
324
+ model2 += [nn.ReLU(),
325
+ nn.ConvTranspose2d(int(ngf * mult), int(ngf * mult / multiple),
326
+ kernel_size=4, stride=2,
327
+ padding=1, bias=use_bias),
328
+ norm_layer(int(ngf * mult / multiple))]
329
+ model2 += [nn.ReLU()]
330
+ model2 += [nn.ConvTranspose2d(ngf, output_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)]
331
+ model2 += [nn.Tanh()]
332
+ self.decoder = nn.Sequential(*model2)
333
+
334
+ def forward(self, x):
335
+ ax = self.encoder(x) # b, 256, 12, 12
336
+ ax = ax.view(ax.size(0), -1) # view -- reshape
337
+ ax = self.relu(self.fc1(ax))
338
+ ax = self.fc2(ax)
339
+ ax = ax.view(ax.size(0),-1,self.rh,self.rw)
340
+ y = self.decoder(ax)
341
+ return y, ax
342
+
343
+ class VAE(nn.Module):
344
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, multiple=2,latent_dim=1024, h=96, w=96):
345
+ super(VAE, self).__init__()
346
+ self.input_nc = input_nc
347
+ self.output_nc = output_nc
348
+ self.ngf = ngf
349
+ if type(norm_layer) == functools.partial:
350
+ use_bias = norm_layer.func == nn.InstanceNorm2d
351
+ else:
352
+ use_bias = norm_layer == nn.InstanceNorm2d
353
+
354
+ model = [nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]
355
+ n_downsampling = 3
356
+ for i in range(n_downsampling):
357
+ mult = multiple**i
358
+ model += [nn.LeakyReLU(0.2),
359
+ nn.Conv2d(int(ngf * mult), int(ngf * mult * multiple), kernel_size=4,
360
+ stride=2, padding=1, bias=use_bias),
361
+ norm_layer(int(ngf * mult * multiple))]
362
+ self.encoder_cnn = nn.Sequential(*model)
363
+
364
+ self.c_dim = int(ngf*(multiple**n_downsampling)*h/16*w/16)
365
+ self.rh = h/16
366
+ self.rw = w/16
367
+ self.fc1 = nn.Linear(self.c_dim,latent_dim)
368
+ self.fc2 = nn.Linear(self.c_dim,latent_dim)
369
+ self.fc3 = nn.Linear(latent_dim,self.c_dim)
370
+ self.relu = nn.ReLU()
371
+
372
+ model2 = []
373
+ for i in range(n_downsampling):
374
+ mult = multiple**(n_downsampling - i)
375
+ model2 += [nn.ReLU(),
376
+ nn.ConvTranspose2d(int(ngf * mult), int(ngf * mult / multiple),
377
+ kernel_size=4, stride=2,
378
+ padding=1, bias=use_bias),
379
+ norm_layer(int(ngf * mult / multiple))]
380
+ model2 += [nn.ReLU()]
381
+ model2 += [nn.ConvTranspose2d(ngf, output_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)]
382
+ model2 += [nn.Tanh()]#[-1,1]
383
+ self.decoder_cnn = nn.Sequential(*model2)
384
+
385
+ def encode(self, x):
386
+ h1 = self.encoder_cnn(x)
387
+ r1 = h1.view(h1.size(0), -1)
388
+ return self.fc1(r1), self.fc2(r1)
389
+
390
+ def reparameterize(self, mu, logvar):# not deterministic for test mode
391
+ std = torch.exp(0.5*logvar)
392
+ eps = torch.randn_like(std)# torch.rand_like returns a tensor with the same size as input,
393
+ # that is filled with random numbers from a normal distribution N(0,1)
394
+ return eps.mul(std).add_(mu)
395
+
396
+ def decode(self, z):
397
+ h4 = self.relu(self.fc3(z))
398
+ r3 = h4.view(h4.size(0),-1,self.rh,self.rw)
399
+ return self.decoder_cnn(r3)
400
+
401
+ def forward(self, x):
402
+ mu, logvar = self.encode(x)
403
+ z = self.reparameterize(mu, logvar)
404
+ reconx = self.decode(z)
405
+ return reconx, mu, logvar
406
+
407
+ class Classifier(nn.Module):
408
+ def __init__(self, input_nc, classes, ngf=64, num_downs=3, norm_layer=nn.BatchNorm2d, use_dropout=False,
409
+ h=96, w=96):
410
+ super(Classifier, self).__init__()
411
+ self.input_nc = input_nc
412
+ self.ngf = ngf
413
+ if type(norm_layer) == functools.partial:
414
+ use_bias = norm_layer.func == nn.InstanceNorm2d
415
+ else:
416
+ use_bias = norm_layer == nn.InstanceNorm2d
417
+
418
+ model = [nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]
419
+ multiple = 2
420
+ for i in range(num_downs):
421
+ mult = multiple**i
422
+ model += [nn.LeakyReLU(0.2),
423
+ nn.Conv2d(int(ngf * mult), int(ngf * mult * multiple), kernel_size=4,
424
+ stride=2, padding=1, bias=use_bias),
425
+ norm_layer(int(ngf * mult * multiple))]
426
+ self.encoder = nn.Sequential(*model)
427
+ strides = 2**(num_downs+1)
428
+ self.fc1 = nn.Linear(int(ngf*h*w/(strides*2)), classes)
429
+
430
+ def forward(self, x):
431
+ ax = self.encoder(x) # b, 512, 6, 6
432
+ ax = ax.view(ax.size(0), -1) # view -- reshape
433
+ return self.fc1(ax)
434
+
435
+ class Regressor(nn.Module):
436
+ def __init__(self, input_nc, ngf=64, norm_layer=nn.BatchNorm2d, arch=1):
437
+ super(Regressor, self).__init__()
438
+ # if use BatchNorm2d,
439
+ # no need to use bias as BatchNorm2d has affine parameters
440
+
441
+ self.arch = arch
442
+
443
+ if arch == 1:
444
+ use_bias = True
445
+ sequence = [
446
+ nn.Conv2d(input_nc, ngf, kernel_size=3, stride=2, padding=0, bias=use_bias),#11->5
447
+ nn.LeakyReLU(0.2, True),
448
+ nn.Conv2d(ngf, 1, kernel_size=5, stride=1, padding=0, bias=use_bias),#5->1
449
+ ]
450
+ elif arch == 2:
451
+ if type(norm_layer) == functools.partial:
452
+ use_bias = norm_layer.func == nn.InstanceNorm2d
453
+ else:
454
+ use_bias = norm_layer == nn.InstanceNorm2d
455
+ sequence = [
456
+ nn.Conv2d(input_nc, ngf, kernel_size=3, stride=1, padding=0, bias=use_bias),#11->9
457
+ nn.LeakyReLU(0.2, True),
458
+ nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=1, padding=0, bias=use_bias),#9->7
459
+ norm_layer(ngf*2),
460
+ nn.LeakyReLU(0.2, True),
461
+ nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=1, padding=0, bias=use_bias),#7->5
462
+ norm_layer(ngf*4),
463
+ nn.LeakyReLU(0.2, True),
464
+ nn.Conv2d(ngf*4, 1, kernel_size=5, stride=1, padding=0, bias=use_bias),#5->1
465
+ ]
466
+ elif arch == 3:
467
+ use_bias = True
468
+ sequence = [
469
+ nn.Conv2d(input_nc, ngf, kernel_size=3, stride=1, padding=1, bias=use_bias),#11->11
470
+ nn.LeakyReLU(0.2, True),
471
+ nn.Conv2d(ngf, 1, kernel_size=11, stride=1, padding=0, bias=use_bias),#11->1
472
+ ]
473
+ elif arch == 4:
474
+ use_bias = True
475
+ sequence = [
476
+ nn.Conv2d(input_nc, ngf, kernel_size=3, stride=1, padding=1, bias=use_bias),#11->11
477
+ nn.LeakyReLU(0.2, True),
478
+ nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=1, padding=1, bias=use_bias),#11->11
479
+ nn.LeakyReLU(0.2, True),
480
+ nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=1, padding=1, bias=use_bias),#11->11
481
+ nn.LeakyReLU(0.2, True),
482
+ nn.Conv2d(ngf*4, 1, kernel_size=11, stride=1, padding=0, bias=use_bias),#11->1
483
+ ]
484
+ elif arch == 5:
485
+ use_bias = True
486
+ sequence = [
487
+ nn.Conv2d(input_nc, ngf, kernel_size=3, stride=1, padding=1, bias=use_bias),#11->11
488
+ nn.LeakyReLU(0.2, True),
489
+ nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=1, padding=1, bias=use_bias),#11->11
490
+ nn.LeakyReLU(0.2, True),
491
+ nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=1, padding=1, bias=use_bias),#11->11
492
+ nn.LeakyReLU(0.2, True),
493
+ ]
494
+ fc = [
495
+ nn.Linear(ngf*4*11*11, 4096),
496
+ nn.ReLU(True),
497
+ nn.Dropout(),
498
+ nn.Linear(4096, 1),
499
+ ]
500
+ self.fc = nn.Sequential(*fc)
501
+
502
+ self.model = nn.Sequential(*sequence)
503
+
504
+ def forward(self, x):
505
+ if self.arch <= 4:
506
+ return self.model(x)
507
+ else:
508
+ x = self.model(x)
509
+ x = x.view(x.size(0), -1)
510
+ x = self.fc(x)
511
+ return x
512
+
513
+
514
+ # Defines the generator that consists of Resnet blocks between a few
515
+ # downsampling/upsampling operations.
516
+ # Code and idea originally from Justin Johnson's architecture.
517
+ # https://github.com/jcjohnson/fast-neural-style/
518
+ class ResnetGenerator(nn.Module):
519
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
520
+ assert(n_blocks >= 0)
521
+ super(ResnetGenerator, self).__init__()
522
+ self.input_nc = input_nc
523
+ self.output_nc = output_nc
524
+ self.ngf = ngf
525
+ if type(norm_layer) == functools.partial:
526
+ use_bias = norm_layer.func == nn.InstanceNorm2d
527
+ else:
528
+ use_bias = norm_layer == nn.InstanceNorm2d
529
+
530
+ model = [nn.ReflectionPad2d(3),
531
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
532
+ bias=use_bias),
533
+ norm_layer(ngf),
534
+ nn.ReLU(True)]
535
+
536
+ n_downsampling = 2
537
+ for i in range(n_downsampling):
538
+ mult = 2**i
539
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
540
+ stride=2, padding=1, bias=use_bias),
541
+ norm_layer(ngf * mult * 2),
542
+ nn.ReLU(True)]
543
+
544
+ mult = 2**n_downsampling
545
+ for i in range(n_blocks):
546
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
547
+
548
+ for i in range(n_downsampling):
549
+ mult = 2**(n_downsampling - i)
550
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
551
+ kernel_size=3, stride=2,
552
+ padding=1, output_padding=1,
553
+ bias=use_bias),
554
+ norm_layer(int(ngf * mult / 2)),
555
+ nn.ReLU(True)]
556
+ model += [nn.ReflectionPad2d(3)]
557
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
558
+ model += [nn.Tanh()]
559
+
560
+ self.model = nn.Sequential(*model)
561
+
562
+ def forward(self, input):
563
+ return self.model(input)
564
+
565
+ class ResnetStyle2Generator(nn.Module):
566
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
567
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
568
+ """
569
+
570
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', extra_channel=3, model0_res=0):
571
+ """Construct a Resnet-based generator
572
+
573
+ Parameters:
574
+ input_nc (int) -- the number of channels in input images
575
+ output_nc (int) -- the number of channels in output images
576
+ ngf (int) -- the number of filters in the last conv layer
577
+ norm_layer -- normalization layer
578
+ use_dropout (bool) -- if use dropout layers
579
+ n_blocks (int) -- the number of ResNet blocks
580
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
581
+ """
582
+ assert(n_blocks >= 0)
583
+ super(ResnetStyle2Generator, self).__init__()
584
+ if type(norm_layer) == functools.partial:
585
+ use_bias = norm_layer.func == nn.InstanceNorm2d
586
+ else:
587
+ use_bias = norm_layer == nn.InstanceNorm2d
588
+
589
+ model0 = [nn.ReflectionPad2d(3),
590
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
591
+ norm_layer(ngf),
592
+ nn.ReLU(True)]
593
+
594
+ n_downsampling = 2
595
+ for i in range(n_downsampling): # add downsampling layers
596
+ mult = 2 ** i
597
+ model0 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
598
+ norm_layer(ngf * mult * 2),
599
+ nn.ReLU(True)]
600
+
601
+ mult = 2 ** n_downsampling
602
+ for i in range(model0_res): # add ResNet blocks
603
+ model0 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
604
+
605
+ model = []
606
+ model += [nn.Conv2d(ngf * mult + extra_channel, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),
607
+ norm_layer(ngf * mult),
608
+ nn.ReLU(True)]
609
+
610
+ for i in range(n_blocks-model0_res): # add ResNet blocks
611
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
612
+
613
+ for i in range(n_downsampling): # add upsampling layers
614
+ mult = 2 ** (n_downsampling - i)
615
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
616
+ kernel_size=3, stride=2,
617
+ padding=1, output_padding=1,
618
+ bias=use_bias),
619
+ norm_layer(int(ngf * mult / 2)),
620
+ nn.ReLU(True)]
621
+ model += [nn.ReflectionPad2d(3)]
622
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
623
+ model += [nn.Tanh()]
624
+
625
+ self.model0 = nn.Sequential(*model0)
626
+ self.model = nn.Sequential(*model)
627
+ print(list(self.modules()))
628
+
629
+ def forward(self, input1, input2): # input2 [bs,c]
630
+ """Standard forward"""
631
+ f1 = self.model0(input1)
632
+ [bs,c,h,w] = f1.shape
633
+ input2 = input2.repeat(h,w,1,1).permute([2,3,0,1])
634
+ y1 = torch.cat([f1, input2], 1)
635
+ return self.model(y1)
636
+
637
+ class Combiner(nn.Module):
638
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
639
+ assert(n_blocks >= 0)
640
+ super(Combiner, self).__init__()
641
+ self.input_nc = input_nc
642
+ self.output_nc = output_nc
643
+ self.ngf = ngf
644
+ if type(norm_layer) == functools.partial:
645
+ use_bias = norm_layer.func == nn.InstanceNorm2d
646
+ else:
647
+ use_bias = norm_layer == nn.InstanceNorm2d
648
+
649
+ model = [nn.ReflectionPad2d(3),
650
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
651
+ bias=use_bias),
652
+ norm_layer(ngf),
653
+ nn.ReLU(True)]
654
+
655
+ for i in range(n_blocks):
656
+ model += [ResnetBlock(ngf, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
657
+
658
+ model += [nn.ReflectionPad2d(3)]
659
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
660
+ model += [nn.Tanh()]
661
+
662
+ self.model = nn.Sequential(*model)
663
+
664
+ def forward(self, input):
665
+ return self.model(input)
666
+
667
+ class Combiner2(nn.Module):
668
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
669
+ norm_layer=nn.BatchNorm2d, use_dropout=False):
670
+ super(Combiner2, self).__init__()
671
+
672
+ # construct unet structure
673
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
674
+ unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
675
+
676
+ self.model = unet_block
677
+
678
+ def forward(self, input):
679
+ return self.model(input)
680
+
681
+
682
+ # Define a resnet block
683
+ class ResnetBlock(nn.Module):
684
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
685
+ super(ResnetBlock, self).__init__()
686
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
687
+
688
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
689
+ conv_block = []
690
+ p = 0
691
+ if padding_type == 'reflect':
692
+ conv_block += [nn.ReflectionPad2d(1)]
693
+ elif padding_type == 'replicate':
694
+ conv_block += [nn.ReplicationPad2d(1)]
695
+ elif padding_type == 'zero':
696
+ p = 1
697
+ else:
698
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
699
+
700
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
701
+ norm_layer(dim),
702
+ nn.ReLU(True)]
703
+ if use_dropout:
704
+ conv_block += [nn.Dropout(0.5)]
705
+
706
+ p = 0
707
+ if padding_type == 'reflect':
708
+ conv_block += [nn.ReflectionPad2d(1)]
709
+ elif padding_type == 'replicate':
710
+ conv_block += [nn.ReplicationPad2d(1)]
711
+ elif padding_type == 'zero':
712
+ p = 1
713
+ else:
714
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
715
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
716
+ norm_layer(dim)]
717
+
718
+ return nn.Sequential(*conv_block)
719
+
720
+ def forward(self, x):
721
+ out = x + self.conv_block(x)
722
+ return out
723
+
724
+
725
+ # Defines the Unet generator.
726
+ # |num_downs|: number of downsamplings in UNet. For example,
727
+ # if |num_downs| == 7, image of size 128x128 will become of size 1x1
728
+ # at the bottleneck
729
+ class UnetGenerator(nn.Module):
730
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
731
+ norm_layer=nn.BatchNorm2d, use_dropout=False):
732
+ super(UnetGenerator, self).__init__()
733
+
734
+ # construct unet structure
735
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
736
+ for i in range(num_downs - 5):
737
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
738
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
739
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
740
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
741
+ unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
742
+
743
+ self.model = unet_block
744
+
745
+ def forward(self, input):
746
+ return self.model(input)
747
+
748
+ class UnetResGenerator(nn.Module):
749
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
750
+ norm_layer=nn.BatchNorm2d, use_dropout=False, nres=1):
751
+ super(UnetResGenerator, self).__init__()
752
+
753
+ # construct unet structure
754
+ unet_block = UnetSkipConnectionResBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, nres=nres)
755
+ for i in range(num_downs - 5):
756
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
757
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
758
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
759
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
760
+ unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
761
+
762
+ self.model = unet_block
763
+
764
+ def forward(self, input):
765
+ return self.model(input)
766
+
767
+ class PartUnet(nn.Module):
768
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
769
+ norm_layer=nn.BatchNorm2d, use_dropout=False):
770
+ super(PartUnet, self).__init__()
771
+
772
+ # construct unet structure
773
+ # 3 downs
774
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
775
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
776
+ unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
777
+
778
+ self.model = unet_block
779
+
780
+ def forward(self, input):
781
+ return self.model(input)
782
+
783
+ class PartUnetRes(nn.Module):
784
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
785
+ norm_layer=nn.BatchNorm2d, use_dropout=False, nres=1):
786
+ super(PartUnetRes, self).__init__()
787
+
788
+ # construct unet structure
789
+ # 3 downs
790
+ unet_block = UnetSkipConnectionResBlock(ngf * 2, ngf * 4, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, nres=nres)
791
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
792
+ unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
793
+
794
+ self.model = unet_block
795
+
796
+ def forward(self, input):
797
+ return self.model(input)
798
+
799
+ class PartUnet2(nn.Module):
800
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
801
+ norm_layer=nn.BatchNorm2d, use_dropout=False):
802
+ super(PartUnet2, self).__init__()
803
+
804
+ # construct unet structure
805
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 2, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
806
+ for i in range(num_downs - 3):
807
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
808
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
809
+ unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
810
+
811
+ self.model = unet_block
812
+
813
+ def forward(self, input):
814
+ return self.model(input)
815
+
816
+ class PartUnet2Res(nn.Module):
817
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
818
+ norm_layer=nn.BatchNorm2d, use_dropout=False, nres=1):
819
+ super(PartUnet2Res, self).__init__()
820
+
821
+ # construct unet structure
822
+ unet_block = UnetSkipConnectionResBlock(ngf * 2, ngf * 2, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, nres=nres)
823
+ for i in range(num_downs - 3):
824
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
825
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
826
+ unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
827
+
828
+ self.model = unet_block
829
+
830
+ def forward(self, input):
831
+ return self.model(input)
832
+
833
+ class PartUnet2Style(nn.Module):
834
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, extra_channel=2,
835
+ norm_layer=nn.BatchNorm2d, use_dropout=False):
836
+ super(PartUnet2Style, self).__init__()
837
+ # construct unet structure
838
+ unet_block = UnetSkipConnectionStyleBlock(ngf * 2, ngf * 2, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, extra_channel=extra_channel)
839
+ for i in range(num_downs - 3):
840
+ unet_block = UnetSkipConnectionStyleBlock(ngf * 2, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout, extra_channel=extra_channel)
841
+ unet_block = UnetSkipConnectionStyleBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, extra_channel=extra_channel)
842
+ unet_block = UnetSkipConnectionStyleBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer, extra_channel=extra_channel)
843
+
844
+ self.model = unet_block
845
+
846
+ def forward(self, input, cate):
847
+ return self.model(input, cate)
848
+
849
+ class PartUnet2ResStyle(nn.Module):
850
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, extra_channel=2,
851
+ norm_layer=nn.BatchNorm2d, use_dropout=False, nres=1):
852
+ super(PartUnet2ResStyle, self).__init__()
853
+ # construct unet structure
854
+ unet_block = UnetSkipConnectionResStyleBlock(ngf * 2, ngf * 2, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, extra_channel=extra_channel, nres=nres)
855
+ for i in range(num_downs - 3):
856
+ unet_block = UnetSkipConnectionStyleBlock(ngf * 2, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout, extra_channel=extra_channel)
857
+ unet_block = UnetSkipConnectionStyleBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, extra_channel=extra_channel)
858
+ unet_block = UnetSkipConnectionStyleBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer, extra_channel=extra_channel)
859
+
860
+ self.model = unet_block
861
+
862
+ def forward(self, input, cate):
863
+ return self.model(input, cate)
864
+
865
+
866
+ # Defines the submodule with skip connection.
867
+ # X -------------------identity---------------------- X
868
+ # |-- downsampling -- |submodule| -- upsampling --|
869
+ class UnetSkipConnectionBlock(nn.Module):
870
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
871
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
872
+ super(UnetSkipConnectionBlock, self).__init__()
873
+ self.outermost = outermost
874
+ if type(norm_layer) == functools.partial:
875
+ use_bias = norm_layer.func == nn.InstanceNorm2d
876
+ else:
877
+ use_bias = norm_layer == nn.InstanceNorm2d
878
+ if input_nc is None:
879
+ input_nc = outer_nc
880
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
881
+ stride=2, padding=1, bias=use_bias)
882
+ downrelu = nn.LeakyReLU(0.2, True)
883
+ downnorm = norm_layer(inner_nc)
884
+ uprelu = nn.ReLU(True)
885
+ upnorm = norm_layer(outer_nc)
886
+
887
+ if outermost:
888
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
889
+ kernel_size=4, stride=2,
890
+ padding=1)
891
+ down = [downconv]
892
+ up = [uprelu, upconv, nn.Tanh()]
893
+ model = down + [submodule] + up
894
+ elif innermost:
895
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
896
+ kernel_size=4, stride=2,
897
+ padding=1, bias=use_bias)
898
+ down = [downrelu, downconv]
899
+ up = [uprelu, upconv, upnorm]
900
+ model = down + up
901
+ else:
902
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
903
+ kernel_size=4, stride=2,
904
+ padding=1, bias=use_bias)
905
+ down = [downrelu, downconv, downnorm]
906
+ up = [uprelu, upconv, upnorm]
907
+
908
+ if use_dropout:
909
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
910
+ else:
911
+ model = down + [submodule] + up
912
+
913
+ self.model = nn.Sequential(*model)
914
+
915
+ def forward(self, x):
916
+ if self.outermost:
917
+ return self.model(x)
918
+ else:
919
+ return torch.cat([x, self.model(x)], 1)
920
+
921
+ class UnetSkipConnectionResBlock(nn.Module):
922
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
923
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, nres=1):
924
+ super(UnetSkipConnectionResBlock, self).__init__()
925
+ self.outermost = outermost
926
+ if type(norm_layer) == functools.partial:
927
+ use_bias = norm_layer.func == nn.InstanceNorm2d
928
+ else:
929
+ use_bias = norm_layer == nn.InstanceNorm2d
930
+ if input_nc is None:
931
+ input_nc = outer_nc
932
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
933
+ stride=2, padding=1, bias=use_bias)
934
+ downrelu = nn.LeakyReLU(0.2, True)
935
+ downnorm = norm_layer(inner_nc)
936
+ uprelu = nn.ReLU(True)
937
+ upnorm = norm_layer(outer_nc)
938
+
939
+ if outermost:
940
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
941
+ kernel_size=4, stride=2,
942
+ padding=1)
943
+ down = [downconv]
944
+ up = [uprelu, upconv, nn.Tanh()]
945
+ model = down + [submodule] + up
946
+ elif innermost:
947
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
948
+ kernel_size=4, stride=2,
949
+ padding=1, bias=use_bias)
950
+ down = [downrelu, downconv, downrelu]
951
+ up = [upconv, upnorm]
952
+ model = down
953
+ # resblock: conv norm relu conv norm +
954
+ for i in range(nres):
955
+ model += [ResnetBlock(inner_nc, padding_type='reflect', norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
956
+ model += up
957
+ #model = down + [submodule] + up
958
+ print('UnetSkipConnectionResBlock','nres',nres,'inner_nc',inner_nc)
959
+ else:
960
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
961
+ kernel_size=4, stride=2,
962
+ padding=1, bias=use_bias)
963
+ down = [downrelu, downconv, downnorm]
964
+ up = [uprelu, upconv, upnorm]
965
+
966
+ if use_dropout:
967
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
968
+ else:
969
+ model = down + [submodule] + up
970
+
971
+ self.model = nn.Sequential(*model)
972
+
973
+ def forward(self, x):
974
+ if self.outermost:
975
+ return self.model(x)
976
+ else:
977
+ return torch.cat([x, self.model(x)], 1)
978
+
979
+ class UnetSkipConnectionStyleBlock(nn.Module):
980
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
981
+ submodule=None, outermost=False, innermost=False,
982
+ extra_channel=2, norm_layer=nn.BatchNorm2d, use_dropout=False):
983
+ super(UnetSkipConnectionStyleBlock, self).__init__()
984
+ self.outermost = outermost
985
+ self.innermost = innermost
986
+ self.extra_channel = extra_channel
987
+ if type(norm_layer) == functools.partial:
988
+ use_bias = norm_layer.func == nn.InstanceNorm2d
989
+ else:
990
+ use_bias = norm_layer == nn.InstanceNorm2d
991
+ if input_nc is None:
992
+ input_nc = outer_nc
993
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
994
+ stride=2, padding=1, bias=use_bias)
995
+ downrelu = nn.LeakyReLU(0.2, True)
996
+ downnorm = norm_layer(inner_nc)
997
+ uprelu = nn.ReLU(True)
998
+ upnorm = norm_layer(outer_nc)
999
+
1000
+ if outermost:
1001
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
1002
+ kernel_size=4, stride=2,
1003
+ padding=1)
1004
+ down = [downconv]
1005
+ up = [uprelu, upconv, nn.Tanh()]
1006
+ model = down + [submodule] + up
1007
+ elif innermost:
1008
+ upconv = nn.ConvTranspose2d(inner_nc+extra_channel, outer_nc,
1009
+ kernel_size=4, stride=2,
1010
+ padding=1, bias=use_bias)
1011
+ down = [downrelu, downconv]
1012
+ up = [uprelu, upconv, upnorm]
1013
+ model = down + up
1014
+ else:
1015
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
1016
+ kernel_size=4, stride=2,
1017
+ padding=1, bias=use_bias)
1018
+ down = [downrelu, downconv, downnorm]
1019
+ up = [uprelu, upconv, upnorm]
1020
+
1021
+ if use_dropout:
1022
+ up = up + [nn.Dropout(0.5)]
1023
+ model = down + [submodule] + up
1024
+
1025
+ self.model = nn.Sequential(*model)
1026
+
1027
+ self.downmodel = nn.Sequential(*down)
1028
+ self.upmodel = nn.Sequential(*up)
1029
+ self.submodule = submodule
1030
+
1031
+ def forward(self, x, cate):# cate [bs,c]
1032
+ if self.innermost:
1033
+ y1 = self.downmodel(x)
1034
+ [bs,c,h,w] = y1.shape
1035
+ map = cate.repeat(h,w,1,1).permute([2,3,0,1])
1036
+ y2 = torch.cat([y1,map], 1)
1037
+ y3 = self.upmodel(y2)
1038
+ return torch.cat([x, y3], 1)
1039
+ else:
1040
+ y1 = self.downmodel(x)
1041
+ y2 = self.submodule(y1,cate)
1042
+ y3 = self.upmodel(y2)
1043
+ if self.outermost:
1044
+ return y3
1045
+ else:
1046
+ return torch.cat([x, y3], 1)
1047
+
1048
+ class UnetSkipConnectionResStyleBlock(nn.Module):
1049
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
1050
+ submodule=None, outermost=False, innermost=False,
1051
+ extra_channel=2, norm_layer=nn.BatchNorm2d, use_dropout=False, nres=1):
1052
+ super(UnetSkipConnectionResStyleBlock, self).__init__()
1053
+ self.outermost = outermost
1054
+ self.innermost = innermost
1055
+ self.extra_channel = extra_channel
1056
+ if type(norm_layer) == functools.partial:
1057
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1058
+ else:
1059
+ use_bias = norm_layer == nn.InstanceNorm2d
1060
+ if input_nc is None:
1061
+ input_nc = outer_nc
1062
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
1063
+ stride=2, padding=1, bias=use_bias)
1064
+ downrelu = nn.LeakyReLU(0.2, True)
1065
+ downnorm = norm_layer(inner_nc)
1066
+ uprelu = nn.ReLU(True)
1067
+ upnorm = norm_layer(outer_nc)
1068
+
1069
+ if outermost:
1070
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
1071
+ kernel_size=4, stride=2,
1072
+ padding=1)
1073
+ down = [downconv]
1074
+ up = [uprelu, upconv, nn.Tanh()]
1075
+ model = down + [submodule] + up
1076
+ elif innermost:
1077
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
1078
+ kernel_size=4, stride=2,
1079
+ padding=1, bias=use_bias)
1080
+ down = [downrelu, downconv, downrelu]
1081
+ up = [nn.Conv2d(inner_nc+extra_channel, inner_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
1082
+ norm_layer(inner_nc),
1083
+ nn.ReLU(True)]
1084
+ for i in range(nres):
1085
+ up += [ResnetBlock(inner_nc, padding_type='reflect', norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
1086
+ up += [ upconv, upnorm]
1087
+ model = down + up
1088
+ print('UnetSkipConnectionResStyleBlock','nres',nres,'inner_nc',inner_nc)
1089
+ else:
1090
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
1091
+ kernel_size=4, stride=2,
1092
+ padding=1, bias=use_bias)
1093
+ down = [downrelu, downconv, downnorm]
1094
+ up = [uprelu, upconv, upnorm]
1095
+
1096
+ if use_dropout:
1097
+ up = up + [nn.Dropout(0.5)]
1098
+ model = down + [submodule] + up
1099
+
1100
+ self.model = nn.Sequential(*model)
1101
+
1102
+ self.downmodel = nn.Sequential(*down)
1103
+ self.upmodel = nn.Sequential(*up)
1104
+ self.submodule = submodule
1105
+
1106
+ def forward(self, x, cate):# cate [bs,c]
1107
+ # concate in the innermost block
1108
+ if self.innermost:
1109
+ y1 = self.downmodel(x)
1110
+ [bs,c,h,w] = y1.shape
1111
+ map = cate.repeat(h,w,1,1).permute([2,3,0,1])
1112
+ y2 = torch.cat([y1,map], 1)
1113
+ y3 = self.upmodel(y2)
1114
+ return torch.cat([x, y3], 1)
1115
+ else:
1116
+ y1 = self.downmodel(x)
1117
+ y2 = self.submodule(y1,cate)
1118
+ y3 = self.upmodel(y2)
1119
+ if self.outermost:
1120
+ return y3
1121
+ else:
1122
+ return torch.cat([x, y3], 1)
1123
+
1124
+ # Defines the PatchGAN discriminator with the specified arguments.
1125
+ class NLayerDiscriminator(nn.Module):
1126
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
1127
+ super(NLayerDiscriminator, self).__init__()
1128
+ if type(norm_layer) == functools.partial:
1129
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1130
+ else:
1131
+ use_bias = norm_layer == nn.InstanceNorm2d
1132
+
1133
+ kw = 4
1134
+ padw = 1
1135
+ sequence = [
1136
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
1137
+ nn.LeakyReLU(0.2, True)
1138
+ ]
1139
+
1140
+ nf_mult = 1
1141
+ nf_mult_prev = 1
1142
+ for n in range(1, n_layers):
1143
+ nf_mult_prev = nf_mult
1144
+ nf_mult = min(2**n, 8)
1145
+ sequence += [
1146
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
1147
+ kernel_size=kw, stride=2, padding=padw, bias=use_bias),
1148
+ norm_layer(ndf * nf_mult),
1149
+ nn.LeakyReLU(0.2, True)
1150
+ ]
1151
+
1152
+ nf_mult_prev = nf_mult
1153
+ nf_mult = min(2**n_layers, 8)
1154
+ sequence += [
1155
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
1156
+ kernel_size=kw, stride=1, padding=padw, bias=use_bias),
1157
+ norm_layer(ndf * nf_mult),
1158
+ nn.LeakyReLU(0.2, True)
1159
+ ]
1160
+
1161
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
1162
+
1163
+ if use_sigmoid:#no_lsgan, use sigmoid before calculating bceloss(binary cross entropy)
1164
+ sequence += [nn.Sigmoid()]
1165
+
1166
+ self.model = nn.Sequential(*sequence)
1167
+
1168
+ def forward(self, input):
1169
+ return self.model(input)
1170
+
1171
+
1172
+ class PixelDiscriminator(nn.Module):
1173
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
1174
+ super(PixelDiscriminator, self).__init__()
1175
+ if type(norm_layer) == functools.partial:
1176
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1177
+ else:
1178
+ use_bias = norm_layer == nn.InstanceNorm2d
1179
+
1180
+ self.net = [
1181
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
1182
+ nn.LeakyReLU(0.2, True),
1183
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
1184
+ norm_layer(ndf * 2),
1185
+ nn.LeakyReLU(0.2, True),
1186
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
1187
+
1188
+ if use_sigmoid:
1189
+ self.net.append(nn.Sigmoid())
1190
+
1191
+ self.net = nn.Sequential(*self.net)
1192
+
1193
+ def forward(self, input):
1194
+ return self.net(input)
APDrawingGAN2/models/test_model.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_model import BaseModel
2
+ from . import networks
3
+ import torch
4
+
5
+
6
+ class TestModel(BaseModel):
7
+ def name(self):
8
+ return 'TestModel'
9
+
10
+ @staticmethod
11
+ def modify_commandline_options(parser, is_train=True):
12
+ assert not is_train, 'TestModel cannot be used in train mode'
13
+ # uncomment because default CycleGAN did not use dropout ( parser.set_defaults(no_dropout=True) )
14
+ # parser = CycleGANModel.modify_commandline_options(parser, is_train=False)
15
+ parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch')# no_lsgan=True, use_lsgan=False
16
+ parser.set_defaults(dataset_mode='single')
17
+ parser.set_defaults(auxiliary_root='auxiliaryeye2o')
18
+ parser.set_defaults(use_local=True, hair_local=True, bg_local=True)
19
+ parser.set_defaults(nose_ae=True, others_ae=True, compactmask=True, MOUTH_H=56)
20
+ parser.set_defaults(soft_border=1)
21
+ parser.add_argument('--nnG_hairc', type=int, default=6, help='nnG for hair classifier')
22
+ parser.add_argument('--use_resnet', action='store_true', help='use resnet for generator')
23
+
24
+ parser.add_argument('--model_suffix', type=str, default='',
25
+ help='In checkpoints_dir, [which_epoch]_net_G[model_suffix].pth will'
26
+ ' be loaded as the generator of TestModel')
27
+
28
+ return parser
29
+
30
+ def initialize(self, opt):
31
+ assert(not opt.isTrain)
32
+ BaseModel.initialize(self, opt)
33
+
34
+ # specify the training losses you want to print out. The program will call base_model.get_current_losses
35
+ self.loss_names = []
36
+ # specify the images you want to save/display. The program will call base_model.get_current_visuals
37
+ self.visual_names = ['real_A', 'fake_B']
38
+ # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
39
+ self.model_names = ['G' + opt.model_suffix]
40
+ self.auxiliary_model_names = []
41
+ if self.opt.use_local:
42
+ self.model_names += ['GLEyel','GLEyer','GLNose','GLMouth','GLHair','GLBG','GCombine']
43
+ self.auxiliary_model_names += ['CLm','CLh']
44
+ # auxiliary nets for local output refinement
45
+ if self.opt.nose_ae:
46
+ self.auxiliary_model_names += ['AE']
47
+ if self.opt.others_ae:
48
+ self.auxiliary_model_names += ['AEel','AEer','AEmowhite','AEmoblack']
49
+ print('model_names', self.model_names)
50
+ print('auxiliary_model_names', self.auxiliary_model_names)
51
+
52
+ # load/define networks
53
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
54
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
55
+ opt.nnG)
56
+ print('netG', opt.netG)
57
+ if self.opt.use_local:
58
+ netlocal1 = 'partunet' if self.opt.use_resnet == 0 else 'resnet_nblocks'
59
+ netlocal2 = 'partunet2' if self.opt.use_resnet == 0 else 'resnet_6blocks'
60
+ netlocal2_style = 'partunet2style' if self.opt.use_resnet == 0 else 'resnet_style2_6blocks'
61
+ self.netGLEyel = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
62
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
63
+ self.netGLEyer = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
64
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
65
+ self.netGLNose = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
66
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
67
+ self.netGLMouth = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
68
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
69
+ self.netGLHair = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2_style, opt.norm,
70
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=4,
71
+ extra_channel=3)
72
+ self.netGLBG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2, opt.norm,
73
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=4)
74
+ # by default combiner_type is combiner, which uses resnet
75
+ print('combiner_type', self.opt.combiner_type)
76
+ self.netGCombine = networks.define_G(2*opt.output_nc, opt.output_nc, opt.ngf, self.opt.combiner_type, opt.norm,
77
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 2)
78
+ # auxiliary classifiers for mouth and hair
79
+ ratio = self.opt.fineSize / 256
80
+ self.MOUTH_H = int(self.opt.MOUTH_H * ratio)
81
+ self.MOUTH_W = int(self.opt.MOUTH_W * ratio)
82
+ self.netCLm = networks.define_G(opt.input_nc, 2, opt.ngf, 'classifier', opt.norm,
83
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
84
+ nnG = 3, ae_h = self.MOUTH_H, ae_w = self.MOUTH_W)
85
+ self.netCLh = networks.define_G(opt.input_nc, 3, opt.ngf, 'classifier', opt.norm,
86
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
87
+ nnG = opt.nnG_hairc, ae_h = opt.fineSize, ae_w = opt.fineSize)
88
+ # ==================================auxiliary nets (loaded, parameters fixed)=============================
89
+ if self.opt.use_local and self.opt.nose_ae:
90
+ ratio = self.opt.fineSize / 256
91
+ NOSE_H = self.opt.NOSE_H * ratio
92
+ NOSE_W = self.opt.NOSE_W * ratio
93
+ self.netAE = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
94
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
95
+ latent_dim=self.opt.ae_latentno, ae_h=NOSE_H, ae_w=NOSE_W)
96
+ self.set_requires_grad(self.netAE, False)
97
+ if self.opt.use_local and self.opt.others_ae:
98
+ ratio = self.opt.fineSize / 256
99
+ EYE_H = self.opt.EYE_H * ratio
100
+ EYE_W = self.opt.EYE_W * ratio
101
+ MOUTH_H = self.opt.MOUTH_H * ratio
102
+ MOUTH_W = self.opt.MOUTH_W * ratio
103
+ self.netAEel = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
104
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
105
+ latent_dim=self.opt.ae_latenteye, ae_h=EYE_H, ae_w=EYE_W)
106
+ self.netAEer = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
107
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
108
+ latent_dim=self.opt.ae_latenteye, ae_h=EYE_H, ae_w=EYE_W)
109
+ self.netAEmowhite = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
110
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
111
+ latent_dim=self.opt.ae_latentmo, ae_h=MOUTH_H, ae_w=MOUTH_W)
112
+ self.netAEmoblack = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
113
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
114
+ latent_dim=self.opt.ae_latentmo, ae_h=MOUTH_H, ae_w=MOUTH_W)
115
+ self.set_requires_grad(self.netAEel, False)
116
+ self.set_requires_grad(self.netAEer, False)
117
+ self.set_requires_grad(self.netAEmowhite, False)
118
+ self.set_requires_grad(self.netAEmoblack, False)
119
+
120
+ # assigns the model to self.netG_[suffix] so that it can be loaded
121
+ # please see BaseModel.load_networks
122
+ setattr(self, 'netG' + opt.model_suffix, self.netG)
123
+
124
+ def set_input(self, input):
125
+ # we need to use single_dataset mode
126
+ self.real_A = input['A'].to(self.device)
127
+ self.image_paths = input['A_paths']
128
+ self.batch_size = len(self.image_paths)
129
+ if self.opt.use_local:
130
+ self.real_A_eyel = input['eyel_A'].to(self.device)
131
+ self.real_A_eyer = input['eyer_A'].to(self.device)
132
+ self.real_A_nose = input['nose_A'].to(self.device)
133
+ self.real_A_mouth = input['mouth_A'].to(self.device)
134
+ self.center = input['center']
135
+ if self.opt.soft_border:
136
+ self.softel = input['soft_eyel_mask'].to(self.device)
137
+ self.softer = input['soft_eyer_mask'].to(self.device)
138
+ self.softno = input['soft_nose_mask'].to(self.device)
139
+ self.softmo = input['soft_mouth_mask'].to(self.device)
140
+ if self.opt.compactmask:
141
+ self.cmask = input['cmask'].to(self.device)
142
+ self.cmask1 = self.cmask*2-1#[0,1]->[-1,1]
143
+ self.cmaskel = input['cmaskel'].to(self.device)
144
+ self.cmask1el = self.cmaskel*2-1
145
+ self.cmasker = input['cmasker'].to(self.device)
146
+ self.cmask1er = self.cmasker*2-1
147
+ self.cmaskmo = input['cmaskmo'].to(self.device)
148
+ self.cmask1mo = self.cmaskmo*2-1
149
+ self.real_A_hair = input['hair_A'].to(self.device)
150
+ self.mask = input['mask'].to(self.device) # mask for non-eyes,nose,mouth
151
+ self.mask2 = input['mask2'].to(self.device) # mask for non-bg
152
+ self.real_A_bg = input['bg_A'].to(self.device)
153
+
154
+ def getonehot(self,outputs,classes):
155
+ [maxv,index] = torch.max(outputs,1)
156
+ y = torch.unsqueeze(index,1)
157
+ onehot = torch.FloatTensor(self.batch_size,classes).to(self.device)
158
+ onehot.zero_()
159
+ onehot.scatter_(1,y,1)
160
+ return onehot
161
+
162
+ def forward(self):
163
+ if not self.opt.use_local:
164
+ self.fake_B = self.netG(self.real_A)
165
+ else:
166
+ self.fake_B0 = self.netG(self.real_A)
167
+ # EYES, MOUTH
168
+ outputs1 = self.netCLm(self.real_A_mouth)
169
+ onehot1 = self.getonehot(outputs1,2)
170
+
171
+ if not self.opt.others_ae:
172
+ fake_B_eyel = self.netGLEyel(self.real_A_eyel)
173
+ fake_B_eyer = self.netGLEyer(self.real_A_eyer)
174
+ fake_B_mouth = self.netGLMouth(self.real_A_mouth)
175
+ else: # use AE that only constains compact region, need cmask!
176
+ self.fake_B_eyel1 = self.netGLEyel(self.real_A_eyel)
177
+ self.fake_B_eyer1 = self.netGLEyer(self.real_A_eyer)
178
+ self.fake_B_mouth1 = self.netGLMouth(self.real_A_mouth)
179
+ self.fake_B_eyel2,_ = self.netAEel(self.fake_B_eyel1)
180
+ self.fake_B_eyer2,_ = self.netAEer(self.fake_B_eyer1)
181
+ # USE 2 AEs
182
+ self.fake_B_mouth2 = torch.FloatTensor(self.batch_size,self.opt.output_nc,self.MOUTH_H,self.MOUTH_W).to(self.device)
183
+ for i in range(self.batch_size):
184
+ if onehot1[i][0] == 1:
185
+ self.fake_B_mouth2[i],_ = self.netAEmowhite(self.fake_B_mouth1[i].unsqueeze(0))
186
+ #print('AEmowhite')
187
+ elif onehot1[i][1] == 1:
188
+ self.fake_B_mouth2[i],_ = self.netAEmoblack(self.fake_B_mouth1[i].unsqueeze(0))
189
+ #print('AEmoblack')
190
+ fake_B_eyel = self.add_with_mask(self.fake_B_eyel2,self.fake_B_eyel1,self.cmaskel)
191
+ fake_B_eyer = self.add_with_mask(self.fake_B_eyer2,self.fake_B_eyer1,self.cmasker)
192
+ fake_B_mouth = self.add_with_mask(self.fake_B_mouth2,self.fake_B_mouth1,self.cmaskmo)
193
+ # NOSE
194
+ if not self.opt.nose_ae:
195
+ fake_B_nose = self.netGLNose(self.real_A_nose)
196
+ else: # use AE that only constains compact region, need cmask!
197
+ self.fake_B_nose1 = self.netGLNose(self.real_A_nose)
198
+ self.fake_B_nose2,_ = self.netAE(self.fake_B_nose1)
199
+ fake_B_nose = self.add_with_mask(self.fake_B_nose2,self.fake_B_nose1,self.cmask)
200
+
201
+ # HAIR, BG AND PARTCOMBINE
202
+ outputs2 = self.netCLh(self.real_A_hair)
203
+ onehot2 = self.getonehot(outputs2,3)
204
+
205
+ fake_B_hair = self.netGLHair(self.real_A_hair,onehot2)
206
+ fake_B_bg = self.netGLBG(self.real_A_bg)
207
+ self.fake_B_hair = self.masked(fake_B_hair,self.mask*self.mask2)
208
+ self.fake_B_bg = self.masked(fake_B_bg,self.inverse_mask(self.mask2))
209
+ if not self.opt.compactmask:
210
+ self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,fake_B_bg,self.mask*self.mask2,self.inverse_mask(self.mask2),self.opt.comb_op)
211
+ else:
212
+ self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,fake_B_bg,self.mask*self.mask2,self.inverse_mask(self.mask2),self.opt.comb_op,self.opt.region_enm,self.cmaskel,self.cmasker,self.cmask,self.cmaskmo)
213
+
214
+ self.fake_B = self.netGCombine(torch.cat([self.fake_B0,self.fake_B1],1))
APDrawingGAN2/options/__init__.py ADDED
File without changes
APDrawingGAN2/options/base_options.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from util import util
4
+ import torch
5
+ import models
6
+ import data
7
+
8
+
9
+ class BaseOptions():
10
+ def __init__(self):
11
+ self.initialized = False
12
+
13
+ def initialize(self, parser):
14
+ parser.add_argument('--dataroot', type=str, default='', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
15
+ parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
16
+ parser.add_argument('--loadSize', type=int, default=512, help='scale images to this size')
17
+ parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
18
+ parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
19
+ parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels')
20
+ parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
21
+ parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
22
+ parser.add_argument('--netD', type=str, default='basic', help='selects model to use for netD')
23
+ parser.add_argument('--netG', type=str, default='unet_256', help='selects model to use for netG')
24
+ parser.add_argument('--nnG', type=int, default=9, help='specify nblock for resnet_nblocks, ndown for unet for unet_ndown')
25
+ parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
26
+ parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
27
+ parser.add_argument('--gpu_ids_p', type=str, default='0', help='gpu ids for pretrained auxiliary models: e.g. 0 0,1,2, 0,2. use -1 for CPU')
28
+ parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
29
+ parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [unaligned | aligned | single]')
30
+ parser.add_argument('--model', type=str, default='apdrawing',
31
+ help='chooses which model to use. cycle_gan, pix2pix, test, autoencoder')
32
+ parser.add_argument('--use_local', action='store_true', help='use local part network')
33
+ parser.add_argument('--lm_dir', type=str, default='dataset/landmark/', help='path to facial landmarks')
34
+ parser.add_argument('--nose_ae', action='store_true', help='use nose autoencoder')
35
+ parser.add_argument('--others_ae', action='store_true', help='use autoencoder for eyes and mouth too')
36
+ parser.add_argument('--nose_ae_net', type=str, default='autoencoderfc', help='net for nose autoencoder [autoencoder | autoencoderfc]')
37
+ parser.add_argument('--comb_op', type=int, default=1, help='use min-pooling(1) or max-pooling(0) for overlapping regions')
38
+ parser.add_argument('--hair_local', action='store_true', help='add hair part')
39
+ parser.add_argument('--bg_local', action='store_true', help='use background mask to seperate background')
40
+ parser.add_argument('--bg_dir', default='dataset/mask/bg/', type=str, help='choose bg_dir')
41
+ parser.add_argument('--region_enm', type=int, default=0, help='region type for eyes nose mouth: 0 for rectangle, 1 for campact mask in rectangle, 2 for mask no rectangle (1,2 must have compactmask, 0 use compactmask for AE)')
42
+ parser.add_argument('--soft_border', type=int, default=0, help='use mask with soft border')
43
+ parser.add_argument('--EYE_H', type=int, default=40, help='EYE_H')
44
+ parser.add_argument('--EYE_W', type=int, default=56, help='EYE_W')
45
+ parser.add_argument('--NOSE_H', type=int, default=48, help='NOSE_H')
46
+ parser.add_argument('--NOSE_W', type=int, default=48, help='NOSE_W')
47
+ parser.add_argument('--MOUTH_H', type=int, default=40, help='MOUTH_H')
48
+ parser.add_argument('--MOUTH_W', type=int, default=64, help='MOUTH_W')
49
+ parser.add_argument('--average_pos', action='store_true', help='use avg pos in partCombiner')
50
+ parser.add_argument('--combiner_type', type=str, default='combiner', help='choose combiner type')
51
+ parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
52
+ parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
53
+ parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
54
+ parser.add_argument('--auxiliary_root', type=str, default='auxiliary', help='auxiliary model folder')
55
+ parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
56
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
57
+ parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
58
+ parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
59
+ parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
60
+ parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
61
+ parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
62
+ parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
63
+ parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
64
+ parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
65
+ parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
66
+ parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
67
+ parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
68
+ parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
69
+ parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{loadSize}')
70
+ # compact mask
71
+ parser.add_argument('--compactmask', action='store_true', help='use compact mask as input and apply to loss')# "when you calculate the (ae) loss, you should also restrict to nose pixels"
72
+ parser.add_argument('--cmask_dir', type=str, default='dataset/mask/', help='compact mask directory')
73
+ parser.add_argument('--ae_latentno', type=int, default=1024 ,help='latent space dim for pretrained NOSE AEwithfc')
74
+ parser.add_argument('--ae_latentmo', type=int, default=1024 ,help='latent space dim for pretrained MOUTH AEwithfc')
75
+ parser.add_argument('--ae_latenteye', type=int, default=1024 ,help='latent space dim for pretrained EYEL/EYER AEwithfc')
76
+ parser.add_argument('--ae_small', type=int, default=0 ,help='use latent dim smaller than default 1024 in 4 AEs')
77
+ # below for autoencoder
78
+ parser.add_argument('--ae_latent', type=int, default=1024 ,help='latent space dim for autoencoderfc')
79
+ parser.add_argument('--ae_multiple', type=float, default=2 ,help='filter number change in ae encoder')
80
+ parser.add_argument('--ae_h', type=int, default=96 ,help='ae input h')
81
+ parser.add_argument('--ae_w', type=int, default=96 ,help='ae input w')
82
+ parser.add_argument('--ae_region', type=str, default='nose' ,help='autoencoder for which region')
83
+ parser.add_argument('--no_ae', action='store_true', help='no ae')
84
+ self.initialized = True
85
+ return parser
86
+
87
+ def gather_options(self):
88
+ # initialize parser with basic options
89
+ if not self.initialized:
90
+ parser = argparse.ArgumentParser(
91
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
92
+ parser = self.initialize(parser)
93
+
94
+ # get the basic options
95
+ opt, _ = parser.parse_known_args()
96
+
97
+ # modify model-related parser options
98
+ model_name = opt.model
99
+ model_option_setter = models.get_option_setter(model_name)
100
+ parser = model_option_setter(parser, self.isTrain)
101
+ opt, _ = parser.parse_known_args() # parse again with the new defaults
102
+
103
+ # modify dataset-related parser options
104
+ dataset_name = opt.dataset_mode
105
+ dataset_option_setter = data.get_option_setter(dataset_name)
106
+ parser = dataset_option_setter(parser, self.isTrain)
107
+
108
+ self.parser = parser
109
+
110
+ return parser.parse_args()
111
+
112
+ def print_options(self, opt):
113
+ message = ''
114
+ message += '----------------- Options ---------------\n'
115
+ for k, v in sorted(vars(opt).items()):
116
+ comment = ''
117
+ default = self.parser.get_default(k)
118
+ if v != default:
119
+ comment = '\t[default: %s]' % str(default)
120
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
121
+ message += '----------------- End -------------------'
122
+ print(message)
123
+
124
+ # save to the disk
125
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
126
+ util.mkdirs(expr_dir)
127
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
128
+ with open(file_name, 'wt') as opt_file:
129
+ opt_file.write(message)
130
+ opt_file.write('\n')
131
+
132
+ def parse(self, print=True):
133
+
134
+ opt = self.gather_options()
135
+ if opt.use_local:
136
+ opt.loadSize = opt.fineSize
137
+ if opt.region_enm in [1,2]:
138
+ opt.compactmask = True
139
+ if opt.nose_ae or opt.others_ae:
140
+ opt.compactmask = True
141
+ if opt.ae_latentno < 1024 and opt.ae_latentmo < 1024 and opt.ae_latenteye < 1024:
142
+ opt.ae_small = 1
143
+ opt.isTrain = self.isTrain # train or test
144
+
145
+ # process opt.suffix
146
+ if opt.suffix:
147
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
148
+ opt.name = opt.name + suffix
149
+
150
+ if self.isTrain and opt.pretrain:
151
+ opt.nose_ae = False
152
+ opt.others_ae = False
153
+ opt.compactmask = False
154
+ opt.chamfer_loss = False
155
+ if not self.isTrain and opt.pretrain:
156
+ opt.nose_ae = False
157
+ opt.others_ae = False
158
+ opt.compactmask = False
159
+ if opt.no_ae:
160
+ opt.nose_ae = False
161
+ opt.others_ae = False
162
+ opt.compactmask = False
163
+ if self.isTrain and opt.no_dtremap:
164
+ opt.dt_nonlinear = ''
165
+ opt.lambda_chamfer = 0.1
166
+ opt.lambda_chamfer2 = 0.1
167
+ if self.isTrain and opt.no_dt:
168
+ opt.chamfer_loss = False
169
+
170
+ if print:
171
+ self.print_options(opt)
172
+
173
+ # set gpu ids
174
+ str_ids = opt.gpu_ids.split(',')
175
+ opt.gpu_ids = []
176
+ for str_id in str_ids:
177
+ id = int(str_id)
178
+ if id >= 0:
179
+ opt.gpu_ids.append(id)
180
+ if len(opt.gpu_ids) > 0:
181
+ torch.cuda.set_device(opt.gpu_ids[0])
182
+
183
+ # set gpu ids
184
+ str_ids = opt.gpu_ids_p.split(',')
185
+ opt.gpu_ids_p = []
186
+ for str_id in str_ids:
187
+ id = int(str_id)
188
+ if id >= 0:
189
+ opt.gpu_ids_p.append(id)
190
+
191
+ self.opt = opt
192
+ return self.opt
APDrawingGAN2/options/test_options.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_options import BaseOptions
2
+
3
+
4
+ class TestOptions(BaseOptions):
5
+ def initialize(self, parser):
6
+ parser = BaseOptions.initialize(self, parser)
7
+ parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
8
+ parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
9
+ parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
10
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
11
+ parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
12
+ parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
13
+ parser.add_argument('--test_continuity_loss', action='store_true', help='get continuity value in test')
14
+ parser.add_argument('--netG_line', type=str, default='unet_512', help='selects model to use for netG_line')
15
+ parser.add_argument('--save2', action='store_true', help='only save real_A and fake_B')
16
+ parser.add_argument('--imagefolder', type=str, default='images', help='subfolder to save images')
17
+ parser.add_argument('--pretrain', action='store_true', help='pretrain stage, no dt loss, no ae')
18
+
19
+ parser.set_defaults(model='test')
20
+ # To avoid cropping, the loadSize should be the same as fineSize
21
+ parser.set_defaults(loadSize=parser.get_default('fineSize'))
22
+ self.isTrain = False
23
+ return parser
APDrawingGAN2/options/train_options.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_options import BaseOptions
2
+
3
+
4
+ class TrainOptions(BaseOptions):
5
+ def initialize(self, parser):
6
+ parser = BaseOptions.initialize(self, parser)
7
+ parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
8
+ parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
9
+ parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
10
+ parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
11
+ parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
12
+ parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
13
+ parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
14
+ parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
15
+ parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
16
+ parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
17
+ parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
18
+ parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
19
+ parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
20
+ parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
21
+ parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
22
+ parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
23
+ parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
24
+ parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine')
25
+ parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
26
+ # ============================================loss=========================================================
27
+ # chamfer loss
28
+ parser.add_argument('--chamfer_loss', action='store_true', help='use chamfer loss')
29
+ parser.add_argument('--chamfer_2way', action='store_true', help='use chamfer loss 2 way')
30
+ parser.add_argument('--chamfer_only_line', action='store_true', help='use chamfer only on lines')
31
+ parser.add_argument('--lambda_chamfer', type=float, default=0.1, help='weight for chamfer loss')
32
+ parser.add_argument('--lambda_chamfer2', type=float, default=0.1, help='weight for chamfer loss2')
33
+ parser.add_argument('--dt_nonlinear', type=str, default='', help='nonlinear remap on dt [atan | sigmoid | tanh]')
34
+ parser.add_argument('--dt_xmax', type=float, default=10, help='first mutiply dt to range [0,xmax], then use atan/sigmoid/tanh etc, to have more nonlinearity (not much nonlinearity in range [0,1])')
35
+ # line continuity loss
36
+ parser.add_argument('--continuity_loss', action='store_true', help='use line continuity loss')
37
+ parser.add_argument('--lambda_continuity', type=float, default=10.0, help='weight for continuity loss')
38
+ parser.add_argument('--emphasis_conti_face', action='store_true', help='constrain conti loss to pixels in original lines (avoid apply to background etc)')
39
+ parser.add_argument('--facemask_dir', type=str, default='dataset/mask/face/', help='mask folder to constrain conti loss to pixels in original lines')
40
+ # =====================================auxilary net structure===============================================
41
+ # dt & line net structure
42
+ parser.add_argument('--netG_dt', type=str, default='unet_512', help='selects model to use for netG_dt, for chamfer loss')
43
+ parser.add_argument('--netG_line', type=str, default='unet_512', help='selects model to use for netG_line, for chamfer loss')
44
+ # multiple discriminators
45
+ parser.add_argument('--discriminator_local', action='store_true', help='use six diffent local discriminator for 6 local regions')
46
+ parser.add_argument('--gan_loss_strategy', type=int, default=2, help='specify how to calculate gan loss for g, 1: average global and local discriminators; 2: not change global discriminator weight, 0.25 for local')
47
+ parser.add_argument('--addw_eye', type=float, default=1.0, help='additional weight for eye region')
48
+ parser.add_argument('--addw_nose', type=float, default=1.0, help='additional weight for nose region')
49
+ parser.add_argument('--addw_mouth', type=float, default=1.0, help='additional weight for mouth region')
50
+ parser.add_argument('--addw_hair', type=float, default=1.0, help='additional weight for hair region')
51
+ parser.add_argument('--addw_bg', type=float, default=1.0, help='additional weight for bg region')
52
+ # ==========================================ablation========================================================
53
+ parser.add_argument('--no_l1_loss', action='store_true', help='no l1 loss')
54
+ parser.add_argument('--no_G_local_loss', action='store_true', help='not using local transfer loss for local generator output')
55
+ parser.add_argument('--no_dtremap', action='store_true', help='no dt remap')
56
+ parser.add_argument('--no_dt', action='store_true', help='no dt')
57
+
58
+ parser.add_argument('--pretrain', action='store_true', help='pretrain stage, no dt loss, no ae')
59
+
60
+
61
+ self.isTrain = True
62
+ return parser
APDrawingGAN2/preprocess/combine_A_and_B.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ import argparse
5
+
6
+ parser = argparse.ArgumentParser('create image pairs')
7
+ parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')
8
+ parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')
9
+ parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')
10
+ parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000)
11
+ parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true')
12
+ args = parser.parse_args()
13
+
14
+ for arg in vars(args):
15
+ print('[%s] = ' % arg, getattr(args, arg))
16
+
17
+ splits = os.listdir(args.fold_A)
18
+
19
+ for sp in splits:
20
+ img_fold_A = os.path.join(args.fold_A, sp)
21
+ img_fold_B = os.path.join(args.fold_B, sp)
22
+ img_list = os.listdir(img_fold_A)
23
+ if args.use_AB:
24
+ img_list = [img_path for img_path in img_list if '_A.' in img_path]
25
+
26
+ num_imgs = min(args.num_imgs, len(img_list))
27
+ print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))
28
+ img_fold_AB = os.path.join(args.fold_AB, sp)
29
+ if not os.path.isdir(img_fold_AB):
30
+ os.makedirs(img_fold_AB)
31
+ print('split = %s, number of images = %d' % (sp, num_imgs))
32
+ for n in range(num_imgs):
33
+ name_A = img_list[n]
34
+ path_A = os.path.join(img_fold_A, name_A)
35
+ if args.use_AB:
36
+ name_B = name_A.replace('_A.', '_B.')
37
+ else:
38
+ name_B = name_A
39
+ path_B = os.path.join(img_fold_B, name_B)
40
+ if os.path.isfile(path_A) and os.path.isfile(path_B):
41
+ name_AB = name_A
42
+ if args.use_AB:
43
+ name_AB = name_AB.replace('_A.', '.') # remove _A
44
+ path_AB = os.path.join(img_fold_AB, name_AB)
45
+ im_A = cv2.imread(path_A, cv2.IMREAD_COLOR)
46
+ im_B = cv2.imread(path_B, cv2.IMREAD_COLOR)
47
+ im_AB = np.concatenate([im_A, im_B], 1)
48
+ cv2.imwrite(path_AB, im_AB)
APDrawingGAN2/preprocess/example/img_1701.jpg ADDED
APDrawingGAN2/preprocess/example/img_1701_aligned.png ADDED
APDrawingGAN2/preprocess/example/img_1701_aligned.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ 194 248
2
+ 314 249
3
+ 261 312
4
+ 209 368
5
+ 302 371
APDrawingGAN2/preprocess/example/img_1701_aligned_68lm.txt ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 120 261
2
+ 124 294
3
+ 129 326
4
+ 133 358
5
+ 142 388
6
+ 162 412
7
+ 190 430
8
+ 220 445
9
+ 253 449
10
+ 287 447
11
+ 317 432
12
+ 344 411
13
+ 362 385
14
+ 370 354
15
+ 375 322
16
+ 382 291
17
+ 385 258
18
+ 142 225
19
+ 161 209
20
+ 188 204
21
+ 215 208
22
+ 242 218
23
+ 269 218
24
+ 296 208
25
+ 324 206
26
+ 351 213
27
+ 369 231
28
+ 256 244
29
+ 256 264
30
+ 256 284
31
+ 256 305
32
+ 232 324
33
+ 244 328
34
+ 256 332
35
+ 267 329
36
+ 277 325
37
+ 172 252
38
+ 186 243
39
+ 203 243
40
+ 218 253
41
+ 203 257
42
+ 186 257
43
+ 290 254
44
+ 305 244
45
+ 322 246
46
+ 336 255
47
+ 322 260
48
+ 305 259
49
+ 210 368
50
+ 229 358
51
+ 245 352
52
+ 256 354
53
+ 267 352
54
+ 283 358
55
+ 300 368
56
+ 284 382
57
+ 268 388
58
+ 255 389
59
+ 244 388
60
+ 228 381
61
+ 220 368
62
+ 245 363
63
+ 256 364
64
+ 267 364
65
+ 290 368
66
+ 267 370
67
+ 255 372
68
+ 244 371
APDrawingGAN2/preprocess/example/img_1701_aligned_bgmask.png ADDED
APDrawingGAN2/preprocess/example/img_1701_aligned_eyelmask.png ADDED
APDrawingGAN2/preprocess/example/img_1701_aligned_eyermask.png ADDED
APDrawingGAN2/preprocess/example/img_1701_aligned_facemask.png ADDED
APDrawingGAN2/preprocess/example/img_1701_aligned_mouthmask.png ADDED
APDrawingGAN2/preprocess/example/img_1701_aligned_nosemask.png ADDED
APDrawingGAN2/preprocess/example/img_1701_facial5point.mat ADDED
Binary file (230 Bytes). View file
 
APDrawingGAN2/preprocess/face_align_512.m ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function [trans_img,trans_facial5point]=face_align_512(impath,facial5point,savedir)
2
+ % align the faces by similarity transformation.
3
+ % using 5 facial landmarks: 2 eyes, nose, 2 mouth corners.
4
+ % impath: path to image
5
+ % facial5point: 5x2 size, 5 facial landmark positions, detected by MTCNN
6
+ % savedir: savedir for cropped image and transformed facial landmarks
7
+
8
+ %% alignment settings
9
+ imgSize = [512,512];
10
+ coord5point = [180,230;
11
+ 300,230;
12
+ 240,301;
13
+ 186,365.6;
14
+ 294,365.6];%480x480
15
+ coord5point = (coord5point-240)/560 * 512 + 256;
16
+
17
+ %% face alignment
18
+
19
+ % load and align, resize image to imgSize
20
+ img = imread(impath);
21
+ facial5point = double(facial5point);
22
+ transf = cp2tform(facial5point, coord5point, 'similarity');
23
+ trans_img = imtransform(img, transf, 'XData', [1 imgSize(2)],...
24
+ 'YData', [1 imgSize(1)],...
25
+ 'Size', imgSize,...
26
+ 'FillValues', [255;255;255]);
27
+ trans_facial5point = round(tformfwd(transf,facial5point));
28
+
29
+
30
+ %% save results
31
+ if ~exist(savedir,'dir')
32
+ mkdir(savedir)
33
+ end
34
+ [~,name,~] = fileparts(impath);
35
+ % save trans_img
36
+ imwrite(trans_img, fullfile(savedir,[name,'_aligned.png']));
37
+ fprintf('write aligned image to %s\n',fullfile(savedir,[name,'_aligned.png']));
38
+ % save trans_facial5point
39
+ write_5pt(fullfile(savedir, [name, '_aligned.txt']), trans_facial5point);
40
+ fprintf('write transformed facial landmark to %s\n',fullfile(savedir,[name,'_aligned.txt']));
41
+
42
+ %% show results
43
+ imshow(trans_img); hold on;
44
+ plot(trans_facial5point(:,1),trans_facial5point(:,2),'b');
45
+ plot(trans_facial5point(:,1),trans_facial5point(:,2),'r+');
46
+
47
+ end
48
+
49
+ function [] = write_5pt(fn, trans_pt)
50
+ fid = fopen(fn, 'w');
51
+ for i = 1:5
52
+ fprintf(fid, '%d %d\n', trans_pt(i,1), trans_pt(i,2));%will be read as np.int32
53
+ end
54
+ fclose(fid);
55
+ end
APDrawingGAN2/preprocess/get_partmask.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os, glob, csv, shutil
3
+ import numpy as np
4
+ import dlib
5
+ import math
6
+ from shapely.geometry import Point
7
+ from shapely.geometry import Polygon
8
+ import sys
9
+
10
+ detector = dlib.get_frontal_face_detector()
11
+ predictor = dlib.shape_predictor('../checkpoints/shape_predictor_68_face_landmarks.dat')
12
+
13
+ def getfeats(featpath):
14
+ trans_points = np.empty([68,2],dtype=np.int64)
15
+ with open(featpath, 'r') as csvfile:
16
+ reader = csv.reader(csvfile, delimiter=' ')
17
+ for ind,row in enumerate(reader):
18
+ trans_points[ind,:] = row
19
+ return trans_points
20
+
21
+ def getinternal(lm1,lm2):
22
+ lminternal = []
23
+ if abs(lm1[1]-lm2[1]) > abs(lm1[0]-lm2[0]):
24
+ if lm1[1] > lm2[1]:
25
+ tmp = lm1
26
+ lm1 = lm2
27
+ lm2 = tmp
28
+ for y in range(lm1[1]+1,lm2[1]):
29
+ x = int(round(float(y-lm1[1])/(lm2[1]-lm1[1])*(lm2[0]-lm1[0])+lm1[0]))
30
+ lminternal.append((x,y))
31
+ else:
32
+ if lm1[0] > lm2[0]:
33
+ tmp = lm1
34
+ lm1 = lm2
35
+ lm2 = tmp
36
+ for x in range(lm1[0]+1,lm2[0]):
37
+ y = int(round(float(x-lm1[0])/(lm2[0]-lm1[0])*(lm2[1]-lm1[1])+lm1[1]))
38
+ lminternal.append((x,y))
39
+ return lminternal
40
+
41
+ def mulcross(p,x_1,x):#p-x_1,x-x_1
42
+ vp = [p[0]-x_1[0],p[1]-x_1[1]]
43
+ vq = [x[0]-x_1[0],x[1]-x_1[1]]
44
+ return vp[0]*vq[1]-vp[1]*vq[0]
45
+
46
+ def shape_to_np(shape, dtype="int"):
47
+ # initialize the list of (x, y)-coordinates
48
+ coords = np.zeros((shape.num_parts, 2), dtype=dtype)
49
+ # loop over all facial landmarks and convert them
50
+ # to a 2-tuple of (x, y)-coordinates
51
+ for i in range(0, shape.num_parts):
52
+ coords[i] = (shape.part(i).x, shape.part(i).y)
53
+ # return the list of (x, y)-coordinates
54
+ return coords
55
+
56
+ def get_68lm(imgfile,savepath):
57
+ image = cv2.imread(imgfile)
58
+ rgbImg = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
59
+ rects = detector(rgbImg, 1)
60
+ for (i, rect) in enumerate(rects):
61
+ landmarks = predictor(rgbImg, rect)
62
+ landmarks = shape_to_np(landmarks)
63
+ f = open(savepath,'w')
64
+ for i in range(len(landmarks)):
65
+ lm = landmarks[i]
66
+ print(lm[0], lm[1], file=f)
67
+ f.close()
68
+
69
+ def get_partmask(imgfile,part,lmpath,savefile):
70
+ img = cv2.imread(imgfile)
71
+ mask = np.zeros(img.shape, np.uint8)
72
+ lms = getfeats(lmpath)
73
+
74
+ if os.path.exists(savefile):
75
+ return
76
+
77
+ if part == 'nose':
78
+ # 27,31....,35 -> up, left, right, lower5 -- eight points
79
+ up = [int(round(1.2*lms[27][0]-0.2*lms[33][0])),int(round(1.2*lms[27][1]-0.2*lms[33][1]))]
80
+ lower5 = [[0,0]]*5
81
+ for i in range(31,36):
82
+ lower5[i-31] = [int(round(1.1*lms[i][0]-0.1*lms[27][0])),int(round(1.1*lms[i][1]-0.1*lms[27][1]))]
83
+ ratio = 2.5
84
+ left = [int(round(ratio*lower5[0][0]-(ratio-1)*lower5[1][0])),int(round(ratio*lower5[0][1]-(ratio-1)*lower5[1][1]))]
85
+ right = [int(round(ratio*lower5[4][0]-(ratio-1)*lower5[3][0])),int(round(ratio*lower5[4][1]-(ratio-1)*lower5[3][1]))]
86
+ loop = [up,left,lower5[0],lower5[1],lower5[2],lower5[3],lower5[4],right]
87
+ elif part == 'eyel':
88
+ height = max(lms[41][1]-lms[37][1],lms[40][1]-lms[38][1])
89
+ width = lms[39][0]-lms[36][0]
90
+ ratio = 0.1
91
+ gap = int(math.ceil(width*ratio))
92
+ ratio2 = 0.6
93
+ gaph = int(math.ceil(height*ratio2))
94
+ ratio3 = 1.5
95
+ gaph2 = int(math.ceil(height*ratio3))
96
+ upper = [[lms[17][0]-2*gap,lms[17][1]],[lms[17][0]-2*gap,lms[17][1]-gaph],[lms[18][0],lms[18][1]-gaph],[lms[19][0],lms[19][1]-gaph],[lms[20][0],lms[20][1]-gaph],[lms[21][0]+gap*2,lms[21][1]-gaph]]
97
+ lower = [[lms[39][0]+gap,lms[40][1]+gaph2],[lms[40][0],lms[40][1]+gaph2],[lms[41][0],lms[41][1]+gaph2],[lms[36][0]-2*gap,lms[41][1]+gaph2]]
98
+ loop = upper + lower
99
+ loop.reverse()
100
+ elif part == 'eyer':
101
+ height = max(lms[47][1]-lms[43][1],lms[46][1]-lms[44][1])
102
+ width = lms[45][0]-lms[42][0]
103
+ ratio = 0.1
104
+ gap = int(math.ceil(width*ratio))
105
+ ratio2 = 0.6
106
+ gaph = int(math.ceil(height*ratio2))
107
+ ratio3 = 1.5
108
+ gaph2 = int(math.ceil(height*ratio3))
109
+ upper = [[lms[22][0]-2*gap,lms[22][1]],[lms[22][0]-2*gap,lms[22][1]-gaph],[lms[23][0],lms[23][1]-gaph],[lms[24][0],lms[24][1]-gaph],[lms[25][0],lms[25][1]-gaph],[lms[26][0]+gap*2,lms[26][1]-gaph]]
110
+ lower = [[lms[45][0]+2*gap,lms[46][1]+gaph2],[lms[46][0],lms[46][1]+gaph2],[lms[47][0],lms[47][1]+gaph2],[lms[42][0]-gap,lms[42][1]+gaph2]]
111
+ loop = upper + lower
112
+ loop.reverse()
113
+ elif part == 'mouth':
114
+ height = lms[62][1]-lms[51][1]
115
+ width = lms[54][0]-lms[48][0]
116
+ ratio = 1
117
+ ratio2 = 0.2#0.1
118
+ gaph = int(math.ceil(ratio*height))
119
+ gapw = int(math.ceil(ratio2*width))
120
+ left = [(lms[48][0]-gapw,lms[48][1])]
121
+ upper = [(lms[i][0], lms[i][1]-gaph) for i in range(48,55)]
122
+ right = [(lms[54][0]+gapw,lms[54][1])]
123
+ lower = [(lms[i][0], lms[i][1]+gaph) for i in list(range(54,60))+[48]]
124
+ loop = left + upper + right + lower
125
+ loop.reverse()
126
+ pl = Polygon(loop)
127
+
128
+ for i in range(mask.shape[0]):
129
+ for j in range(mask.shape[1]):
130
+ if part != 'mouth' and part != 'jaw':
131
+ p = [j,i]
132
+ flag = 1
133
+ for k in range(len(loop)):
134
+ if mulcross(p,loop[k],loop[(k+1)%len(loop)]) < 0:#y downside... >0 represents counter-clockwise, <0 clockwise
135
+ flag = 0
136
+ break
137
+ else:
138
+ p = Point(j,i)
139
+ flag = pl.contains(p)
140
+ if flag:
141
+ mask[i,j] = [255,255,255]
142
+ if not os.path.exists(os.path.dirname(savefile)):
143
+ os.mkdir(os.path.dirname(savefile))
144
+ cv2.imwrite(savefile,mask)
145
+
146
+ if __name__ == '__main__':
147
+ imgfile = 'example/img_1701_aligned.png'
148
+ lmfile = 'example/img_1701_aligned_68lm.txt'
149
+ get_68lm(imgfile,lmfile)
150
+ for part in ['eyel','eyer','nose','mouth']:
151
+ savepath = 'example/img_1701_aligned_'+part+'mask.png'
152
+ get_partmask(imgfile,part,lmfile,savepath)
APDrawingGAN2/preprocess/readme.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Preprocessing steps
2
+
3
+ Both training and testing images need:
4
+
5
+ - align to 512x512
6
+ - facial landmarks
7
+ - mask for eyes,nose,mouth,background
8
+
9
+ Training images additionally need:
10
+
11
+ - mask for face region
12
+
13
+
14
+ ### 1. Align, resize, crop images to 512x512, and get facial landmarks
15
+
16
+ All training and testing images in our model are aligned using facial landmarks. And landmarks after alignment are needed in our code.
17
+
18
+ - First, 5 facial landmark for a face photo need to be detected (we detect using [MTCNN](https://github.com/kpzhang93/MTCNN_face_detection_alignment)(MTCNNv1)).
19
+
20
+ - Then, we provide a matlab function in `face_align_512.m` to align, resize and crop face photos (and corresponding drawings) to 512x512.Call this function in MATLAB to align the image to 512x512.
21
+ For example, for `img_1701.jpg` in `example` dir, 5 detected facial landmark is saved in `example/img_1701_facial5point.mat`. Call following in MATLAB:
22
+ ```bash
23
+ load('example/img_1701_facial5point.mat');
24
+ [trans_img,trans_facial5point]=face_align_512('example/img_1701.jpg',facial5point,'example');
25
+ ```
26
+
27
+ This will align the image, and output aligned image + transformed facial landmark (in txt format) in `example` folder.
28
+ See `face_align_512.m` for more instructions.
29
+
30
+ The saved transformed facial landmark need to be copied to `dataset/landmark/`, and has the **same filename** with aligned face photos (e.g. `dataset/data/test_single/31.png` should have landmark file `dataset/landmark/31.txt`).
31
+
32
+ ### 2. Prepare background masks
33
+
34
+ In our work, background mask is segmented by method in
35
+ "Automatic Portrait Segmentation for Image Stylization"
36
+ Xiaoyong Shen, Aaron Hertzmann, Jiaya Jia, Sylvain Paris, Brian Price, Eli Shechtman, Ian Sachs. Computer Graphics Forum, 35(2)(Proc. Eurographics), 2016.
37
+
38
+ We use code in http://xiaoyongshen.me/webpage_portrait/index.html to detect background masks for aligned face photos.
39
+ An example background mask is shown in `example/img_1701_aligned_bgmask.png`.
40
+
41
+ The background masks need to be copied to `dataset/mask/bg/`, and has the **same filename** with aligned face photos (e.g. `dataset/data/test_single/31.png` should have background mask `dataset/mask/bg/31.png`)
42
+
43
+ ### 3. Prepare eyes/nose/mouth masks
44
+
45
+ We use dlib to extract 68 landmarks for aligned face photos, and use these landmarks to get masks for local regions.
46
+ See an example in `get_partmask.py`, the eyes, nose, mouth masks for `example/img_1701_aligned.png` are `example/img_1701_aligned_[part]mask.png`, where part is in [eyel,eyer,nose,mouth].
47
+
48
+ The part masks need to be copied to `dataset/mask/[part]/`, and has the **same filename** with aligned face photos.
49
+
50
+ ### 4. (For training) Prepare face masks
51
+
52
+ We use the face parsing net in https://github.com/cientgu/Mask_Guided_Portrait_Editing to detect face region.
53
+ The face parsing net will label each face into 11 classes, the 0 is for background, 10 is for hair, and the 1~9 are face regions.
54
+ An example face mask is shown in `example/img_1701_aligned_facemask.png`.
55
+
56
+ The face masks need to be copied to `dataset/mask/face/`, and has the **same filename** with aligned face photos.
57
+
58
+ ### 5. (For training) Combine A and B
59
+
60
+ We provide a python script to generate training data in the form of pairs of images {A,B}, i.e. pairs {face photo, drawing}. This script will concatenate each pair of images horizontally into one single image. Then we can learn to translate A to B:
61
+
62
+ Create folder `/path/to/data` with subfolders `A` and `B`. `A` and `B` should each have their own subfolders `train`, `test`, etc. In `/path/to/data/A/train`, put training face photos. In `/path/to/data/B/train`, put the corresponding artist drawings. Repeat same for `test`.
63
+
64
+ Corresponding images in a pair {A,B} must both be images after aligning and of size 512x512, and have the same filename, e.g., `/path/to/data/A/train/1.png` is considered to correspond to `/path/to/data/B/train/1.png`.
65
+
66
+ Once the data is formatted this way, call:
67
+ ```bash
68
+ python datasets/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data
69
+ ```
70
+
71
+ This will combine each pair of images (A,B) into a single image file, ready for training.
APDrawingGAN2/readme.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # APDrawingGAN++
3
+
4
+ We provide PyTorch implementations for our TPAMI paper "Line Drawings for Face Portraits from Photos using Global and Local Structure based GANs".
5
+ It is a journal extension of our previous CVPR 2019 work [APDrawingGAN](https://github.com/yiranran/APDrawingGAN).
6
+
7
+ This project generates artistic portrait drawings from face photos using a GAN-based model.
8
+ You may find useful information in [preprocessing steps](preprocess/readme.md) and [training/testing tips](docs/tips.md).
9
+
10
+ [[Jittor implementation]](https://github.com/yiranran/APDrawingGAN2-Jittor)
11
+
12
+ ## Our Proposed Framework
13
+
14
+ <img src = 'imgs/architecture-pami.jpg'>
15
+
16
+ ## Sample Results
17
+ Up: input, Down: output
18
+ <p>
19
+ <img src='imgs/sample/140_large-img_1696_real_A.png' width="16%"/>
20
+ <img src='imgs/sample/140_large-img_1615_real_A.png' width="16%"/>
21
+ <img src='imgs/sample/140_large-img_1684_real_A.png' width="16%"/>
22
+ <img src='imgs/sample/140_large-img_1616_real_A.png' width="16%"/>
23
+ <img src='imgs/sample/140_large-img_1673_real_A.png' width="16%"/>
24
+ <img src='imgs/sample/140_large-img_1701_real_A.png' width="16%"/>
25
+ </p>
26
+ <p>
27
+ <img src='imgs/sample/140_large-img_1696_fake_B.png' width="16%"/>
28
+ <img src='imgs/sample/140_large-img_1615_fake_B.png' width="16%"/>
29
+ <img src='imgs/sample/140_large-img_1684_fake_B.png' width="16%"/>
30
+ <img src='imgs/sample/140_large-img_1616_fake_B.png' width="16%"/>
31
+ <img src='imgs/sample/140_large-img_1673_fake_B.png' width="16%"/>
32
+ <img src='imgs/sample/140_large-img_1701_fake_B.png' width="16%"/>
33
+ </p>
34
+
35
+ ## Citation
36
+ If you use this code for your research, please cite our paper.
37
+ ```
38
+ @inproceedings{YiXLLR20,
39
+ title = {Line Drawings for Face Portraits from Photos using Global and Local Structure based {GAN}s},
40
+ author = {Yi, Ran and Xia, Mengfei and Liu, Yong-Jin and Lai, Yu-Kun and Rosin, Paul L},
41
+ booktitle = {{IEEE} Transactions on Pattern Analysis and Machine Intelligence (TPAMI)},
42
+ doi = {10.1109/TPAMI.2020.2987931},
43
+ year = {2020}
44
+ }
45
+ ```
46
+
47
+ ## Prerequisites
48
+ - Linux or macOS
49
+ - Python 2 or 3
50
+ - CPU or NVIDIA GPU + CUDA CuDNN
51
+
52
+
53
+ ## Getting Started
54
+ ### 1.Installation
55
+ ```bash
56
+ pip install -r requirements.txt
57
+ ```
58
+
59
+ ### 2.Quick Start (Apply a Pre-trained Model)
60
+ - Download APDrawing dataset from [BaiduYun](https://pan.baidu.com/s/1cN5gEYJ2tnE9WboLA79Z5g)(extract code:0zuv) or [YandexDrive](https://yadi.sk/d/4vWhi8-ZQj_nRw), and extract to `dataset`.
61
+
62
+ - Download pre-trained models and auxiliary nets from [BaiduYun](https://pan.baidu.com/s/1nrtCHQmgcwbSGxWuAVzWhA)(extract code:imqp) or [YandexDrive](https://yadi.sk/d/DS4271lbEPhGVQ), and extract to `checkpoints`.
63
+
64
+ - Generate artistic portrait drawings for example photos in `dataset/test_single` using
65
+ ``` bash
66
+ python test.py --dataroot dataset/test_single --name apdrawinggan++_author --model test --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-single
67
+ ```
68
+ The test results will be saved to a html file here: `./results/apdrawinggan++_author/test_150/index-single.html`.
69
+
70
+ - If you want to test on your own data, please first align your pictures and prepare your data's facial landmarks and masks according to tutorial in [preprocessing steps](preprocess/readme.md), then change the --dataroot flag above to your directory of aligned photos.
71
+
72
+ ### 3.Train
73
+ - Run `python -m visdom.server`
74
+ - Train a model (with pre-training as initialization):
75
+ first copy "pre2" models into checkpoints dir of current experiment, e.g. `checkpoints/apdrawinggan++_1`.
76
+ ```bash
77
+ mkdir checkpoints/apdrawinggan++_1/
78
+ cp checkpoints/pre2/*.pt checkpoints/apdrawinggan++_1/
79
+ python train.py --dataroot dataset/AB_140_aug3_H_hm2 --name apdrawinggan++_1 --model apdrawingpp_style --use_resnet --netG resnet_9blocks --continue_train --continuity_loss --lambda_continuity 40.0 --gpu_ids 0 --gpu_ids_p 1 --display_env apdrawinggan++_1 --niter 200 --niter_decay 0 --lr 0.0001 --batch_size 1 --emphasis_conti_face --auxiliary_root auxiliaryeye2o
80
+ ```
81
+ - To view training results and loss plots, click the URL http://localhost:8097. To see more intermediate results, check out `./checkpoints/apdrawinggan++_1/web/index.html`
82
+
83
+ ### 4.Test
84
+ - To test the model on test set:
85
+ ```bash
86
+ python test.py --dataroot dataset/AB_140_aug3_H_hm2 --name apdrawinggan++_author --model apdrawingpp_style --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-apd70
87
+ ```
88
+ The test results will be saved to a html file: `./results/apdrawinggan++_author/test_150/index-apd70.html`.
89
+
90
+ - To test the model on images without paired ground truth, same as 2. Apply a pre-trained model.
91
+
92
+ You can find these scripts at `scripts` directory.
93
+
94
+
95
+ ## [Preprocessing Steps](preprocess/readme.md)
96
+ Preprocessing steps for your own data (either for testing or training).
97
+
98
+
99
+ ## [Training/Test Tips](docs/tips.md)
100
+ Best practice for training and testing your models.
101
+
102
+ You can contact email yr16@mails.tsinghua.edu.cn for any questions.
103
+
104
+ ## Acknowledgments
105
+ Our code is inspired by [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
APDrawingGAN2/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.1.0
2
+ torchvision==0.4.0
3
+ dominate==2.4.0
4
+ visdom==0.1.8.9
5
+ scipy==1.1.0
6
+ numpy==1.16.4
7
+ Pillow==4.3.0
8
+ opencv-python==4.1.0.25
9
+ dlib==19.18.0
10
+ shapely==1.7.0
APDrawingGAN2/script/test.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ set -ex
2
+ python test.py --dataroot dataset/AB_140_aug3_H_hm2 --name apdrawinggan++_author --model apdrawingpp_style --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-apd70
APDrawingGAN2/script/test_single.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ set -ex
2
+ python test.py --dataroot dataset/test_single --name apdrawinggan++_author --model test --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-single
APDrawingGAN2/script/train.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ set -ex
2
+ python train.py --dataroot dataset/AB_140_aug3_H_hm2 --name apdrawinggan++_1 --model apdrawingpp_style --use_resnet --netG resnet_9blocks --continue_train --continuity_loss --lambda_continuity 40.0 --gpu_ids 0 --gpu_ids_p 1 --display_env apdrawinggan++_1 --niter 200 --niter_decay 0 --lr 0.0001 --batch_size 1 --emphasis_conti_face --auxiliary_root auxiliaryeye2o
3
+
APDrawingGAN2/test.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from options.test_options import TestOptions
3
+ from data import CreateDataLoader
4
+ from models import create_model
5
+ from util.visualizer import save_images
6
+ from util import html
7
+
8
+
9
+ if __name__ == '__main__':
10
+ opt = TestOptions().parse()
11
+ opt.num_threads = 1 # test code only supports num_threads = 1
12
+ opt.batch_size = 1 # test code only supports batch_size = 1
13
+ opt.serial_batches = True # no shuffle
14
+ opt.no_flip = True # no flip
15
+ opt.display_id = -1 # no visdom display
16
+ data_loader = CreateDataLoader(opt)
17
+ dataset = data_loader.load_data()
18
+ model = create_model(opt)
19
+ model.setup(opt)
20
+ # create website
21
+ web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
22
+ #webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
23
+ webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch),reflesh=0, folder=opt.imagefolder)
24
+ if opt.test_continuity_loss:
25
+ file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'continuity.txt')
26
+ file_name1 = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'continuity-r.txt')
27
+ if os.path.exists(file_name):
28
+ os.remove(file_name)
29
+ if os.path.exists(file_name1):
30
+ os.remove(file_name1)
31
+ # test
32
+ #model.eval()
33
+ for i, data in enumerate(dataset):
34
+ if i >= opt.how_many:#test code only supports batch_size = 1, how_many means how many test images to run
35
+ break
36
+ model.set_input(data)
37
+ model.test()
38
+ visuals = model.get_current_visuals()#in test the loadSize is set to the same as fineSize
39
+ img_path = model.get_image_paths()
40
+ #if i % 5 == 0:
41
+ # print('processing (%04d)-th image... %s' % (i, img_path))
42
+ save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
43
+
44
+ webpage.save()
45
+ if opt.model == 'regressor':
46
+ print(model.cnt)
47
+ print(model.value/model.cnt)
48
+ print(model.minval)
49
+ print(model.avg/model.cnt)
50
+ print(model.max)
51
+ html = os.path.join(web_dir,'cindex'+opt.imagefolder[6:]+'.html')
52
+ f=open(html,'w')
53
+ print('<table border="1" style=\"text-align:center;\">',file=f,end='')
54
+ print('<tr>',file=f,end='')
55
+ print('<td>image name</td>',file=f,end='')
56
+ print('<td>realA</td>',file=f,end='')
57
+ print('<td>realB</td>',file=f,end='')
58
+ print('<td>fakeB</td>',file=f,end='')
59
+ print('</tr>',file=f,end='')
60
+ for info in model.info:
61
+ basen = os.path.basename(info[0])[:-4]
62
+ print('<tr>',file=f,end='')
63
+ print('<td>%s</td>'%basen,file=f,end='')
64
+ print('<td><img src=\"%s/%s_real_A.png\" style=\"width:44px\"></td>'%(opt.imagefolder,basen),file=f,end='')
65
+ print('<td>%.4f</td>'%info[1],file=f,end='')
66
+ print('<td>%.4f</td>'%info[2],file=f,end='')
67
+ print('</tr>',file=f,end='')
68
+ print('</table>',file=f,end='')
69
+ f.close()
APDrawingGAN2/train.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from options.train_options import TrainOptions
3
+ from data import CreateDataLoader
4
+ from models import create_model
5
+ from util.visualizer import Visualizer
6
+
7
+ if __name__ == '__main__':
8
+ start = time.time()
9
+ opt = TrainOptions().parse()
10
+ data_loader = CreateDataLoader(opt)
11
+ dataset = data_loader.load_data()
12
+ dataset_size = len(data_loader)
13
+ print('#training images = %d' % dataset_size)
14
+
15
+ model = create_model(opt)
16
+ model.setup(opt)
17
+ visualizer = Visualizer(opt)
18
+ total_steps = 0
19
+ model.save_networks2(opt.which_epoch)
20
+
21
+ for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
22
+ epoch_start_time = time.time()
23
+ iter_data_time = time.time()
24
+ epoch_iter = 0
25
+
26
+ for i, data in enumerate(dataset):
27
+ iter_start_time = time.time()
28
+ if total_steps % opt.print_freq == 0:
29
+ t_data = iter_start_time - iter_data_time
30
+ visualizer.reset()
31
+ total_steps += opt.batch_size
32
+ epoch_iter += opt.batch_size
33
+ model.set_input(data)
34
+ model.optimize_parameters()
35
+
36
+ if total_steps % opt.display_freq == 0:
37
+ save_result = total_steps % opt.update_html_freq == 0
38
+ visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
39
+ #print('display',total_steps)
40
+
41
+ if total_steps % opt.print_freq == 0:#print freq 100
42
+ losses = model.get_current_losses()
43
+ t = (time.time() - iter_start_time) / opt.batch_size
44
+ visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data)
45
+ if opt.display_id > 0:
46
+ visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses)
47
+
48
+ if total_steps % opt.save_latest_freq == 0:
49
+ print('saving the latest model (epoch %d, total_steps %d)' %
50
+ (epoch, total_steps))
51
+ #model.save_networks('latest')
52
+ model.save_networks2('latest')
53
+
54
+ iter_data_time = time.time()
55
+ if epoch % opt.save_epoch_freq == 0:
56
+ print('saving the model at the end of epoch %d, iters %d' %
57
+ (epoch, total_steps))
58
+ #model.save_networks('latest')
59
+ #model.save_networks(epoch)
60
+ model.save_networks2('latest')
61
+ model.save_networks2(epoch)
62
+
63
+ print('End of epoch %d / %d \t Time Taken: %d sec' %
64
+ (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
65
+ model.update_learning_rate()
66
+
67
+ print('Total Time Taken: %d sec' % (time.time() - start))
APDrawingGAN2/util/__init__.py ADDED
File without changes
APDrawingGAN2/util/get_data.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import tarfile
4
+ import requests
5
+ from warnings import warn
6
+ from zipfile import ZipFile
7
+ from bs4 import BeautifulSoup
8
+ from os.path import abspath, isdir, join, basename
9
+
10
+
11
+ class GetData(object):
12
+ """
13
+
14
+ Download CycleGAN or Pix2Pix Data.
15
+
16
+ Args:
17
+ technique : str
18
+ One of: 'cyclegan' or 'pix2pix'.
19
+ verbose : bool
20
+ If True, print additional information.
21
+
22
+ Examples:
23
+ >>> from util.get_data import GetData
24
+ >>> gd = GetData(technique='cyclegan')
25
+ >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
26
+
27
+ """
28
+
29
+ def __init__(self, technique='cyclegan', verbose=True):
30
+ url_dict = {
31
+ 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets',
32
+ 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
33
+ }
34
+ self.url = url_dict.get(technique.lower())
35
+ self._verbose = verbose
36
+
37
+ def _print(self, text):
38
+ if self._verbose:
39
+ print(text)
40
+
41
+ @staticmethod
42
+ def _get_options(r):
43
+ soup = BeautifulSoup(r.text, 'lxml')
44
+ options = [h.text for h in soup.find_all('a', href=True)
45
+ if h.text.endswith(('.zip', 'tar.gz'))]
46
+ return options
47
+
48
+ def _present_options(self):
49
+ r = requests.get(self.url)
50
+ options = self._get_options(r)
51
+ print('Options:\n')
52
+ for i, o in enumerate(options):
53
+ print("{0}: {1}".format(i, o))
54
+ choice = input("\nPlease enter the number of the "
55
+ "dataset above you wish to download:")
56
+ return options[int(choice)]
57
+
58
+ def _download_data(self, dataset_url, save_path):
59
+ if not isdir(save_path):
60
+ os.makedirs(save_path)
61
+
62
+ base = basename(dataset_url)
63
+ temp_save_path = join(save_path, base)
64
+
65
+ with open(temp_save_path, "wb") as f:
66
+ r = requests.get(dataset_url)
67
+ f.write(r.content)
68
+
69
+ if base.endswith('.tar.gz'):
70
+ obj = tarfile.open(temp_save_path)
71
+ elif base.endswith('.zip'):
72
+ obj = ZipFile(temp_save_path, 'r')
73
+ else:
74
+ raise ValueError("Unknown File Type: {0}.".format(base))
75
+
76
+ self._print("Unpacking Data...")
77
+ obj.extractall(save_path)
78
+ obj.close()
79
+ os.remove(temp_save_path)
80
+
81
+ def get(self, save_path, dataset=None):
82
+ """
83
+
84
+ Download a dataset.
85
+
86
+ Args:
87
+ save_path : str
88
+ A directory to save the data to.
89
+ dataset : str, optional
90
+ A specific dataset to download.
91
+ Note: this must include the file extension.
92
+ If None, options will be presented for you
93
+ to choose from.
94
+
95
+ Returns:
96
+ save_path_full : str
97
+ The absolute path to the downloaded data.
98
+
99
+ """
100
+ if dataset is None:
101
+ selected_dataset = self._present_options()
102
+ else:
103
+ selected_dataset = dataset
104
+
105
+ save_path_full = join(save_path, selected_dataset.split('.')[0])
106
+
107
+ if isdir(save_path_full):
108
+ warn("\n'{0}' already exists. Voiding Download.".format(
109
+ save_path_full))
110
+ else:
111
+ self._print('Downloading Data...')
112
+ url = "{0}/{1}".format(self.url, selected_dataset)
113
+ self._download_data(url, save_path=save_path)
114
+
115
+ return abspath(save_path_full)
APDrawingGAN2/util/html.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dominate
2
+ from dominate.tags import *
3
+ import os
4
+
5
+
6
+ class HTML:
7
+ def __init__(self, web_dir, title, reflesh=0, folder='images'):
8
+ self.title = title
9
+ self.web_dir = web_dir
10
+ #self.img_dir = os.path.join(self.web_dir, 'images')
11
+ self.img_dir = os.path.join(self.web_dir, folder)
12
+ self.folder = folder
13
+ if not os.path.exists(self.web_dir):
14
+ os.makedirs(self.web_dir)
15
+ if not os.path.exists(self.img_dir):
16
+ os.makedirs(self.img_dir)
17
+ # print(self.img_dir)
18
+
19
+ self.doc = dominate.document(title=title)
20
+ if reflesh > 0:
21
+ with self.doc.head:
22
+ meta(http_equiv="reflesh", content=str(reflesh))
23
+
24
+ def get_image_dir(self):
25
+ return self.img_dir
26
+
27
+ def add_header(self, str):
28
+ with self.doc:
29
+ h3(str)
30
+
31
+ def add_table(self, border=1):
32
+ self.t = table(border=border, style="table-layout: fixed;")
33
+ self.doc.add(self.t)
34
+
35
+ def add_images(self, ims, txts, links, width=400):
36
+ self.add_table()
37
+ with self.t:
38
+ with tr():
39
+ for im, txt, link in zip(ims, txts, links):
40
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
41
+ with p():
42
+ with a(href=os.path.join('images', link)):
43
+ #img(style="width:%dpx" % width, src=os.path.join('images', im))
44
+ img(style="width:%dpx" % width, src=os.path.join(self.folder, im))
45
+ br()
46
+ p(txt)
47
+
48
+ def save(self):
49
+ #html_file = '%s/index.html' % self.web_dir
50
+ html_file = '%s/index%s.html' % (self.web_dir, self.folder[6:])
51
+ f = open(html_file, 'wt')
52
+ f.write(self.doc.render())
53
+ f.close()
54
+
55
+
56
+ if __name__ == '__main__':
57
+ html = HTML('web/', 'test_html')
58
+ html.add_header('hello world')
59
+
60
+ ims = []
61
+ txts = []
62
+ links = []
63
+ for n in range(4):
64
+ ims.append('image_%d.png' % n)
65
+ txts.append('text_%d' % n)
66
+ links.append('image_%d.png' % n)
67
+ html.add_images(ims, txts, links)
68
+ html.save()
APDrawingGAN2/util/image_pool.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+
4
+
5
+ class ImagePool():
6
+ def __init__(self, pool_size):
7
+ self.pool_size = pool_size
8
+ if self.pool_size > 0:
9
+ self.num_imgs = 0
10
+ self.images = []
11
+
12
+ def query(self, images):
13
+ if self.pool_size == 0:
14
+ return images
15
+ return_images = []
16
+ for image in images:
17
+ image = torch.unsqueeze(image.data, 0)
18
+ if self.num_imgs < self.pool_size:
19
+ self.num_imgs = self.num_imgs + 1
20
+ self.images.append(image)
21
+ return_images.append(image)
22
+ else:
23
+ p = random.uniform(0, 1)
24
+ if p > 0.5:
25
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
26
+ tmp = self.images[random_id].clone()
27
+ self.images[random_id] = image
28
+ return_images.append(tmp)
29
+ else:
30
+ return_images.append(image)
31
+ return_images = torch.cat(return_images, 0)
32
+ return return_images
APDrawingGAN2/util/util.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import os
6
+
7
+
8
+ # Converts a Tensor into an image array (numpy)
9
+ # |imtype|: the desired type of the converted numpy array
10
+ def tensor2im(input_image, imtype=np.uint8):
11
+ if isinstance(input_image, torch.Tensor):
12
+ image_tensor = input_image.data
13
+ else:
14
+ return input_image
15
+ image_numpy = image_tensor[0].cpu().float().numpy()
16
+ if image_numpy.shape[0] == 1:
17
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
18
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
19
+ return image_numpy.astype(imtype)
20
+
21
+
22
+ def diagnose_network(net, name='network'):
23
+ mean = 0.0
24
+ count = 0
25
+ for param in net.parameters():
26
+ if param.grad is not None:
27
+ mean += torch.mean(torch.abs(param.grad.data))
28
+ count += 1
29
+ if count > 0:
30
+ mean = mean / count
31
+ print(name)
32
+ print(mean)
33
+
34
+
35
+ def save_image(image_numpy, image_path):
36
+ image_pil = Image.fromarray(image_numpy)
37
+ image_pil.save(image_path)
38
+
39
+
40
+ def print_numpy(x, val=True, shp=False):
41
+ x = x.astype(np.float64)
42
+ if shp:
43
+ print('shape,', x.shape)
44
+ if val:
45
+ x = x.flatten()
46
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
47
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
48
+
49
+
50
+ def mkdirs(paths):
51
+ if isinstance(paths, list) and not isinstance(paths, str):
52
+ for path in paths:
53
+ mkdir(path)
54
+ else:
55
+ mkdir(paths)
56
+
57
+
58
+ def mkdir(path):
59
+ if not os.path.exists(path):
60
+ os.makedirs(path)
APDrawingGAN2/util/visualizer.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import ntpath
4
+ import time
5
+ from . import util
6
+ from . import html
7
+ from scipy.misc import imresize
8
+
9
+
10
+ # save image to the disk
11
+ def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
12
+ image_dir = webpage.get_image_dir()
13
+ short_path = ntpath.basename(image_path[0])
14
+ name = os.path.splitext(short_path)[0]
15
+
16
+ webpage.add_header(name)
17
+ ims, txts, links = [], [], []
18
+
19
+ for label, im_data in visuals.items():
20
+ im = util.tensor2im(im_data)#tensor to numpy array [-1,1]->[0,1]->[0,255]
21
+ image_name = '%s_%s.png' % (name, label)
22
+ save_path = os.path.join(image_dir, image_name)
23
+ h, w, _ = im.shape
24
+ if aspect_ratio > 1.0:
25
+ im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
26
+ if aspect_ratio < 1.0:
27
+ im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
28
+ util.save_image(im, save_path)
29
+
30
+ ims.append(image_name)
31
+ txts.append(label)
32
+ links.append(image_name)
33
+ webpage.add_images(ims, txts, links, width=width)
34
+
35
+
36
+ class Visualizer():
37
+ def __init__(self, opt):
38
+ self.display_id = opt.display_id
39
+ self.use_html = opt.isTrain and not opt.no_html
40
+ self.win_size = opt.display_winsize
41
+ self.name = opt.name
42
+ self.opt = opt
43
+ self.saved = False
44
+ if self.display_id > 0:
45
+ import visdom
46
+ self.ncols = opt.display_ncols
47
+ self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env, raise_exceptions=True)
48
+
49
+ if self.use_html:
50
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
51
+ self.img_dir = os.path.join(self.web_dir, 'images')
52
+ print('create web directory %s...' % self.web_dir)
53
+ util.mkdirs([self.web_dir, self.img_dir])
54
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
55
+ with open(self.log_name, "a") as log_file:
56
+ now = time.strftime("%c")
57
+ log_file.write('================ Training Loss (%s) ================\n' % now)
58
+
59
+ def reset(self):
60
+ self.saved = False
61
+
62
+ def throw_visdom_connection_error(self):
63
+ print('\n\nCould not connect to Visdom server (https://github.com/facebookresearch/visdom) for displaying training progress.\nYou can suppress connection to Visdom using the option --display_id -1. To install visdom, run \n$ pip install visdom\n, and start the server by \n$ python -m visdom.server.\n\n')
64
+ exit(1)
65
+
66
+ # |visuals|: dictionary of images to display or save
67
+ def display_current_results(self, visuals, epoch, save_result):
68
+ if self.display_id > 0: # show images in the browser
69
+ ncols = self.ncols
70
+ if ncols > 0:
71
+ ncols = min(ncols, len(visuals))
72
+ h, w = next(iter(visuals.values())).shape[:2]
73
+ table_css = """<style>
74
+ table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center}
75
+ table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black}
76
+ </style>""" % (w, h)
77
+ title = self.name
78
+ label_html = ''
79
+ label_html_row = ''
80
+ images = []
81
+ idx = 0
82
+ for label, image in visuals.items():
83
+ image_numpy = util.tensor2im(image)
84
+ label_html_row += '<td>%s</td>' % label
85
+ images.append(image_numpy.transpose([2, 0, 1]))
86
+ idx += 1
87
+ if idx % ncols == 0:
88
+ label_html += '<tr>%s</tr>' % label_html_row
89
+ label_html_row = ''
90
+ white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
91
+ while idx % ncols != 0:
92
+ images.append(white_image)
93
+ label_html_row += '<td></td>'
94
+ idx += 1
95
+ if label_html_row != '':
96
+ label_html += '<tr>%s</tr>' % label_html_row
97
+ # pane col = image row
98
+ try:
99
+ self.vis.images(images, nrow=ncols, win=self.display_id + 1,
100
+ padding=2, opts=dict(title=title + ' images'))
101
+ label_html = '<table>%s</table>' % label_html
102
+ self.vis.text(table_css + label_html, win=self.display_id + 2,
103
+ opts=dict(title=title + ' labels'))
104
+ except ConnectionError:
105
+ self.throw_visdom_connection_error()
106
+
107
+ else:
108
+ idx = 1
109
+ for label, image in visuals.items():
110
+ image_numpy = util.tensor2im(image)
111
+ self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
112
+ win=self.display_id + idx)
113
+ idx += 1
114
+
115
+ if self.use_html and (save_result or not self.saved): # save images to a html file
116
+ self.saved = True
117
+ for label, image in visuals.items():
118
+ image_numpy = util.tensor2im(image)
119
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
120
+ util.save_image(image_numpy, img_path)
121
+ # update website
122
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
123
+ for n in range(epoch, 0, -1):
124
+ webpage.add_header('epoch [%d]' % n)
125
+ ims, txts, links = [], [], []
126
+
127
+ for label, image_numpy in visuals.items():
128
+ image_numpy = util.tensor2im(image)
129
+ img_path = 'epoch%.3d_%s.png' % (n, label)
130
+ ims.append(img_path)
131
+ txts.append(label)
132
+ links.append(img_path)
133
+ webpage.add_images(ims, txts, links, width=self.win_size)
134
+ webpage.save()
135
+
136
+ def save_current_results1(self, visuals, epoch, epoch_iter):
137
+ if not os.path.exists(self.img_dir+'/detailed'):
138
+ os.mkdir(self.img_dir+'/detailed')
139
+ for label, image in visuals.items():
140
+ image_numpy = util.tensor2im(image)
141
+ img_path = os.path.join(self.img_dir, 'detailed', 'epoch%.3d_%.3d_%s.png' % (epoch, epoch_iter, label))
142
+ util.save_image(image_numpy, img_path)
143
+
144
+ # losses: dictionary of error labels and values
145
+ def plot_current_losses(self, epoch, counter_ratio, opt, losses):
146
+ if not hasattr(self, 'plot_data'):
147
+ self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
148
+ self.plot_data['X'].append(epoch + counter_ratio)
149
+ self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
150
+ try:
151
+ self.vis.line(
152
+ X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
153
+ Y=np.array(self.plot_data['Y']),
154
+ opts={
155
+ 'title': self.name + ' loss over time',
156
+ 'legend': self.plot_data['legend'],
157
+ 'xlabel': 'epoch',
158
+ 'ylabel': 'loss'},
159
+ win=self.display_id)
160
+ except ConnectionError:
161
+ self.throw_visdom_connection_error()
162
+
163
+ # losses: same format as |losses| of plot_current_losses
164
+ def print_current_losses(self, epoch, i, losses, t, t_data):
165
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data)
166
+ for k, v in losses.items():
167
+ message += '%s: %.6f ' % (k, v)
168
+
169
+ print(message)
170
+ with open(self.log_name, "a") as log_file:
171
+ log_file.write('%s\n' % message)
README.md CHANGED
@@ -1,4 +1,5 @@
1
  ---
 
2
  title: Apdrawing
3
  emoji: 💻
4
  colorFrom: indigo
 
1
  ---
2
+ python_version: 3.7
3
  title: Apdrawing
4
  emoji: 💻
5
  colorFrom: indigo
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+ import argparse
5
+ import functools
6
+ import os
7
+ import pathlib
8
+ import sys
9
+ from typing import Callable
10
+ import uuid
11
+
12
+ sys.path.insert(0, 'APDrawingGAN2')
13
+
14
+ import gradio as gr
15
+ import huggingface_hub
16
+ import numpy as np
17
+ import PIL.Image
18
+
19
+ from io import BytesIO
20
+ import shutil
21
+
22
+ from options.test_options import TestOptions
23
+ from data import CreateDataLoader
24
+ from models import create_model
25
+
26
+ from util import html
27
+
28
+ import ntpath
29
+ from util import util
30
+
31
+
32
+ ORIGINAL_REPO_URL = 'https://github.com/yiranran/APDrawingGAN2'
33
+ TITLE = 'yiranran/APDrawingGAN2'
34
+ DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}.
35
+
36
+ """
37
+ ARTICLE = """
38
+
39
+ """
40
+
41
+
42
+ MODEL_REPO = 'hylee/apdrawing_model'
43
+
44
+ def parse_args() -> argparse.Namespace:
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument('--device', type=str, default='cpu')
47
+ parser.add_argument('--theme', type=str)
48
+ parser.add_argument('--live', action='store_true')
49
+ parser.add_argument('--share', action='store_true')
50
+ parser.add_argument('--port', type=int)
51
+ parser.add_argument('--disable-queue',
52
+ dest='enable_queue',
53
+ action='store_false')
54
+ parser.add_argument('--allow-flagging', type=str, default='never')
55
+ parser.add_argument('--allow-screenshot', action='store_true')
56
+ return parser.parse_args()
57
+
58
+
59
+ def load_checkpoint():
60
+ dir = 'checkpoint'
61
+ checkpoint_path = huggingface_hub.hf_hub_download(MODEL_REPO,
62
+ 'checkpoints.zip',
63
+ force_filename='checkpoints.zip')
64
+ print(checkpoint_path)
65
+ shutil.unpack_archive(checkpoint_path, extract_dir=dir)
66
+
67
+ print(os.listdir(dir+'/checkpoints'))
68
+
69
+ return dir+'/checkpoints'
70
+
71
+ # save image to the disk
72
+ def save_images2(image_dir, visuals, image_path, aspect_ratio=1.0, width=256):
73
+ short_path = ntpath.basename(image_path[0])
74
+ name = os.path.splitext(short_path)[0]
75
+
76
+ imgs = []
77
+
78
+ for label, im_data in visuals.items():
79
+ im = util.tensor2im(im_data)#tensor to numpy array [-1,1]->[0,1]->[0,255]
80
+ image_name = '%s_%s.png' % (name, label)
81
+ save_path = os.path.join(image_dir, image_name)
82
+ h, w, _ = im.shape
83
+ if aspect_ratio > 1.0:
84
+ im = np.array(PIL.Image.fromarray(arr).resize(im, (h, int(w * aspect_ratio))))
85
+ if aspect_ratio < 1.0:
86
+ im = np.array(PIL.Image.fromarray(arr).resize(im, (int(h / aspect_ratio), w)))
87
+ util.save_image(im, save_path)
88
+ imgs.append(save_path)
89
+
90
+ return imgs
91
+
92
+
93
+ SAFEHASH = [x for x in "0123456789-abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ"]
94
+ def compress_UUID():
95
+ '''
96
+ 根据http://www.ietf.org/rfc/rfc1738.txt,由uuid编码扩bai大字符域生成du串
97
+ 包括:[0-9a-zA-Z\-_]共64个
98
+ 长度:(32-2)/3*2=20
99
+ 备注:可在地球上人zhi人都用,使用100年不重复(2^120)
100
+ :return:String
101
+ '''
102
+ row = str(uuid.uuid4()).replace('-', '')
103
+ safe_code = ''
104
+ for i in range(10):
105
+ enbin = "%012d" % int(bin(int(row[i * 3] + row[i * 3 + 1] + row[i * 3 + 2], 16))[2:], 10)
106
+ safe_code += (SAFEHASH[int(enbin[0:6], 2)] + SAFEHASH[int(enbin[6:12], 2)])
107
+ safe_code = safe_code.replace('-', '')
108
+ return safe_code
109
+
110
+
111
+ def run(
112
+ image,
113
+ model,
114
+ opt,
115
+ ) -> tuple[PIL.Image.Image]:
116
+
117
+ dataroot = 'images/'+compress_UUID()
118
+ opt.dataroot = os.path.join(dataroot, 'src/')
119
+ os.makedirs(opt.dataroot, exist_ok=True)
120
+ opt.results_dir = os.path.join(dataroot, 'results/')
121
+ os.makedirs(opt.results_dir, exist_ok=True)
122
+
123
+ shutil.copy(image.name, opt.dataroot)
124
+
125
+ data_loader = CreateDataLoader(opt)
126
+ dataset = data_loader.load_data()
127
+
128
+ imgs = [image.name]
129
+ # test
130
+ # model.eval()
131
+ for i, data in enumerate(dataset):
132
+ if i >= opt.how_many: # test code only supports batch_size = 1, how_many means how many test images to run
133
+ break
134
+ model.set_input(data)
135
+ model.test()
136
+ visuals = model.get_current_visuals() # in test the loadSize is set to the same as fineSize
137
+ img_path = model.get_image_paths()
138
+ # if i % 5 == 0:
139
+ # print('processing (%04d)-th image... %s' % (i, img_path))
140
+ imgs = save_images2(opt.results_dir, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
141
+
142
+ print(imgs)
143
+ return PIL.Image.open(imgs[0])
144
+
145
+
146
+ def main():
147
+ gr.close_all()
148
+
149
+ args = parse_args()
150
+
151
+ checkpoint_dir = load_checkpoint()
152
+
153
+ opt = TestOptions().parse()
154
+ opt.num_threads = 1 # test code only supports num_threads = 1
155
+ opt.batch_size = 1 # test code only supports batch_size = 1
156
+ opt.serial_batches = True # no shuffle
157
+ opt.no_flip = True # no flip
158
+ opt.display_id = -1 # no visdom display
159
+
160
+ '''
161
+ python test.py --dataroot dataset/test_single --name apdrawinggan++_author --model test --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-single
162
+ '''
163
+ opt.dataroot = 'dataset/test_single'
164
+ opt.name = 'apdrawinggan++_author'
165
+ opt.model = 'test'
166
+ opt.use_resnet = True
167
+ opt.netG = 'resnet_9blocks'
168
+ opt.which_epoch = 150
169
+ opt.how_many = 1000
170
+ opt.gpu_ids = -1
171
+ opt.gpu_ids_p = -1
172
+ opt.imagefolder = 'images-single'
173
+
174
+ opt.checkpoints_dir = checkpoint_dir
175
+
176
+
177
+ model = create_model(opt)
178
+ model.setup(opt)
179
+
180
+ func = functools.partial(run, model=model, opt=opt)
181
+ func = functools.update_wrapper(func, run)
182
+
183
+
184
+ gr.Interface(
185
+ func,
186
+ [
187
+ gr.inputs.Image(type='file', label='Input Image'),
188
+ ],
189
+ [
190
+ gr.outputs.Image(
191
+ type='pil',
192
+ label='Result'),
193
+ ],
194
+ #examples=examples,
195
+ theme=args.theme,
196
+ title=TITLE,
197
+ description=DESCRIPTION,
198
+ article=ARTICLE,
199
+ allow_screenshot=args.allow_screenshot,
200
+ allow_flagging=args.allow_flagging,
201
+ live=args.live,
202
+ ).launch(
203
+ enable_queue=args.enable_queue,
204
+ server_port=args.port,
205
+ share=args.share,
206
+ )
207
+
208
+
209
+ if __name__ == '__main__':
210
+ main()
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=0.4.0
2
+ torchvision>=0.2.1
3
+ dominate>=2.3.1
4
+ visdom>=0.1.8.3
5
+ scipy>=1.1.0
6
+ numpy>=1.14.1
7
+ Pillow>=5.0.0
8
+ opencv-python>=3.4.2