ethanNeuralImage commited on
Commit
c85e4eb
1 Parent(s): 5238ef9

Adding in metrics

Browse files
app.py CHANGED
@@ -26,6 +26,10 @@ import ris.spherical_kmeans as spherical_kmeans
26
  from ris.blend import blend_latents
27
  from ris.model import Generator as RIS_Generator
28
 
 
 
 
 
29
  from PIL import Image
30
 
31
  opts_args = ['--no_fine_mapper']
@@ -70,6 +74,11 @@ ris_gen = RIS_Generator(1024, 512, 8, channel_multiplier=2).to(device).eval()
70
  ris_ckpt = torch.load('./pretrained_models/ris/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
71
  ris_gen.load_state_dict(ris_ckpt['g_ema'], strict=False)
72
 
 
 
 
 
 
73
 
74
  with gr.Blocks() as demo:
75
  with gr.Row() as row:
@@ -104,14 +113,14 @@ with gr.Blocks() as demo:
104
  output_hyperstyle_gd = gr.Image(type='pil', label="Hyperstyle Global Directions", visible=False)
105
  output_hyperstyle_ris = gr.Image(type='pil', label='Hyperstyle RIS', visible=False)
106
  with gr.Row() as hyperstyle_metrics:
107
- output_hypersyle_metrics = gr.Text()
108
  with gr.Row(visible=False) as e4e_images:
109
  output_e4e_invert = gr.Image(type='pil', label="E4E Inverted", visible=False)
110
  output_e4e_mapper = gr.Image(type='pil', label="E4E Mapper")
111
  output_e4e_gd = gr.Image(type='pil', label="E4E Global Directions", visible=False)
112
  output_e4e_ris = gr.Image(type='pil', label='E4E RIS', visible=False)
113
- with gr.Row() as e4e_metrics:
114
- output_e4e_metrics = gr.Text()
115
  def n_iter_change(number):
116
  if number < 0:
117
  return 0
@@ -124,7 +133,9 @@ with gr.Blocks() as demo:
124
  hyperstyle_bool = 'Hyperstyle' in bools
125
  return {
126
  hyperstyle_images: gr.update(visible=hyperstyle_bool),
 
127
  e4e_images: gr.update(visible=e4e_bool),
 
128
  n_hyperstyle_iterations: gr.update(visible=hyperstyle_bool)
129
  }
130
  def outp_toggles(bool):
@@ -153,7 +164,7 @@ with gr.Blocks() as demo:
153
 
154
  n_hyperstyle_iterations.change(n_iter_change, n_hyperstyle_iterations, n_hyperstyle_iterations)
155
  mapper_choice.change(mapper_change, mapper_choice, [target_text])
156
- inverter_bools.change(inverter_toggles, inverter_bools, [hyperstyle_images, e4e_images, n_hyperstyle_iterations])
157
  invert_bool.change(outp_toggles, invert_bool, [output_hyperstyle_invert, output_e4e_invert])
158
  mapper_bool.change(mapper_toggles, mapper_bool, [mapper_opts, output_hyperstyle_mapper, output_e4e_mapper])
159
  gd_bool.change(gd_toggles, gd_bool, [gd_opts, output_hyperstyle_gd, output_e4e_gd])
@@ -173,6 +184,17 @@ with gr.Blocks() as demo:
173
  randomize_noise=False, truncation=1, weights_deltas=weight_deltas)
174
  result_batch = (x_hat, w_hat)
175
  return result_batch
 
 
 
 
 
 
 
 
 
 
 
176
  def submit(
177
  src, align_img, inverter_bools, n_iterations, invert_bool,
178
  mapper_bool, mapper_choice, mapper_alpha,
@@ -188,6 +210,7 @@ with gr.Blocks() as demo:
188
  mapper = StyleCLIPMapper(mapper_args)
189
  mapper.eval()
190
  mapper.to(device)
 
191
  with torch.no_grad():
192
  output_imgs = []
193
  if align_img:
@@ -208,7 +231,7 @@ with gr.Blocks() as demo:
208
  else:
209
  ref_input = Image.open(src).convert('RGB')
210
  ref_input = im2tensor_transforms(ref_input).to(device)
211
-
212
  if 'Hyperstyle' in inverter_bools:
213
  hyperstyle_batch, hyperstyle_latents, hyperstyle_deltas, _ = run_inversion(input_img.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
214
  if invert_bool:
@@ -217,13 +240,19 @@ with gr.Blocks() as demo:
217
  invert_hyperstyle = None
218
  if mapper_bool:
219
  mapped_hyperstyle, _ = map_latent(mapper, hyperstyle_latents, stylespace=False, weight_deltas=hyperstyle_deltas, strength=mapper_alpha)
 
220
  mapped_hyperstyle = tensor2im(mapped_hyperstyle[0])
 
 
221
  else:
222
  mapped_hyperstyle = None
223
 
224
  if gd_bool:
225
- gd_hyperstyle = edit_image(_, hyperstyle_latents[0], hyperstyle.decoder, direction_calculator, opts, hyperstyle_deltas)[0]
226
- gd_hyperstyle = tensor2im(gd_hyperstyle)
 
 
 
227
  else:
228
  gd_hyperstyle = None
229
 
@@ -237,10 +266,11 @@ with gr.Blocks() as demo:
237
  else:
238
  ris_hyperstyle=None
239
 
240
- hyperstyle_output = [invert_hyperstyle, mapped_hyperstyle,gd_hyperstyle, ris_hyperstyle]
241
  else:
242
- hyperstyle_output = [None, None, None, None]
243
  output_imgs.extend(hyperstyle_output)
 
244
  if 'E4E' in inverter_bools:
245
  e4e_batch, e4e_latents = hyperstyle.w_invert(input_img.unsqueeze(0))
246
  e4e_deltas = None
@@ -250,13 +280,21 @@ with gr.Blocks() as demo:
250
  invert_e4e = None
251
  if mapper_bool:
252
  mapped_e4e, _ = map_latent(mapper, e4e_latents, stylespace=False, weight_deltas=e4e_deltas, strength=mapper_alpha)
 
253
  mapped_e4e = tensor2im(mapped_e4e[0])
 
 
 
254
  else:
255
  mapped_e4e = None
256
 
257
  if gd_bool:
258
- gd_e4e = edit_image(_, e4e_latents[0], hyperstyle.decoder, direction_calculator, opts, e4e_deltas)[0]
259
- gd_e4e = tensor2im(gd_e4e)
 
 
 
 
260
  else:
261
  gd_e4e = None
262
 
@@ -270,9 +308,9 @@ with gr.Blocks() as demo:
270
  else:
271
  ris_e4e=None
272
 
273
- e4e_output = [invert_e4e, mapped_e4e, gd_e4e, ris_e4e]
274
  else:
275
- e4e_output = [None, None, None, None]
276
  output_imgs.extend(e4e_output)
277
  return output_imgs
278
  submit_button.click(
@@ -283,8 +321,8 @@ with gr.Blocks() as demo:
283
  gd_bool, neutral_text, target_text, alpha, beta,
284
  ris_bool, ref_img
285
  ],
286
- [output_hyperstyle_invert, output_hyperstyle_mapper, output_hyperstyle_gd, output_hyperstyle_ris,
287
- output_e4e_invert, output_e4e_mapper, output_e4e_gd, output_e4e_ris]
288
  )
289
 
290
  demo.launch()
 
26
  from ris.blend import blend_latents
27
  from ris.model import Generator as RIS_Generator
28
 
29
+ from metrics import FaceMetric
30
+ from metrics.criteria.clip_loss import CLIPLoss
31
+ import clip
32
+
33
  from PIL import Image
34
 
35
  opts_args = ['--no_fine_mapper']
 
74
  ris_ckpt = torch.load('./pretrained_models/ris/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
75
  ris_gen.load_state_dict(ris_ckpt['g_ema'], strict=False)
76
 
77
+ lpips_metric = FaceMetric(metric_type='lpips', device=device)
78
+ ssim_metric = FaceMetric(metric_type='ms-ssim', device=device)
79
+ id_metric = FaceMetric(metric_type='id', device=device)
80
+ clip_hair = FaceMetric(metric_type='cliphair', device=device)
81
+ clip_text = CLIPLoss(hyperstyle_args)
82
 
83
  with gr.Blocks() as demo:
84
  with gr.Row() as row:
 
113
  output_hyperstyle_gd = gr.Image(type='pil', label="Hyperstyle Global Directions", visible=False)
114
  output_hyperstyle_ris = gr.Image(type='pil', label='Hyperstyle RIS', visible=False)
115
  with gr.Row() as hyperstyle_metrics:
116
+ output_hypersyle_metrics = gr.Text(label='Hyperstyle Metrics')
117
  with gr.Row(visible=False) as e4e_images:
118
  output_e4e_invert = gr.Image(type='pil', label="E4E Inverted", visible=False)
119
  output_e4e_mapper = gr.Image(type='pil', label="E4E Mapper")
120
  output_e4e_gd = gr.Image(type='pil', label="E4E Global Directions", visible=False)
121
  output_e4e_ris = gr.Image(type='pil', label='E4E RIS', visible=False)
122
+ with gr.Row(visible=False) as e4e_metrics:
123
+ output_e4e_metrics = gr.Text(label='E4E Metrics')
124
  def n_iter_change(number):
125
  if number < 0:
126
  return 0
 
133
  hyperstyle_bool = 'Hyperstyle' in bools
134
  return {
135
  hyperstyle_images: gr.update(visible=hyperstyle_bool),
136
+ hyperstyle_metrics: gr.update(visible=hyperstyle_bool),
137
  e4e_images: gr.update(visible=e4e_bool),
138
+ e4e_metrics: gr.update(visible=e4e_bool),
139
  n_hyperstyle_iterations: gr.update(visible=hyperstyle_bool)
140
  }
141
  def outp_toggles(bool):
 
164
 
165
  n_hyperstyle_iterations.change(n_iter_change, n_hyperstyle_iterations, n_hyperstyle_iterations)
166
  mapper_choice.change(mapper_change, mapper_choice, [target_text])
167
+ inverter_bools.change(inverter_toggles, inverter_bools, [hyperstyle_images, hyperstyle_metrics, e4e_images, e4e_metrics, n_hyperstyle_iterations])
168
  invert_bool.change(outp_toggles, invert_bool, [output_hyperstyle_invert, output_e4e_invert])
169
  mapper_bool.change(mapper_toggles, mapper_bool, [mapper_opts, output_hyperstyle_mapper, output_e4e_mapper])
170
  gd_bool.change(gd_toggles, gd_bool, [gd_opts, output_hyperstyle_gd, output_e4e_gd])
 
184
  randomize_noise=False, truncation=1, weights_deltas=weight_deltas)
185
  result_batch = (x_hat, w_hat)
186
  return result_batch
187
+ def run_metrics(base_img, edited_img):
188
+ lpips_score = lpips_metric(base_img, edited_img)[0]
189
+ ssim_score = ssim_metric(base_img, edited_img)[0]
190
+ id_score = id_metric(base_img, edited_img)[0]
191
+
192
+ return lpips_score, ssim_score, id_score
193
+ def clip_text_metric(tensor, text):
194
+ clip_embed = torch.cat([clip.tokenize(text)]).cuda()
195
+ clip_score = 1-clip_text(tensor.unsqueeze(0), clip_embed).item()
196
+ return clip_score
197
+
198
  def submit(
199
  src, align_img, inverter_bools, n_iterations, invert_bool,
200
  mapper_bool, mapper_choice, mapper_alpha,
 
210
  mapper = StyleCLIPMapper(mapper_args)
211
  mapper.eval()
212
  mapper.to(device)
213
+ resize_to = (256, 256) if hyperstyle_args.resize_outputs else (hyperstyle_args.output_size, hyperstyle_args.output_size)
214
  with torch.no_grad():
215
  output_imgs = []
216
  if align_img:
 
231
  else:
232
  ref_input = Image.open(src).convert('RGB')
233
  ref_input = im2tensor_transforms(ref_input).to(device)
234
+ hyperstyle_metrics_text = ''
235
  if 'Hyperstyle' in inverter_bools:
236
  hyperstyle_batch, hyperstyle_latents, hyperstyle_deltas, _ = run_inversion(input_img.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
237
  if invert_bool:
 
240
  invert_hyperstyle = None
241
  if mapper_bool:
242
  mapped_hyperstyle, _ = map_latent(mapper, hyperstyle_latents, stylespace=False, weight_deltas=hyperstyle_deltas, strength=mapper_alpha)
243
+ #clip_score = clip_text_metric(mapped_hyperstyle[0], mapper_args.description)
244
  mapped_hyperstyle = tensor2im(mapped_hyperstyle[0])
245
+ #lpips_score, ssim_score, id_score = run_metrics(invert_hyperstyle.resize(resize_to), mapped_hyperstyle.resize(resize_to))
246
+ #hyperstyle_metrics_text += f'Mapper Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
247
  else:
248
  mapped_hyperstyle = None
249
 
250
  if gd_bool:
251
+ gd_hyperstyle = edit_image(_, hyperstyle_latents[0], hyperstyle.decoder, direction_calculator, opts, hyperstyle_deltas)
252
+ #clip_score = clip_text_metric(gd_hyperstyle[0], opts.target_text)
253
+ gd_hyperstyle = tensor2im(gd_hyperstyle[0])
254
+ #lpips_score, ssim_score, id_score = run_metrics(invert_hyperstyle.resize(resize_to), gd_hyperstyle.resize(resize_to))
255
+ #hyperstyle_metrics_text += f'Global Direction Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
256
  else:
257
  gd_hyperstyle = None
258
 
 
266
  else:
267
  ris_hyperstyle=None
268
 
269
+ hyperstyle_output = [invert_hyperstyle, mapped_hyperstyle,gd_hyperstyle, ris_hyperstyle, hyperstyle_metrics_text]
270
  else:
271
+ hyperstyle_output = [None, None, None, None, hyperstyle_metrics_text]
272
  output_imgs.extend(hyperstyle_output)
273
+ e4e_metrics_text = ''
274
  if 'E4E' in inverter_bools:
275
  e4e_batch, e4e_latents = hyperstyle.w_invert(input_img.unsqueeze(0))
276
  e4e_deltas = None
 
280
  invert_e4e = None
281
  if mapper_bool:
282
  mapped_e4e, _ = map_latent(mapper, e4e_latents, stylespace=False, weight_deltas=e4e_deltas, strength=mapper_alpha)
283
+ #clip_score = clip_text_metric(mapped_e4e[0], mapper_args.description)
284
  mapped_e4e = tensor2im(mapped_e4e[0])
285
+ #lpips_score, ssim_score, id_score = run_metrics(invert_e4e, mapped_e4e)
286
+ #e4e_metrics_text += f'Mapper Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
287
+
288
  else:
289
  mapped_e4e = None
290
 
291
  if gd_bool:
292
+ gd_e4e = edit_image(_, e4e_latents[0], hyperstyle.decoder, direction_calculator, opts, e4e_deltas)
293
+ clip_score = clip_text_metric(gd_e4e[0], opts.target_text)
294
+ gd_e4e = tensor2im(gd_e4e[0])
295
+ lpips_score, ssim_score, id_score = run_metrics(invert_e4e, gd_e4e)
296
+ e4e_metrics_text += f'Global Direction Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
297
+
298
  else:
299
  gd_e4e = None
300
 
 
308
  else:
309
  ris_e4e=None
310
 
311
+ e4e_output = [invert_e4e, mapped_e4e, gd_e4e, ris_e4e, e4e_metrics_text]
312
  else:
313
+ e4e_output = [None, None, None, None, e4e_metrics_text]
314
  output_imgs.extend(e4e_output)
315
  return output_imgs
316
  submit_button.click(
 
321
  gd_bool, neutral_text, target_text, alpha, beta,
322
  ris_bool, ref_img
323
  ],
324
+ [output_hyperstyle_invert, output_hyperstyle_mapper, output_hyperstyle_gd, output_hyperstyle_ris, output_hypersyle_metrics,
325
+ output_e4e_invert, output_e4e_mapper, output_e4e_gd, output_e4e_ris, output_e4e_metrics]
326
  )
327
 
328
  demo.launch()
metrics/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .face_eval import FaceMetric
metrics/criteria/__init__.py ADDED
File without changes
metrics/criteria/clip_loss.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import clip
4
+
5
+
6
+ class CLIPLoss(torch.nn.Module):
7
+
8
+ def __init__(self, opts):
9
+ super(CLIPLoss, self).__init__()
10
+ self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
11
+ self.upsample = torch.nn.Upsample(scale_factor=7)
12
+ self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)
13
+
14
+ def forward(self, image, text):
15
+ image = self.avg_pool(self.upsample(image))
16
+ similarity = 1 - self.model(image, text)[0] / 100
17
+ return similarity
metrics/criteria/id_loss.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from models.facial_recognition.model_irse import Backbone
5
+
6
+
7
+ class IDLoss(nn.Module):
8
+ def __init__(self, opts):
9
+ super(IDLoss, self).__init__()
10
+ print('Loading ResNet ArcFace')
11
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
12
+ self.facenet.load_state_dict(torch.load(opts.ir_se50_weights))
13
+ self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
14
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
15
+ self.facenet.eval()
16
+ self.facenet.cuda()
17
+ self.opts = opts
18
+
19
+ def extract_feats(self, x):
20
+ if x.shape[2] != 256:
21
+ x = self.pool(x)
22
+ x = x[:, :, 35:223, 32:220] # Crop interesting region
23
+ x = self.face_pool(x)
24
+ x_feats = self.facenet(x)
25
+ return x_feats
26
+
27
+ def forward(self, y_hat, y):
28
+ n_samples = y.shape[0]
29
+ y_feats = self.extract_feats(y) # Otherwise use the feature from there
30
+ y_hat_feats = self.extract_feats(y_hat)
31
+ y_feats = y_feats.detach()
32
+ loss = 0
33
+ sim_improvement = 0
34
+ count = 0
35
+ for i in range(n_samples):
36
+ diff_target = y_hat_feats[i].dot(y_feats[i])
37
+ loss += 1 - diff_target
38
+ count += 1
39
+
40
+ return loss / count, sim_improvement / count
metrics/criteria/parse_related_loss/average_lab_color_loss.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from criteria.parse_related_loss.unet import unet
4
+
5
+ class AvgLabLoss(nn.Module):
6
+ def __init__(self, opts):
7
+ super(AvgLabLoss, self).__init__()
8
+ self.criterion = nn.L1Loss()
9
+ self.M = torch.tensor([[0.412453, 0.357580, 0.180423], [0.212671, 0.715160, 0.072169], [0.019334, 0.119193, 0.950227]])
10
+ print('Loading UNet for AvgLabLoss')
11
+ self.parsenet = unet()
12
+ self.parsenet.load_state_dict(torch.load(opts.parsenet_weights))
13
+ self.parsenet.eval()
14
+ self.shrink = torch.nn.AdaptiveAvgPool2d((512, 512))
15
+ self.magnify = torch.nn.AdaptiveAvgPool2d((1024, 1024))
16
+
17
+ def gen_hair_mask(self, input_image):
18
+ labels_predict = self.parsenet(self.shrink(input_image)).detach()
19
+ mask_512 = (torch.unsqueeze(torch.max(labels_predict, 1)[1], 1)==13).float()
20
+ mask_1024 = self.magnify(mask_512)
21
+ return mask_1024
22
+
23
+ # cal lab written by liuqk
24
+ def f(self, input):
25
+ output = input * 1
26
+ mask = input > 0.008856
27
+ output[mask] = torch.pow(input[mask], 1 / 3)
28
+ output[~mask] = 7.787 * input[~mask] + 0.137931
29
+ return output
30
+
31
+ def rgb2xyz(self, input):
32
+ assert input.size(1) == 3
33
+ M_tmp = self.M.to(input.device).unsqueeze(0)
34
+ M_tmp = M_tmp.repeat(input.size(0), 1, 1) # BxCxC
35
+ output = torch.einsum('bnc,bchw->bnhw', M_tmp, input) # BxCxHxW
36
+ M_tmp = M_tmp.sum(dim=2, keepdim=True) # BxCx1
37
+ M_tmp = M_tmp.unsqueeze(3) # BxCx1x1
38
+ return output / M_tmp
39
+
40
+ def xyz2lab(self, input):
41
+ assert input.size(1) == 3
42
+ output = input * 1
43
+ xyz_f = self.f(input)
44
+ # compute l
45
+ mask = input[:, 1, :, :] > 0.008856
46
+ output[:, 0, :, :][mask] = 116 * xyz_f[:, 1, :, :][mask] - 16
47
+ output[:, 0, :, :][~mask] = 903.3 * input[:, 1, :, :][~mask]
48
+ # compute a
49
+ output[:, 1, :, :] = 500 * (xyz_f[:, 0, :, :] - xyz_f[:, 1, :, :])
50
+ # compute b
51
+ output[:, 2, :, :] = 200 * (xyz_f[:, 1, :, :] - xyz_f[:, 2, :, :])
52
+ return output
53
+ def cal_hair_avg(self, input, mask):
54
+ x = input * mask
55
+ sum = torch.sum(torch.sum(x, dim=2, keepdim=True), dim=3, keepdim=True) # [n,3,1,1]
56
+ mask_sum = torch.sum(torch.sum(mask, dim=2, keepdim=True), dim=3, keepdim=True) # [n,1,1,1]
57
+ mask_sum[mask_sum == 0] = 1
58
+ avg = sum / mask_sum
59
+ return avg
60
+
61
+ def forward(self, fake, real):
62
+ # the mask is [n,1,h,w]
63
+ # normalize to 0~1
64
+ mask_fake = self.gen_hair_mask(fake)
65
+ mask_real = self.gen_hair_mask(real)
66
+ fake_RGB = (fake + 1) / 2.0
67
+ real_RGB = (real + 1) / 2.0
68
+ # from RGB to Lab by liuqk
69
+ fake_xyz = self.rgb2xyz(fake_RGB)
70
+ fake_Lab = self.xyz2lab(fake_xyz)
71
+ real_xyz = self.rgb2xyz(real_RGB)
72
+ real_Lab = self.xyz2lab(real_xyz)
73
+ # cal average value
74
+ fake_Lab_avg = self.cal_hair_avg(fake_Lab, mask_fake)
75
+ real_Lab_avg = self.cal_hair_avg(real_Lab, mask_real)
76
+
77
+ loss = self.criterion(fake_Lab_avg, real_Lab_avg)
78
+ return loss
metrics/criteria/parse_related_loss/bg_loss.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from criteria.parse_related_loss.unet import unet
4
+
5
+ class BackgroundLoss(nn.Module):
6
+ def __init__(self, opts):
7
+ super(BackgroundLoss, self).__init__()
8
+ print('Loading UNet for Background Loss')
9
+ self.parsenet = unet()
10
+ self.parsenet.load_state_dict(torch.load(opts.parsenet_weights))
11
+ self.parsenet.eval()
12
+ self.bg_mask_l2_loss = torch.nn.MSELoss()
13
+ self.shrink = torch.nn.AdaptiveAvgPool2d((512, 512))
14
+ self.magnify = torch.nn.AdaptiveAvgPool2d((1024, 1024))
15
+
16
+
17
+ def gen_bg_mask(self, input_image):
18
+ labels_predict = self.parsenet(self.shrink(input_image)).detach()
19
+ mask_512 = (torch.unsqueeze(torch.max(labels_predict, 1)[1], 1)!=13).float()
20
+ mask_1024 = self.magnify(mask_512)
21
+ return mask_1024
22
+
23
+ def forward(self, x, x_hat):
24
+ x_bg_mask = self.gen_bg_mask(x)
25
+ x_hat_bg_mask = self.gen_bg_mask(x_hat)
26
+ bg_mask = ((x_bg_mask+x_hat_bg_mask)==2).float()
27
+ loss = self.bg_mask_l2_loss(x * bg_mask, x_hat * bg_mask) / self.bg_mask_l2_loss(bg_mask, torch.zeros_like(bg_mask))
28
+ return loss
29
+
metrics/criteria/parse_related_loss/model_utils.py ADDED
@@ -0,0 +1,851 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+
6
+ class conv2DBatchNorm(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_channels,
10
+ n_filters,
11
+ k_size,
12
+ stride,
13
+ padding,
14
+ bias=True,
15
+ dilation=1,
16
+ is_batchnorm=True,
17
+ ):
18
+ super(conv2DBatchNorm, self).__init__()
19
+
20
+ conv_mod = nn.Conv2d(int(in_channels),
21
+ int(n_filters),
22
+ kernel_size=k_size,
23
+ padding=padding,
24
+ stride=stride,
25
+ bias=bias,
26
+ dilation=dilation,)
27
+
28
+ if is_batchnorm:
29
+ self.cb_unit = nn.Sequential(conv_mod, nn.BatchNorm2d(int(n_filters)))
30
+ else:
31
+ self.cb_unit = nn.Sequential(conv_mod)
32
+
33
+ def forward(self, inputs):
34
+ outputs = self.cb_unit(inputs)
35
+ return outputs
36
+
37
+
38
+ class conv2DGroupNorm(nn.Module):
39
+ def __init__(
40
+ self,
41
+ in_channels,
42
+ n_filters,
43
+ k_size,
44
+ stride,
45
+ padding,
46
+ bias=True,
47
+ dilation=1,
48
+ n_groups=16,
49
+ ):
50
+ super(conv2DGroupNorm, self).__init__()
51
+
52
+ conv_mod = nn.Conv2d(int(in_channels),
53
+ int(n_filters),
54
+ kernel_size=k_size,
55
+ padding=padding,
56
+ stride=stride,
57
+ bias=bias,
58
+ dilation=dilation,)
59
+
60
+ self.cg_unit = nn.Sequential(conv_mod,
61
+ nn.GroupNorm(n_groups, int(n_filters)))
62
+
63
+ def forward(self, inputs):
64
+ outputs = self.cg_unit(inputs)
65
+ return outputs
66
+
67
+
68
+ class deconv2DBatchNorm(nn.Module):
69
+ def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True):
70
+ super(deconv2DBatchNorm, self).__init__()
71
+
72
+ self.dcb_unit = nn.Sequential(
73
+ nn.ConvTranspose2d(
74
+ int(in_channels),
75
+ int(n_filters),
76
+ kernel_size=k_size,
77
+ padding=padding,
78
+ stride=stride,
79
+ bias=bias,
80
+ ),
81
+ nn.BatchNorm2d(int(n_filters)),
82
+ )
83
+
84
+ def forward(self, inputs):
85
+ outputs = self.dcb_unit(inputs)
86
+ return outputs
87
+
88
+
89
+ class conv2DBatchNormRelu(nn.Module):
90
+ def __init__(
91
+ self,
92
+ in_channels,
93
+ n_filters,
94
+ k_size,
95
+ stride,
96
+ padding,
97
+ bias=True,
98
+ dilation=1,
99
+ is_batchnorm=True,
100
+ ):
101
+ super(conv2DBatchNormRelu, self).__init__()
102
+
103
+ conv_mod = nn.Conv2d(int(in_channels),
104
+ int(n_filters),
105
+ kernel_size=k_size,
106
+ padding=padding,
107
+ stride=stride,
108
+ bias=bias,
109
+ dilation=dilation,)
110
+
111
+ if is_batchnorm:
112
+ self.cbr_unit = nn.Sequential(conv_mod,
113
+ nn.BatchNorm2d(int(n_filters)),
114
+ nn.ReLU(inplace=True))
115
+ else:
116
+ self.cbr_unit = nn.Sequential(conv_mod, nn.ReLU(inplace=True))
117
+
118
+ def forward(self, inputs):
119
+ outputs = self.cbr_unit(inputs)
120
+ return outputs
121
+
122
+
123
+ class conv2DGroupNormRelu(nn.Module):
124
+ def __init__(
125
+ self,
126
+ in_channels,
127
+ n_filters,
128
+ k_size,
129
+ stride,
130
+ padding,
131
+ bias=True,
132
+ dilation=1,
133
+ n_groups=16,
134
+ ):
135
+ super(conv2DGroupNormRelu, self).__init__()
136
+
137
+ conv_mod = nn.Conv2d(int(in_channels),
138
+ int(n_filters),
139
+ kernel_size=k_size,
140
+ padding=padding,
141
+ stride=stride,
142
+ bias=bias,
143
+ dilation=dilation,)
144
+
145
+ self.cgr_unit = nn.Sequential(conv_mod,
146
+ nn.GroupNorm(n_groups, int(n_filters)),
147
+ nn.ReLU(inplace=True))
148
+
149
+ def forward(self, inputs):
150
+ outputs = self.cgr_unit(inputs)
151
+ return outputs
152
+
153
+
154
+
155
+ class deconv2DBatchNormRelu(nn.Module):
156
+ def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True):
157
+ super(deconv2DBatchNormRelu, self).__init__()
158
+
159
+ self.dcbr_unit = nn.Sequential(
160
+ nn.ConvTranspose2d(
161
+ int(in_channels),
162
+ int(n_filters),
163
+ kernel_size=k_size,
164
+ padding=padding,
165
+ stride=stride,
166
+ bias=bias,
167
+ ),
168
+ nn.BatchNorm2d(int(n_filters)),
169
+ nn.ReLU(inplace=True),
170
+ )
171
+
172
+ def forward(self, inputs):
173
+ outputs = self.dcbr_unit(inputs)
174
+ return outputs
175
+
176
+
177
+ class unetConv2(nn.Module):
178
+ def __init__(self, in_size, out_size, is_batchnorm):
179
+ super(unetConv2, self).__init__()
180
+
181
+ if is_batchnorm:
182
+ self.conv1 = nn.Sequential(
183
+ nn.Conv2d(in_size, out_size, 3, 1, 1),
184
+ nn.BatchNorm2d(out_size),
185
+ nn.ReLU(),
186
+ )
187
+ self.conv2 = nn.Sequential(
188
+ nn.Conv2d(out_size, out_size, 3, 1, 1),
189
+ nn.BatchNorm2d(out_size),
190
+ nn.ReLU(),
191
+ )
192
+ else:
193
+ self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 1), nn.ReLU())
194
+ self.conv2 = nn.Sequential(
195
+ nn.Conv2d(out_size, out_size, 3, 1, 1), nn.ReLU()
196
+ )
197
+
198
+ def forward(self, inputs):
199
+ outputs = self.conv1(inputs)
200
+ #print (outputs.shape)
201
+ outputs = self.conv2(outputs)
202
+ #print (outputs.shape)
203
+ return outputs
204
+
205
+
206
+ class unetUp(nn.Module):
207
+ def __init__(self, in_size, out_size, is_deconv, is_batchnorm):
208
+ super(unetUp, self).__init__()
209
+ self.conv = unetConv2(in_size, out_size, is_batchnorm)
210
+ if is_deconv:
211
+ self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
212
+ else:
213
+ self.up = nn.UpsamplingBilinear2d(scale_factor=2)
214
+
215
+ def forward(self, inputs1, inputs2):
216
+ outputs2 = self.up(inputs2)
217
+ offset = outputs2.size()[2] - inputs1.size()[2]
218
+ padding = 2 * [offset // 2, offset // 2]
219
+ outputs1 = F.pad(inputs1, padding)
220
+
221
+ return self.conv(torch.cat([outputs1, outputs2], 1))
222
+
223
+
224
+ class segnetDown2(nn.Module):
225
+ def __init__(self, in_size, out_size):
226
+ super(segnetDown2, self).__init__()
227
+ self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)
228
+ self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1)
229
+ self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)
230
+
231
+ def forward(self, inputs):
232
+ outputs = self.conv1(inputs)
233
+ outputs = self.conv2(outputs)
234
+ unpooled_shape = outputs.size()
235
+ outputs, indices = self.maxpool_with_argmax(outputs)
236
+ return outputs, indices, unpooled_shape
237
+
238
+
239
+ class segnetDown3(nn.Module):
240
+ def __init__(self, in_size, out_size):
241
+ super(segnetDown3, self).__init__()
242
+ self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)
243
+ self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1)
244
+ self.conv3 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1)
245
+ self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)
246
+
247
+ def forward(self, inputs):
248
+ outputs = self.conv1(inputs)
249
+ outputs = self.conv2(outputs)
250
+ outputs = self.conv3(outputs)
251
+ unpooled_shape = outputs.size()
252
+ outputs, indices = self.maxpool_with_argmax(outputs)
253
+ return outputs, indices, unpooled_shape
254
+
255
+
256
+ class segnetUp2(nn.Module):
257
+ def __init__(self, in_size, out_size):
258
+ super(segnetUp2, self).__init__()
259
+ self.unpool = nn.MaxUnpool2d(2, 2)
260
+ self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1)
261
+ self.conv2 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)
262
+
263
+ def forward(self, inputs, indices, output_shape):
264
+ outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape)
265
+ outputs = self.conv1(outputs)
266
+ outputs = self.conv2(outputs)
267
+ return outputs
268
+
269
+
270
+ class segnetUp3(nn.Module):
271
+ def __init__(self, in_size, out_size):
272
+ super(segnetUp3, self).__init__()
273
+ self.unpool = nn.MaxUnpool2d(2, 2)
274
+ self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1)
275
+ self.conv2 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1)
276
+ self.conv3 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)
277
+
278
+ def forward(self, inputs, indices, output_shape):
279
+ outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape)
280
+ outputs = self.conv1(outputs)
281
+ outputs = self.conv2(outputs)
282
+ outputs = self.conv3(outputs)
283
+ return outputs
284
+
285
+
286
+ class residualBlock(nn.Module):
287
+ expansion = 1
288
+
289
+ def __init__(self, in_channels, n_filters, stride=1, downsample=None):
290
+ super(residualBlock, self).__init__()
291
+
292
+ self.convbnrelu1 = conv2DBatchNormRelu(
293
+ in_channels, n_filters, 3, stride, 1, bias=False
294
+ )
295
+ self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, bias=False)
296
+ self.downsample = downsample
297
+ self.stride = stride
298
+ self.relu = nn.ReLU(inplace=True)
299
+
300
+ def forward(self, x):
301
+ residual = x
302
+
303
+ out = self.convbnrelu1(x)
304
+ out = self.convbn2(out)
305
+
306
+ if self.downsample is not None:
307
+ residual = self.downsample(x)
308
+
309
+ out += residual
310
+ out = self.relu(out)
311
+ return out
312
+
313
+
314
+ class residualBottleneck(nn.Module):
315
+ expansion = 4
316
+
317
+ def __init__(self, in_channels, n_filters, stride=1, downsample=None):
318
+ super(residualBottleneck, self).__init__()
319
+ self.convbn1 = nn.Conv2DBatchNorm(in_channels, n_filters, k_size=1, bias=False)
320
+ self.convbn2 = nn.Conv2DBatchNorm(
321
+ n_filters, n_filters, k_size=3, padding=1, stride=stride, bias=False
322
+ )
323
+ self.convbn3 = nn.Conv2DBatchNorm(
324
+ n_filters, n_filters * 4, k_size=1, bias=False
325
+ )
326
+ self.relu = nn.ReLU(inplace=True)
327
+ self.downsample = downsample
328
+ self.stride = stride
329
+
330
+ def forward(self, x):
331
+ residual = x
332
+
333
+ out = self.convbn1(x)
334
+ out = self.convbn2(out)
335
+ out = self.convbn3(out)
336
+
337
+ if self.downsample is not None:
338
+ residual = self.downsample(x)
339
+
340
+ out += residual
341
+ out = self.relu(out)
342
+
343
+ return out
344
+
345
+
346
+ class linknetUp(nn.Module):
347
+ def __init__(self, in_channels, n_filters):
348
+ super(linknetUp, self).__init__()
349
+
350
+ # B, 2C, H, W -> B, C/2, H, W
351
+ self.convbnrelu1 = conv2DBatchNormRelu(
352
+ in_channels, n_filters / 2, k_size=1, stride=1, padding=1
353
+ )
354
+
355
+ # B, C/2, H, W -> B, C/2, H, W
356
+ self.deconvbnrelu2 = nn.deconv2DBatchNormRelu(
357
+ n_filters / 2, n_filters / 2, k_size=3, stride=2, padding=0
358
+ )
359
+
360
+ # B, C/2, H, W -> B, C, H, W
361
+ self.convbnrelu3 = conv2DBatchNormRelu(
362
+ n_filters / 2, n_filters, k_size=1, stride=1, padding=1
363
+ )
364
+
365
+ def forward(self, x):
366
+ x = self.convbnrelu1(x)
367
+ x = self.deconvbnrelu2(x)
368
+ x = self.convbnrelu3(x)
369
+ return x
370
+
371
+
372
+ class FRRU(nn.Module):
373
+ """
374
+ Full Resolution Residual Unit for FRRN
375
+ """
376
+
377
+ def __init__(self,
378
+ prev_channels,
379
+ out_channels,
380
+ scale,
381
+ group_norm=False,
382
+ n_groups=None):
383
+ super(FRRU, self).__init__()
384
+ self.scale = scale
385
+ self.prev_channels = prev_channels
386
+ self.out_channels = out_channels
387
+ self.group_norm = group_norm
388
+ self.n_groups = n_groups
389
+
390
+
391
+ if self.group_norm:
392
+ conv_unit = conv2DGroupNormRelu
393
+ self.conv1 = conv_unit(
394
+ prev_channels + 32, out_channels, k_size=3,
395
+ stride=1, padding=1, bias=False, n_groups=self.n_groups
396
+ )
397
+ self.conv2 = conv_unit(
398
+ out_channels, out_channels, k_size=3,
399
+ stride=1, padding=1, bias=False, n_groups=self.n_groups
400
+ )
401
+
402
+ else:
403
+ conv_unit = conv2DBatchNormRelu
404
+ self.conv1 = conv_unit(prev_channels + 32, out_channels, k_size=3,
405
+ stride=1, padding=1, bias=False,)
406
+ self.conv2 = conv_unit(out_channels, out_channels, k_size=3,
407
+ stride=1, padding=1, bias=False,)
408
+
409
+ self.conv_res = nn.Conv2d(out_channels, 32, kernel_size=1, stride=1, padding=0)
410
+
411
+ def forward(self, y, z):
412
+ x = torch.cat([y, nn.MaxPool2d(self.scale, self.scale)(z)], dim=1)
413
+ y_prime = self.conv1(x)
414
+ y_prime = self.conv2(y_prime)
415
+
416
+ x = self.conv_res(y_prime)
417
+ upsample_size = torch.Size([_s * self.scale for _s in y_prime.shape[-2:]])
418
+ x = F.upsample(x, size=upsample_size, mode="nearest")
419
+ z_prime = z + x
420
+
421
+ return y_prime, z_prime
422
+
423
+
424
+ class RU(nn.Module):
425
+ """
426
+ Residual Unit for FRRN
427
+ """
428
+
429
+ def __init__(self,
430
+ channels,
431
+ kernel_size=3,
432
+ strides=1,
433
+ group_norm=False,
434
+ n_groups=None):
435
+ super(RU, self).__init__()
436
+ self.group_norm = group_norm
437
+ self.n_groups = n_groups
438
+
439
+ if self.group_norm:
440
+ self.conv1 = conv2DGroupNormRelu(
441
+ channels, channels, k_size=kernel_size,
442
+ stride=strides, padding=1, bias=False,n_groups=self.n_groups)
443
+ self.conv2 = conv2DGroupNorm(
444
+ channels, channels, k_size=kernel_size,
445
+ stride=strides, padding=1, bias=False,n_groups=self.n_groups)
446
+
447
+ else:
448
+ self.conv1 = conv2DBatchNormRelu(
449
+ channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False,)
450
+ self.conv2 = conv2DBatchNorm(
451
+ channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False,)
452
+
453
+ def forward(self, x):
454
+ incoming = x
455
+ x = self.conv1(x)
456
+ x = self.conv2(x)
457
+ return x + incoming
458
+
459
+
460
+ class residualConvUnit(nn.Module):
461
+ def __init__(self, channels, kernel_size=3):
462
+ super(residualConvUnit, self).__init__()
463
+
464
+ self.residual_conv_unit = nn.Sequential(
465
+ nn.ReLU(inplace=True),
466
+ nn.Conv2d(channels, channels, kernel_size=kernel_size),
467
+ nn.ReLU(inplace=True),
468
+ nn.Conv2d(channels, channels, kernel_size=kernel_size),
469
+ )
470
+
471
+ def forward(self, x):
472
+ input = x
473
+ x = self.residual_conv_unit(x)
474
+ return x + input
475
+
476
+
477
+ class multiResolutionFusion(nn.Module):
478
+ def __init__(self, channels, up_scale_high, up_scale_low, high_shape, low_shape):
479
+ super(multiResolutionFusion, self).__init__()
480
+
481
+ self.up_scale_high = up_scale_high
482
+ self.up_scale_low = up_scale_low
483
+
484
+ self.conv_high = nn.Conv2d(high_shape[1], channels, kernel_size=3)
485
+
486
+ if low_shape is not None:
487
+ self.conv_low = nn.Conv2d(low_shape[1], channels, kernel_size=3)
488
+
489
+ def forward(self, x_high, x_low):
490
+ high_upsampled = F.upsample(
491
+ self.conv_high(x_high), scale_factor=self.up_scale_high, mode="bilinear"
492
+ )
493
+
494
+ if x_low is None:
495
+ return high_upsampled
496
+
497
+ low_upsampled = F.upsample(
498
+ self.conv_low(x_low), scale_factor=self.up_scale_low, mode="bilinear"
499
+ )
500
+
501
+ return low_upsampled + high_upsampled
502
+
503
+
504
+ class chainedResidualPooling(nn.Module):
505
+ def __init__(self, channels, input_shape):
506
+ super(chainedResidualPooling, self).__init__()
507
+
508
+ self.chained_residual_pooling = nn.Sequential(
509
+ nn.ReLU(inplace=True),
510
+ nn.MaxPool2d(5, 1, 2),
511
+ nn.Conv2d(input_shape[1], channels, kernel_size=3),
512
+ )
513
+
514
+ def forward(self, x):
515
+ input = x
516
+ x = self.chained_residual_pooling(x)
517
+ return x + input
518
+
519
+
520
+ class pyramidPooling(nn.Module):
521
+ def __init__(
522
+ self,
523
+ in_channels,
524
+ pool_sizes,
525
+ model_name="pspnet",
526
+ fusion_mode="cat",
527
+ is_batchnorm=True,
528
+ ):
529
+ super(pyramidPooling, self).__init__()
530
+
531
+ bias = not is_batchnorm
532
+
533
+ self.paths = []
534
+ for i in range(len(pool_sizes)):
535
+ self.paths.append(
536
+ conv2DBatchNormRelu(
537
+ in_channels,
538
+ int(in_channels / len(pool_sizes)),
539
+ 1,
540
+ 1,
541
+ 0,
542
+ bias=bias,
543
+ is_batchnorm=is_batchnorm,
544
+ )
545
+ )
546
+
547
+ self.path_module_list = nn.ModuleList(self.paths)
548
+ self.pool_sizes = pool_sizes
549
+ self.model_name = model_name
550
+ self.fusion_mode = fusion_mode
551
+
552
+ def forward(self, x):
553
+ h, w = x.shape[2:]
554
+
555
+ if self.training or self.model_name != "icnet": # general settings or pspnet
556
+ k_sizes = []
557
+ strides = []
558
+ for pool_size in self.pool_sizes:
559
+ k_sizes.append((int(h / pool_size), int(w / pool_size)))
560
+ strides.append((int(h / pool_size), int(w / pool_size)))
561
+ else: # eval mode and icnet: pre-trained for 1025 x 2049
562
+ k_sizes = [(8, 15), (13, 25), (17, 33), (33, 65)]
563
+ strides = [(5, 10), (10, 20), (16, 32), (33, 65)]
564
+
565
+ if self.fusion_mode == "cat": # pspnet: concat (including x)
566
+ output_slices = [x]
567
+
568
+ for i, (module, pool_size) in enumerate(
569
+ zip(self.path_module_list, self.pool_sizes)
570
+ ):
571
+ out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0)
572
+ # out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size))
573
+ if self.model_name != "icnet":
574
+ out = module(out)
575
+ out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True)
576
+ output_slices.append(out)
577
+
578
+ return torch.cat(output_slices, dim=1)
579
+ else: # icnet: element-wise sum (including x)
580
+ pp_sum = x
581
+
582
+ for i, (module, pool_size) in enumerate(
583
+ zip(self.path_module_list, self.pool_sizes)
584
+ ):
585
+ out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0)
586
+ # out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size))
587
+ if self.model_name != "icnet":
588
+ out = module(out)
589
+ out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True)
590
+ pp_sum = pp_sum + out
591
+
592
+ return pp_sum
593
+
594
+
595
+ class bottleNeckPSP(nn.Module):
596
+ def __init__(
597
+ self, in_channels, mid_channels, out_channels, stride, dilation=1, is_batchnorm=True
598
+ ):
599
+ super(bottleNeckPSP, self).__init__()
600
+
601
+ bias = not is_batchnorm
602
+
603
+ self.cbr1 = conv2DBatchNormRelu(
604
+ in_channels,
605
+ mid_channels,
606
+ 1,
607
+ stride=1,
608
+ padding=0,
609
+ bias=bias,
610
+ is_batchnorm=is_batchnorm,
611
+ )
612
+ if dilation > 1:
613
+ self.cbr2 = conv2DBatchNormRelu(
614
+ mid_channels,
615
+ mid_channels,
616
+ 3,
617
+ stride=stride,
618
+ padding=dilation,
619
+ bias=bias,
620
+ dilation=dilation,
621
+ is_batchnorm=is_batchnorm,
622
+ )
623
+ else:
624
+ self.cbr2 = conv2DBatchNormRelu(
625
+ mid_channels,
626
+ mid_channels,
627
+ 3,
628
+ stride=stride,
629
+ padding=1,
630
+ bias=bias,
631
+ dilation=1,
632
+ is_batchnorm=is_batchnorm,
633
+ )
634
+ self.cb3 = conv2DBatchNorm(
635
+ mid_channels,
636
+ out_channels,
637
+ 1,
638
+ stride=1,
639
+ padding=0,
640
+ bias=bias,
641
+ is_batchnorm=is_batchnorm,
642
+ )
643
+ self.cb4 = conv2DBatchNorm(
644
+ in_channels,
645
+ out_channels,
646
+ 1,
647
+ stride=stride,
648
+ padding=0,
649
+ bias=bias,
650
+ is_batchnorm=is_batchnorm,
651
+ )
652
+
653
+ def forward(self, x):
654
+ conv = self.cb3(self.cbr2(self.cbr1(x)))
655
+ residual = self.cb4(x)
656
+ return F.relu(conv + residual, inplace=True)
657
+
658
+
659
+ class bottleNeckIdentifyPSP(nn.Module):
660
+ def __init__(self, in_channels, mid_channels, stride, dilation=1, is_batchnorm=True):
661
+ super(bottleNeckIdentifyPSP, self).__init__()
662
+
663
+ bias = not is_batchnorm
664
+
665
+ self.cbr1 = conv2DBatchNormRelu(
666
+ in_channels,
667
+ mid_channels,
668
+ 1,
669
+ stride=1,
670
+ padding=0,
671
+ bias=bias,
672
+ is_batchnorm=is_batchnorm,
673
+ )
674
+ if dilation > 1:
675
+ self.cbr2 = conv2DBatchNormRelu(
676
+ mid_channels,
677
+ mid_channels,
678
+ 3,
679
+ stride=1,
680
+ padding=dilation,
681
+ bias=bias,
682
+ dilation=dilation,
683
+ is_batchnorm=is_batchnorm,
684
+ )
685
+ else:
686
+ self.cbr2 = conv2DBatchNormRelu(
687
+ mid_channels,
688
+ mid_channels,
689
+ 3,
690
+ stride=1,
691
+ padding=1,
692
+ bias=bias,
693
+ dilation=1,
694
+ is_batchnorm=is_batchnorm,
695
+ )
696
+ self.cb3 = conv2DBatchNorm(
697
+ mid_channels,
698
+ in_channels,
699
+ 1,
700
+ stride=1,
701
+ padding=0,
702
+ bias=bias,
703
+ is_batchnorm=is_batchnorm,
704
+ )
705
+
706
+ def forward(self, x):
707
+ residual = x
708
+ x = self.cb3(self.cbr2(self.cbr1(x)))
709
+ return F.relu(x + residual, inplace=True)
710
+
711
+
712
+ class residualBlockPSP(nn.Module):
713
+ def __init__(
714
+ self,
715
+ n_blocks,
716
+ in_channels,
717
+ mid_channels,
718
+ out_channels,
719
+ stride,
720
+ dilation=1,
721
+ include_range="all",
722
+ is_batchnorm=True,
723
+ ):
724
+ super(residualBlockPSP, self).__init__()
725
+
726
+ if dilation > 1:
727
+ stride = 1
728
+
729
+ # residualBlockPSP = convBlockPSP + identityBlockPSPs
730
+ layers = []
731
+ if include_range in ["all", "conv"]:
732
+ layers.append(
733
+ bottleNeckPSP(
734
+ in_channels,
735
+ mid_channels,
736
+ out_channels,
737
+ stride,
738
+ dilation,
739
+ is_batchnorm=is_batchnorm,
740
+ )
741
+ )
742
+ if include_range in ["all", "identity"]:
743
+ for i in range(n_blocks - 1):
744
+ layers.append(
745
+ bottleNeckIdentifyPSP(
746
+ out_channels, mid_channels, stride, dilation, is_batchnorm=is_batchnorm
747
+ )
748
+ )
749
+
750
+ self.layers = nn.Sequential(*layers)
751
+
752
+ def forward(self, x):
753
+ return self.layers(x)
754
+
755
+
756
+ class cascadeFeatureFusion(nn.Module):
757
+ def __init__(
758
+ self, n_classes, low_in_channels, high_in_channels, out_channels, is_batchnorm=True
759
+ ):
760
+ super(cascadeFeatureFusion, self).__init__()
761
+
762
+ bias = not is_batchnorm
763
+
764
+ self.low_dilated_conv_bn = conv2DBatchNorm(
765
+ low_in_channels,
766
+ out_channels,
767
+ 3,
768
+ stride=1,
769
+ padding=2,
770
+ bias=bias,
771
+ dilation=2,
772
+ is_batchnorm=is_batchnorm,
773
+ )
774
+ self.low_classifier_conv = nn.Conv2d(
775
+ int(low_in_channels),
776
+ int(n_classes),
777
+ kernel_size=1,
778
+ padding=0,
779
+ stride=1,
780
+ bias=True,
781
+ dilation=1,
782
+ ) # Train only
783
+ self.high_proj_conv_bn = conv2DBatchNorm(
784
+ high_in_channels,
785
+ out_channels,
786
+ 1,
787
+ stride=1,
788
+ padding=0,
789
+ bias=bias,
790
+ is_batchnorm=is_batchnorm,
791
+ )
792
+
793
+ def forward(self, x_low, x_high):
794
+ x_low_upsampled = F.interpolate(
795
+ x_low, size=get_interp_size(x_low, z_factor=2), mode="bilinear", align_corners=True
796
+ )
797
+
798
+ low_cls = self.low_classifier_conv(x_low_upsampled)
799
+
800
+ low_fm = self.low_dilated_conv_bn(x_low_upsampled)
801
+ high_fm = self.high_proj_conv_bn(x_high)
802
+ high_fused_fm = F.relu(low_fm + high_fm, inplace=True)
803
+
804
+ return high_fused_fm, low_cls
805
+
806
+
807
+ def get_interp_size(input, s_factor=1, z_factor=1): # for caffe
808
+ ori_h, ori_w = input.shape[2:]
809
+
810
+ # shrink (s_factor >= 1)
811
+ ori_h = (ori_h - 1) / s_factor + 1
812
+ ori_w = (ori_w - 1) / s_factor + 1
813
+
814
+ # zoom (z_factor >= 1)
815
+ ori_h = ori_h + (ori_h - 1) * (z_factor - 1)
816
+ ori_w = ori_w + (ori_w - 1) * (z_factor - 1)
817
+
818
+ resize_shape = (int(ori_h), int(ori_w))
819
+ return resize_shape
820
+
821
+
822
+ def interp(input, output_size, mode="bilinear"):
823
+ n, c, ih, iw = input.shape
824
+ oh, ow = output_size
825
+
826
+ # normalize to [-1, 1]
827
+ h = torch.arange(0, oh, dtype=torch.float, device='cuda' if input.is_cuda else 'cpu') / (oh - 1) * 2 - 1
828
+ w = torch.arange(0, ow, dtype=torch.float, device='cuda' if input.is_cuda else 'cpu') / (ow - 1) * 2 - 1
829
+
830
+ grid = torch.zeros(oh, ow, 2, dtype=torch.float, device='cuda' if input.is_cuda else 'cpu')
831
+ grid[:, :, 0] = w.unsqueeze(0).repeat(oh, 1)
832
+ grid[:, :, 1] = h.unsqueeze(0).repeat(ow, 1).transpose(0, 1)
833
+ grid = grid.unsqueeze(0).repeat(n, 1, 1, 1) # grid.shape: [n, oh, ow, 2]
834
+
835
+ return F.grid_sample(input, grid, mode=mode)
836
+
837
+
838
+ def get_upsampling_weight(in_channels, out_channels, kernel_size):
839
+ """Make a 2D bilinear kernel suitable for upsampling"""
840
+ factor = (kernel_size + 1) // 2
841
+ if kernel_size % 2 == 1:
842
+ center = factor - 1
843
+ else:
844
+ center = factor - 0.5
845
+ og = np.ogrid[:kernel_size, :kernel_size]
846
+ filt = (1 - abs(og[0] - center) / factor) * \
847
+ (1 - abs(og[1] - center) / factor)
848
+ weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
849
+ dtype=np.float64)
850
+ weight[range(in_channels), range(out_channels), :, :] = filt
851
+ return torch.from_numpy(weight).float()
metrics/criteria/parse_related_loss/unet.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from criteria.parse_related_loss.model_utils import *
3
+
4
+
5
+ class unet(nn.Module):
6
+ def __init__(
7
+ self,
8
+ feature_scale=4,
9
+ n_classes=19,
10
+ is_deconv=True,
11
+ in_channels=3,
12
+ is_batchnorm=True,
13
+ ):
14
+ super(unet, self).__init__()
15
+ self.is_deconv = is_deconv
16
+ self.in_channels = in_channels
17
+ self.is_batchnorm = is_batchnorm
18
+ self.feature_scale = feature_scale
19
+
20
+ filters = [64, 128, 256, 512, 1024]
21
+ filters = [int(x / self.feature_scale) for x in filters]
22
+
23
+ # downsampling
24
+ self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
25
+ self.maxpool1 = nn.MaxPool2d(kernel_size=2)
26
+
27
+ self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
28
+ self.maxpool2 = nn.MaxPool2d(kernel_size=2)
29
+
30
+ self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
31
+ self.maxpool3 = nn.MaxPool2d(kernel_size=2)
32
+
33
+ self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
34
+ self.maxpool4 = nn.MaxPool2d(kernel_size=2)
35
+
36
+ self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)
37
+
38
+ # upsampling
39
+ self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv, self.is_batchnorm)
40
+ self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv, self.is_batchnorm)
41
+ self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv, self.is_batchnorm)
42
+ self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv, self.is_batchnorm)
43
+
44
+ # final conv (without any concat)
45
+ self.final = nn.Conv2d(filters[0], n_classes, 1)
46
+
47
+ def forward(self, inputs):
48
+ conv1 = self.conv1(inputs)
49
+ maxpool1 = self.maxpool1(conv1)
50
+
51
+ conv2 = self.conv2(maxpool1)
52
+ maxpool2 = self.maxpool2(conv2)
53
+
54
+ conv3 = self.conv3(maxpool2)
55
+ maxpool3 = self.maxpool3(conv3)
56
+
57
+ conv4 = self.conv4(maxpool3)
58
+ maxpool4 = self.maxpool4(conv4)
59
+
60
+ center = self.center(maxpool4)
61
+ up4 = self.up_concat4(conv4, center)
62
+ up3 = self.up_concat3(conv3, up4)
63
+ up2 = self.up_concat2(conv2, up3)
64
+ up1 = self.up_concat1(conv1, up2)
65
+
66
+ final = self.final(up1)
67
+
68
+ return final
metrics/face_eval.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .face_parsing import BiSeNet
2
+ import numpy as np
3
+ from .metrics import LPIPS, MS_SSIM, IdScore, ClipHair
4
+ import torch.nn as nn
5
+ import torch
6
+ from torchvision import transforms
7
+
8
+ class FaceSegmentation(nn.Module):
9
+ def __init__(self, n_classes=19, device='cuda', save_pth='./pretrained_models/79999_iter.pth'):
10
+ super(FaceSegmentation, self).__init__()
11
+ self.net = BiSeNet(n_classes=n_classes).to(device)
12
+ self.net.load_state_dict(torch.load(save_pth))
13
+ self.net.eval()
14
+ self.transform = transforms.Compose([
15
+ transforms.ToTensor(),
16
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
17
+ ])
18
+ self.device=device
19
+
20
+
21
+
22
+ def get_facemask(self, parsing_anno):
23
+ """
24
+ Returns a binary image of the face.
25
+ """
26
+ # face_attr = {1: 'skin', 2: 'l_brow', 3: 'r_brow', 4: 'l_eye', 5: 'r_eye', 6: 'eye_glass', 7: 'l_ear', 8: 'r_ear', 10: 'nose', 11: 'mouth', 12: 'u_lip', 13: 'l_lip', 14: 'neck'}
27
+ face_attr = torch.tensor([1,2,3,4,5,6,7,8,10,11,12,13,14],device=self.device)
28
+ face_mask = torch.isin(parsing_anno, face_attr)
29
+ return(face_mask.int())
30
+
31
+
32
+ def get_hairmask(self, parsing_anno):
33
+ """
34
+ Returns a binary image of the hair.
35
+ """
36
+ hair_mask = parsing_anno == 17
37
+ return(hair_mask.int())
38
+
39
+ def forward(self, img):
40
+ """
41
+ Returns a binary image of the face and hair.
42
+ """
43
+ img = self.transform(img).to(self.device)
44
+ parsing_anno = self.net(img.unsqueeze(0))[0].squeeze(0).argmax(0)
45
+ face_mask = self.get_facemask(parsing_anno).to(self.device)
46
+ hair_mask = self.get_hairmask(parsing_anno).to(self.device)
47
+ return img, face_mask, hair_mask
48
+
49
+
50
+ class FaceMetric(nn.Module):
51
+ def __init__(self, metric_type, eval_face=True, eval_hair=True, device='cuda', seg_save_pth='./pretrained_models/79999_iter.pth'):
52
+ super(FaceMetric, self).__init__()
53
+ if metric_type == 'ms-ssim':
54
+ self.metric = MS_SSIM()
55
+ self.eval_hair= eval_hair
56
+ self.eval_face= eval_face
57
+ elif metric_type == 'lpips':
58
+ self.metric = LPIPS(device=device)
59
+ self.eval_hair= eval_hair
60
+ self.eval_face= eval_face
61
+ elif metric_type == 'id':
62
+ self.metric = IdScore(device=device)
63
+ self.eval_hair = False
64
+ self.eval_face = eval_face
65
+ elif metric_type == 'cliphair':
66
+ self.metric = ClipHair(device=device)
67
+ self.eval_face = False
68
+ self.eval_hair = eval_hair
69
+ else:
70
+ raise NotImplementedError
71
+ self.parser = FaceSegmentation(device=device, save_pth=seg_save_pth)
72
+ self.device=device
73
+
74
+
75
+
76
+
77
+ def forward(self, x, y):
78
+ face_score, hair_score = None, None
79
+ x_tensor, x_face_seg, x_hair_seg = self.parser(x)
80
+ y_tensor, y_face_seg, y_hair_seg = self.parser(y)
81
+ if self.eval_hair == True:
82
+
83
+
84
+ ## Get union of two hair masks
85
+ #hair_mask = (x_hair_seg + y_hair_seg) > 0
86
+
87
+ x_hair = x_tensor * x_hair_seg
88
+ y_hair = y_tensor * y_hair_seg
89
+
90
+ hair_score = self.metric(x_hair, y_hair).item()
91
+ if self.eval_face == True:
92
+
93
+ ## Get intersection of two face masks
94
+ face_mask = (x_face_seg + y_face_seg) > 1
95
+
96
+ x_face = x_tensor * face_mask
97
+ y_face = y_tensor * face_mask
98
+
99
+ face_score = self.metric(x_face, y_face).item()
100
+
101
+ return face_score, hair_score
102
+
103
+
metrics/metrics.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import Module, Sequential, Conv2d, BatchNorm2d, PReLU, Dropout, Flatten, Linear, BatchNorm1d, MaxPool2d, AdaptiveAvgPool2d, ReLU, Sigmoid
4
+ from collections import namedtuple
5
+ from pytorch_msssim import ms_ssim
6
+ import lpips
7
+ import clip
8
+ from torchvision import transforms
9
+
10
+ class LPIPS(nn.Module):
11
+ def __init__(self, net='alex', device='cuda'):
12
+ super(LPIPS, self).__init__()
13
+ self.lpips = lpips.LPIPS(net='alex').to(device)
14
+
15
+ def forward(self, x, y):
16
+ return 1- self.lpips(x, y).squeeze()
17
+
18
+
19
+ class MS_SSIM(nn.Module):
20
+ def __init__(self, avg=False):
21
+ super(MS_SSIM, self).__init__()
22
+ self.ssim = ms_ssim
23
+ self.avg = avg
24
+
25
+ def forward(self, x, y):
26
+ ## normalize images to [0, 1]
27
+ x = (x+1)/2
28
+ y = (y+1)/2
29
+ return self.ssim(x.unsqueeze(0), y.unsqueeze(0), data_range=1, size_average=self.avg)
30
+
31
+
32
+ class IdScore(nn.Module):
33
+ # def __init__(self, opts):
34
+ def __init__(self, device='cuda'):
35
+ super(IdScore, self).__init__()
36
+ # print('Loading ResNet ArcFace')
37
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6).to(device)
38
+ self.facenet.load_state_dict(torch.load('./pretrained_models/model_ir_se50.pth'))
39
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
40
+ self.facenet.eval()
41
+ self.cosine_sim = nn.CosineSimilarity(dim=1)
42
+
43
+
44
+ def extract_feats(self, x):
45
+ x = self.face_pool(x)
46
+ x_feats = self.facenet(x)
47
+ return x_feats
48
+
49
+ def forward(self, y, x):
50
+ x = x.unsqueeze(0)
51
+ y = y.unsqueeze(0)
52
+ x_feats = self.extract_feats(x)
53
+ y_feats = self.extract_feats(y) # Otherwise use the feature from there
54
+ y_feats = y_feats.detach()
55
+
56
+ # diff_views = y_feats[0].dot(x_feats[0])
57
+ cosine_sim = self.cosine_sim(y_feats, x_feats)
58
+
59
+ return cosine_sim
60
+
61
+ class ClipHair(nn.Module):
62
+ def __init__(self, device='cuda'):
63
+ super(ClipHair, self).__init__()
64
+ self.model, self.preprocessing = clip.load("ViT-B/32", device=device)
65
+ self.cosine_sim = nn.CosineSimilarity(dim=1)
66
+ self.device = device
67
+ # self.model, self.preprocessing = model, preprocessing
68
+
69
+ def extract_feats(self, x):
70
+
71
+ x = transforms.ToPILImage()(x.squeeze())
72
+ x = self.preprocessing(x).unsqueeze(0).to(self.device)
73
+ x = self.model.encode_image(x)
74
+ return x
75
+
76
+ def forward(self, y, x):
77
+ x = x.unsqueeze(0)
78
+ y = y.unsqueeze(0)
79
+ x_feats = self.extract_feats(x)
80
+ y_feats = self.extract_feats(y)
81
+ y_feats = y_feats.detach()
82
+
83
+ cosine_sim = self.cosine_sim(x_feats, y_feats)
84
+
85
+ # diff_views = y_feats[0].dot(x_feats[0])/ (y_feats[0].norm() * x_feats[0].norm())
86
+ return cosine_sim
87
+
88
+
89
+ class bottleneck_IR_SE(Module):
90
+ def __init__(self, in_channel, depth, stride):
91
+ super(bottleneck_IR_SE, self).__init__()
92
+ if in_channel == depth:
93
+ self.shortcut_layer = MaxPool2d(1, stride)
94
+ else:
95
+ self.shortcut_layer = Sequential(
96
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
97
+ BatchNorm2d(depth)
98
+ )
99
+ self.res_layer = Sequential(
100
+ BatchNorm2d(in_channel),
101
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
102
+ PReLU(depth),
103
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
104
+ BatchNorm2d(depth),
105
+ SEModule(depth, 16)
106
+ )
107
+
108
+ def forward(self, x):
109
+ shortcut = self.shortcut_layer(x)
110
+ res = self.res_layer(x)
111
+ return res + shortcut
112
+
113
+
114
+ class Backbone(Module):
115
+ def __init__(self, input_size, num_layers, drop_ratio=0.4, affine=True):
116
+ super(Backbone, self).__init__()
117
+ assert input_size in [112, 224], "input_size should be 112 or 224"
118
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
119
+ blocks = get_blocks(num_layers)
120
+
121
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
122
+ BatchNorm2d(64),
123
+ PReLU(64))
124
+ if input_size == 112:
125
+ self.output_layer = Sequential(BatchNorm2d(512),
126
+ Dropout(drop_ratio),
127
+ Flatten(),
128
+ Linear(512 * 7 * 7, 512),
129
+ BatchNorm1d(512, affine=affine))
130
+ else:
131
+ self.output_layer = Sequential(BatchNorm2d(512),
132
+ Dropout(drop_ratio),
133
+ Flatten(),
134
+ Linear(512 * 14 * 14, 512),
135
+ BatchNorm1d(512, affine=affine))
136
+
137
+ modules = []
138
+ for block in blocks:
139
+ for bottleneck in block:
140
+ modules.append(bottleneck_IR_SE(bottleneck.in_channel,
141
+ bottleneck.depth,
142
+ bottleneck.stride))
143
+ self.body = Sequential(*modules)
144
+
145
+ def forward(self, x):
146
+ x = self.input_layer(x)
147
+ x = self.body(x)
148
+ x = self.output_layer(x)
149
+ return l2_norm(x)
150
+
151
+ def get_blocks(num_layers):
152
+ if num_layers == 50:
153
+ blocks = [
154
+ get_block(in_channel=64, depth=64, num_units=3),
155
+ get_block(in_channel=64, depth=128, num_units=4),
156
+ get_block(in_channel=128, depth=256, num_units=14),
157
+ get_block(in_channel=256, depth=512, num_units=3)
158
+ ]
159
+ elif num_layers == 100:
160
+ blocks = [
161
+ get_block(in_channel=64, depth=64, num_units=3),
162
+ get_block(in_channel=64, depth=128, num_units=13),
163
+ get_block(in_channel=128, depth=256, num_units=30),
164
+ get_block(in_channel=256, depth=512, num_units=3)
165
+ ]
166
+ elif num_layers == 152:
167
+ blocks = [
168
+ get_block(in_channel=64, depth=64, num_units=3),
169
+ get_block(in_channel=64, depth=128, num_units=8),
170
+ get_block(in_channel=128, depth=256, num_units=36),
171
+ get_block(in_channel=256, depth=512, num_units=3)
172
+ ]
173
+ else:
174
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
175
+ return blocks
176
+
177
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
178
+ """ A named tuple describing a ResNet block. """
179
+
180
+
181
+ def get_block(in_channel, depth, num_units, stride=2):
182
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
183
+
184
+ def l2_norm(input, axis=1):
185
+ norm = torch.norm(input, 2, axis, True)
186
+ output = torch.div(input, norm)
187
+ return output
188
+
189
+ class SEModule(Module):
190
+ def __init__(self, channels, reduction):
191
+ super(SEModule, self).__init__()
192
+ self.avg_pool = AdaptiveAvgPool2d(1)
193
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
194
+ self.relu = ReLU(inplace=True)
195
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
196
+ self.sigmoid = Sigmoid()
197
+
198
+ def forward(self, x):
199
+ module_input = x
200
+ x = self.avg_pool(x)
201
+ x = self.fc1(x)
202
+ x = self.relu(x)
203
+ x = self.fc2(x)
204
+ x = self.sigmoid(x)
205
+ return module_input * x
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  torch
2
  torchvision
 
3
  dlib
4
  pillow
5
  numpy
 
1
  torch
2
  torchvision
3
+ cudatoolkit
4
  dlib
5
  pillow
6
  numpy
ris/model.py CHANGED
@@ -508,12 +508,7 @@ class Generator(nn.Module):
508
  output.append(self.to_rgb1.get_latent(latent[:, 1]))
509
 
510
  i = 1
511
- # print("Get latent dimensions:")
512
  for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2], self.to_rgbs):
513
- # print(f'{i}: {conv1.get_latent(latent[:, i]).shape}')
514
- # print(f'{i+1}: {conv2.get_latent(latent[:, i+1]).shape}')
515
- # print(f'{i+2}: {to_rgb.get_latent(latent[:, i+2]).shape}')
516
- # print("")
517
  output.append(conv1.get_latent(latent[:, i]))
518
  output.append(conv2.get_latent(latent[:, i+1]))
519
  output.append(to_rgb.get_latent(latent[:, i+2]))
 
508
  output.append(self.to_rgb1.get_latent(latent[:, 1]))
509
 
510
  i = 1
 
511
  for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2], self.to_rgbs):
 
 
 
 
512
  output.append(conv1.get_latent(latent[:, i]))
513
  output.append(conv2.get_latent(latent[:, i+1]))
514
  output.append(to_rgb.get_latent(latent[:, i+2]))