Spaces:
Runtime error
Runtime error
ethanNeuralImage
commited on
Commit
•
c85e4eb
1
Parent(s):
5238ef9
Adding in metrics
Browse files- app.py +53 -15
- metrics/__init__.py +1 -0
- metrics/criteria/__init__.py +0 -0
- metrics/criteria/clip_loss.py +17 -0
- metrics/criteria/id_loss.py +40 -0
- metrics/criteria/parse_related_loss/average_lab_color_loss.py +78 -0
- metrics/criteria/parse_related_loss/bg_loss.py +29 -0
- metrics/criteria/parse_related_loss/model_utils.py +851 -0
- metrics/criteria/parse_related_loss/unet.py +68 -0
- metrics/face_eval.py +103 -0
- metrics/metrics.py +205 -0
- requirements.txt +1 -0
- ris/model.py +0 -5
app.py
CHANGED
@@ -26,6 +26,10 @@ import ris.spherical_kmeans as spherical_kmeans
|
|
26 |
from ris.blend import blend_latents
|
27 |
from ris.model import Generator as RIS_Generator
|
28 |
|
|
|
|
|
|
|
|
|
29 |
from PIL import Image
|
30 |
|
31 |
opts_args = ['--no_fine_mapper']
|
@@ -70,6 +74,11 @@ ris_gen = RIS_Generator(1024, 512, 8, channel_multiplier=2).to(device).eval()
|
|
70 |
ris_ckpt = torch.load('./pretrained_models/ris/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
|
71 |
ris_gen.load_state_dict(ris_ckpt['g_ema'], strict=False)
|
72 |
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
with gr.Blocks() as demo:
|
75 |
with gr.Row() as row:
|
@@ -104,14 +113,14 @@ with gr.Blocks() as demo:
|
|
104 |
output_hyperstyle_gd = gr.Image(type='pil', label="Hyperstyle Global Directions", visible=False)
|
105 |
output_hyperstyle_ris = gr.Image(type='pil', label='Hyperstyle RIS', visible=False)
|
106 |
with gr.Row() as hyperstyle_metrics:
|
107 |
-
output_hypersyle_metrics = gr.Text()
|
108 |
with gr.Row(visible=False) as e4e_images:
|
109 |
output_e4e_invert = gr.Image(type='pil', label="E4E Inverted", visible=False)
|
110 |
output_e4e_mapper = gr.Image(type='pil', label="E4E Mapper")
|
111 |
output_e4e_gd = gr.Image(type='pil', label="E4E Global Directions", visible=False)
|
112 |
output_e4e_ris = gr.Image(type='pil', label='E4E RIS', visible=False)
|
113 |
-
with gr.Row() as e4e_metrics:
|
114 |
-
output_e4e_metrics = gr.Text()
|
115 |
def n_iter_change(number):
|
116 |
if number < 0:
|
117 |
return 0
|
@@ -124,7 +133,9 @@ with gr.Blocks() as demo:
|
|
124 |
hyperstyle_bool = 'Hyperstyle' in bools
|
125 |
return {
|
126 |
hyperstyle_images: gr.update(visible=hyperstyle_bool),
|
|
|
127 |
e4e_images: gr.update(visible=e4e_bool),
|
|
|
128 |
n_hyperstyle_iterations: gr.update(visible=hyperstyle_bool)
|
129 |
}
|
130 |
def outp_toggles(bool):
|
@@ -153,7 +164,7 @@ with gr.Blocks() as demo:
|
|
153 |
|
154 |
n_hyperstyle_iterations.change(n_iter_change, n_hyperstyle_iterations, n_hyperstyle_iterations)
|
155 |
mapper_choice.change(mapper_change, mapper_choice, [target_text])
|
156 |
-
inverter_bools.change(inverter_toggles, inverter_bools, [hyperstyle_images, e4e_images, n_hyperstyle_iterations])
|
157 |
invert_bool.change(outp_toggles, invert_bool, [output_hyperstyle_invert, output_e4e_invert])
|
158 |
mapper_bool.change(mapper_toggles, mapper_bool, [mapper_opts, output_hyperstyle_mapper, output_e4e_mapper])
|
159 |
gd_bool.change(gd_toggles, gd_bool, [gd_opts, output_hyperstyle_gd, output_e4e_gd])
|
@@ -173,6 +184,17 @@ with gr.Blocks() as demo:
|
|
173 |
randomize_noise=False, truncation=1, weights_deltas=weight_deltas)
|
174 |
result_batch = (x_hat, w_hat)
|
175 |
return result_batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
def submit(
|
177 |
src, align_img, inverter_bools, n_iterations, invert_bool,
|
178 |
mapper_bool, mapper_choice, mapper_alpha,
|
@@ -188,6 +210,7 @@ with gr.Blocks() as demo:
|
|
188 |
mapper = StyleCLIPMapper(mapper_args)
|
189 |
mapper.eval()
|
190 |
mapper.to(device)
|
|
|
191 |
with torch.no_grad():
|
192 |
output_imgs = []
|
193 |
if align_img:
|
@@ -208,7 +231,7 @@ with gr.Blocks() as demo:
|
|
208 |
else:
|
209 |
ref_input = Image.open(src).convert('RGB')
|
210 |
ref_input = im2tensor_transforms(ref_input).to(device)
|
211 |
-
|
212 |
if 'Hyperstyle' in inverter_bools:
|
213 |
hyperstyle_batch, hyperstyle_latents, hyperstyle_deltas, _ = run_inversion(input_img.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
|
214 |
if invert_bool:
|
@@ -217,13 +240,19 @@ with gr.Blocks() as demo:
|
|
217 |
invert_hyperstyle = None
|
218 |
if mapper_bool:
|
219 |
mapped_hyperstyle, _ = map_latent(mapper, hyperstyle_latents, stylespace=False, weight_deltas=hyperstyle_deltas, strength=mapper_alpha)
|
|
|
220 |
mapped_hyperstyle = tensor2im(mapped_hyperstyle[0])
|
|
|
|
|
221 |
else:
|
222 |
mapped_hyperstyle = None
|
223 |
|
224 |
if gd_bool:
|
225 |
-
gd_hyperstyle = edit_image(_, hyperstyle_latents[0], hyperstyle.decoder, direction_calculator, opts, hyperstyle_deltas)
|
226 |
-
|
|
|
|
|
|
|
227 |
else:
|
228 |
gd_hyperstyle = None
|
229 |
|
@@ -237,10 +266,11 @@ with gr.Blocks() as demo:
|
|
237 |
else:
|
238 |
ris_hyperstyle=None
|
239 |
|
240 |
-
hyperstyle_output = [invert_hyperstyle, mapped_hyperstyle,gd_hyperstyle, ris_hyperstyle]
|
241 |
else:
|
242 |
-
hyperstyle_output = [None, None, None, None]
|
243 |
output_imgs.extend(hyperstyle_output)
|
|
|
244 |
if 'E4E' in inverter_bools:
|
245 |
e4e_batch, e4e_latents = hyperstyle.w_invert(input_img.unsqueeze(0))
|
246 |
e4e_deltas = None
|
@@ -250,13 +280,21 @@ with gr.Blocks() as demo:
|
|
250 |
invert_e4e = None
|
251 |
if mapper_bool:
|
252 |
mapped_e4e, _ = map_latent(mapper, e4e_latents, stylespace=False, weight_deltas=e4e_deltas, strength=mapper_alpha)
|
|
|
253 |
mapped_e4e = tensor2im(mapped_e4e[0])
|
|
|
|
|
|
|
254 |
else:
|
255 |
mapped_e4e = None
|
256 |
|
257 |
if gd_bool:
|
258 |
-
gd_e4e = edit_image(_, e4e_latents[0], hyperstyle.decoder, direction_calculator, opts, e4e_deltas)
|
259 |
-
|
|
|
|
|
|
|
|
|
260 |
else:
|
261 |
gd_e4e = None
|
262 |
|
@@ -270,9 +308,9 @@ with gr.Blocks() as demo:
|
|
270 |
else:
|
271 |
ris_e4e=None
|
272 |
|
273 |
-
e4e_output = [invert_e4e, mapped_e4e, gd_e4e, ris_e4e]
|
274 |
else:
|
275 |
-
e4e_output = [None, None, None, None]
|
276 |
output_imgs.extend(e4e_output)
|
277 |
return output_imgs
|
278 |
submit_button.click(
|
@@ -283,8 +321,8 @@ with gr.Blocks() as demo:
|
|
283 |
gd_bool, neutral_text, target_text, alpha, beta,
|
284 |
ris_bool, ref_img
|
285 |
],
|
286 |
-
[output_hyperstyle_invert, output_hyperstyle_mapper, output_hyperstyle_gd, output_hyperstyle_ris,
|
287 |
-
output_e4e_invert, output_e4e_mapper, output_e4e_gd, output_e4e_ris]
|
288 |
)
|
289 |
|
290 |
demo.launch()
|
|
|
26 |
from ris.blend import blend_latents
|
27 |
from ris.model import Generator as RIS_Generator
|
28 |
|
29 |
+
from metrics import FaceMetric
|
30 |
+
from metrics.criteria.clip_loss import CLIPLoss
|
31 |
+
import clip
|
32 |
+
|
33 |
from PIL import Image
|
34 |
|
35 |
opts_args = ['--no_fine_mapper']
|
|
|
74 |
ris_ckpt = torch.load('./pretrained_models/ris/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
|
75 |
ris_gen.load_state_dict(ris_ckpt['g_ema'], strict=False)
|
76 |
|
77 |
+
lpips_metric = FaceMetric(metric_type='lpips', device=device)
|
78 |
+
ssim_metric = FaceMetric(metric_type='ms-ssim', device=device)
|
79 |
+
id_metric = FaceMetric(metric_type='id', device=device)
|
80 |
+
clip_hair = FaceMetric(metric_type='cliphair', device=device)
|
81 |
+
clip_text = CLIPLoss(hyperstyle_args)
|
82 |
|
83 |
with gr.Blocks() as demo:
|
84 |
with gr.Row() as row:
|
|
|
113 |
output_hyperstyle_gd = gr.Image(type='pil', label="Hyperstyle Global Directions", visible=False)
|
114 |
output_hyperstyle_ris = gr.Image(type='pil', label='Hyperstyle RIS', visible=False)
|
115 |
with gr.Row() as hyperstyle_metrics:
|
116 |
+
output_hypersyle_metrics = gr.Text(label='Hyperstyle Metrics')
|
117 |
with gr.Row(visible=False) as e4e_images:
|
118 |
output_e4e_invert = gr.Image(type='pil', label="E4E Inverted", visible=False)
|
119 |
output_e4e_mapper = gr.Image(type='pil', label="E4E Mapper")
|
120 |
output_e4e_gd = gr.Image(type='pil', label="E4E Global Directions", visible=False)
|
121 |
output_e4e_ris = gr.Image(type='pil', label='E4E RIS', visible=False)
|
122 |
+
with gr.Row(visible=False) as e4e_metrics:
|
123 |
+
output_e4e_metrics = gr.Text(label='E4E Metrics')
|
124 |
def n_iter_change(number):
|
125 |
if number < 0:
|
126 |
return 0
|
|
|
133 |
hyperstyle_bool = 'Hyperstyle' in bools
|
134 |
return {
|
135 |
hyperstyle_images: gr.update(visible=hyperstyle_bool),
|
136 |
+
hyperstyle_metrics: gr.update(visible=hyperstyle_bool),
|
137 |
e4e_images: gr.update(visible=e4e_bool),
|
138 |
+
e4e_metrics: gr.update(visible=e4e_bool),
|
139 |
n_hyperstyle_iterations: gr.update(visible=hyperstyle_bool)
|
140 |
}
|
141 |
def outp_toggles(bool):
|
|
|
164 |
|
165 |
n_hyperstyle_iterations.change(n_iter_change, n_hyperstyle_iterations, n_hyperstyle_iterations)
|
166 |
mapper_choice.change(mapper_change, mapper_choice, [target_text])
|
167 |
+
inverter_bools.change(inverter_toggles, inverter_bools, [hyperstyle_images, hyperstyle_metrics, e4e_images, e4e_metrics, n_hyperstyle_iterations])
|
168 |
invert_bool.change(outp_toggles, invert_bool, [output_hyperstyle_invert, output_e4e_invert])
|
169 |
mapper_bool.change(mapper_toggles, mapper_bool, [mapper_opts, output_hyperstyle_mapper, output_e4e_mapper])
|
170 |
gd_bool.change(gd_toggles, gd_bool, [gd_opts, output_hyperstyle_gd, output_e4e_gd])
|
|
|
184 |
randomize_noise=False, truncation=1, weights_deltas=weight_deltas)
|
185 |
result_batch = (x_hat, w_hat)
|
186 |
return result_batch
|
187 |
+
def run_metrics(base_img, edited_img):
|
188 |
+
lpips_score = lpips_metric(base_img, edited_img)[0]
|
189 |
+
ssim_score = ssim_metric(base_img, edited_img)[0]
|
190 |
+
id_score = id_metric(base_img, edited_img)[0]
|
191 |
+
|
192 |
+
return lpips_score, ssim_score, id_score
|
193 |
+
def clip_text_metric(tensor, text):
|
194 |
+
clip_embed = torch.cat([clip.tokenize(text)]).cuda()
|
195 |
+
clip_score = 1-clip_text(tensor.unsqueeze(0), clip_embed).item()
|
196 |
+
return clip_score
|
197 |
+
|
198 |
def submit(
|
199 |
src, align_img, inverter_bools, n_iterations, invert_bool,
|
200 |
mapper_bool, mapper_choice, mapper_alpha,
|
|
|
210 |
mapper = StyleCLIPMapper(mapper_args)
|
211 |
mapper.eval()
|
212 |
mapper.to(device)
|
213 |
+
resize_to = (256, 256) if hyperstyle_args.resize_outputs else (hyperstyle_args.output_size, hyperstyle_args.output_size)
|
214 |
with torch.no_grad():
|
215 |
output_imgs = []
|
216 |
if align_img:
|
|
|
231 |
else:
|
232 |
ref_input = Image.open(src).convert('RGB')
|
233 |
ref_input = im2tensor_transforms(ref_input).to(device)
|
234 |
+
hyperstyle_metrics_text = ''
|
235 |
if 'Hyperstyle' in inverter_bools:
|
236 |
hyperstyle_batch, hyperstyle_latents, hyperstyle_deltas, _ = run_inversion(input_img.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
|
237 |
if invert_bool:
|
|
|
240 |
invert_hyperstyle = None
|
241 |
if mapper_bool:
|
242 |
mapped_hyperstyle, _ = map_latent(mapper, hyperstyle_latents, stylespace=False, weight_deltas=hyperstyle_deltas, strength=mapper_alpha)
|
243 |
+
#clip_score = clip_text_metric(mapped_hyperstyle[0], mapper_args.description)
|
244 |
mapped_hyperstyle = tensor2im(mapped_hyperstyle[0])
|
245 |
+
#lpips_score, ssim_score, id_score = run_metrics(invert_hyperstyle.resize(resize_to), mapped_hyperstyle.resize(resize_to))
|
246 |
+
#hyperstyle_metrics_text += f'Mapper Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
|
247 |
else:
|
248 |
mapped_hyperstyle = None
|
249 |
|
250 |
if gd_bool:
|
251 |
+
gd_hyperstyle = edit_image(_, hyperstyle_latents[0], hyperstyle.decoder, direction_calculator, opts, hyperstyle_deltas)
|
252 |
+
#clip_score = clip_text_metric(gd_hyperstyle[0], opts.target_text)
|
253 |
+
gd_hyperstyle = tensor2im(gd_hyperstyle[0])
|
254 |
+
#lpips_score, ssim_score, id_score = run_metrics(invert_hyperstyle.resize(resize_to), gd_hyperstyle.resize(resize_to))
|
255 |
+
#hyperstyle_metrics_text += f'Global Direction Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
|
256 |
else:
|
257 |
gd_hyperstyle = None
|
258 |
|
|
|
266 |
else:
|
267 |
ris_hyperstyle=None
|
268 |
|
269 |
+
hyperstyle_output = [invert_hyperstyle, mapped_hyperstyle,gd_hyperstyle, ris_hyperstyle, hyperstyle_metrics_text]
|
270 |
else:
|
271 |
+
hyperstyle_output = [None, None, None, None, hyperstyle_metrics_text]
|
272 |
output_imgs.extend(hyperstyle_output)
|
273 |
+
e4e_metrics_text = ''
|
274 |
if 'E4E' in inverter_bools:
|
275 |
e4e_batch, e4e_latents = hyperstyle.w_invert(input_img.unsqueeze(0))
|
276 |
e4e_deltas = None
|
|
|
280 |
invert_e4e = None
|
281 |
if mapper_bool:
|
282 |
mapped_e4e, _ = map_latent(mapper, e4e_latents, stylespace=False, weight_deltas=e4e_deltas, strength=mapper_alpha)
|
283 |
+
#clip_score = clip_text_metric(mapped_e4e[0], mapper_args.description)
|
284 |
mapped_e4e = tensor2im(mapped_e4e[0])
|
285 |
+
#lpips_score, ssim_score, id_score = run_metrics(invert_e4e, mapped_e4e)
|
286 |
+
#e4e_metrics_text += f'Mapper Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
|
287 |
+
|
288 |
else:
|
289 |
mapped_e4e = None
|
290 |
|
291 |
if gd_bool:
|
292 |
+
gd_e4e = edit_image(_, e4e_latents[0], hyperstyle.decoder, direction_calculator, opts, e4e_deltas)
|
293 |
+
clip_score = clip_text_metric(gd_e4e[0], opts.target_text)
|
294 |
+
gd_e4e = tensor2im(gd_e4e[0])
|
295 |
+
lpips_score, ssim_score, id_score = run_metrics(invert_e4e, gd_e4e)
|
296 |
+
e4e_metrics_text += f'Global Direction Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
|
297 |
+
|
298 |
else:
|
299 |
gd_e4e = None
|
300 |
|
|
|
308 |
else:
|
309 |
ris_e4e=None
|
310 |
|
311 |
+
e4e_output = [invert_e4e, mapped_e4e, gd_e4e, ris_e4e, e4e_metrics_text]
|
312 |
else:
|
313 |
+
e4e_output = [None, None, None, None, e4e_metrics_text]
|
314 |
output_imgs.extend(e4e_output)
|
315 |
return output_imgs
|
316 |
submit_button.click(
|
|
|
321 |
gd_bool, neutral_text, target_text, alpha, beta,
|
322 |
ris_bool, ref_img
|
323 |
],
|
324 |
+
[output_hyperstyle_invert, output_hyperstyle_mapper, output_hyperstyle_gd, output_hyperstyle_ris, output_hypersyle_metrics,
|
325 |
+
output_e4e_invert, output_e4e_mapper, output_e4e_gd, output_e4e_ris, output_e4e_metrics]
|
326 |
)
|
327 |
|
328 |
demo.launch()
|
metrics/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .face_eval import FaceMetric
|
metrics/criteria/__init__.py
ADDED
File without changes
|
metrics/criteria/clip_loss.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import clip
|
4 |
+
|
5 |
+
|
6 |
+
class CLIPLoss(torch.nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, opts):
|
9 |
+
super(CLIPLoss, self).__init__()
|
10 |
+
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
|
11 |
+
self.upsample = torch.nn.Upsample(scale_factor=7)
|
12 |
+
self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)
|
13 |
+
|
14 |
+
def forward(self, image, text):
|
15 |
+
image = self.avg_pool(self.upsample(image))
|
16 |
+
similarity = 1 - self.model(image, text)[0] / 100
|
17 |
+
return similarity
|
metrics/criteria/id_loss.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from models.facial_recognition.model_irse import Backbone
|
5 |
+
|
6 |
+
|
7 |
+
class IDLoss(nn.Module):
|
8 |
+
def __init__(self, opts):
|
9 |
+
super(IDLoss, self).__init__()
|
10 |
+
print('Loading ResNet ArcFace')
|
11 |
+
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
12 |
+
self.facenet.load_state_dict(torch.load(opts.ir_se50_weights))
|
13 |
+
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
14 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
15 |
+
self.facenet.eval()
|
16 |
+
self.facenet.cuda()
|
17 |
+
self.opts = opts
|
18 |
+
|
19 |
+
def extract_feats(self, x):
|
20 |
+
if x.shape[2] != 256:
|
21 |
+
x = self.pool(x)
|
22 |
+
x = x[:, :, 35:223, 32:220] # Crop interesting region
|
23 |
+
x = self.face_pool(x)
|
24 |
+
x_feats = self.facenet(x)
|
25 |
+
return x_feats
|
26 |
+
|
27 |
+
def forward(self, y_hat, y):
|
28 |
+
n_samples = y.shape[0]
|
29 |
+
y_feats = self.extract_feats(y) # Otherwise use the feature from there
|
30 |
+
y_hat_feats = self.extract_feats(y_hat)
|
31 |
+
y_feats = y_feats.detach()
|
32 |
+
loss = 0
|
33 |
+
sim_improvement = 0
|
34 |
+
count = 0
|
35 |
+
for i in range(n_samples):
|
36 |
+
diff_target = y_hat_feats[i].dot(y_feats[i])
|
37 |
+
loss += 1 - diff_target
|
38 |
+
count += 1
|
39 |
+
|
40 |
+
return loss / count, sim_improvement / count
|
metrics/criteria/parse_related_loss/average_lab_color_loss.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from criteria.parse_related_loss.unet import unet
|
4 |
+
|
5 |
+
class AvgLabLoss(nn.Module):
|
6 |
+
def __init__(self, opts):
|
7 |
+
super(AvgLabLoss, self).__init__()
|
8 |
+
self.criterion = nn.L1Loss()
|
9 |
+
self.M = torch.tensor([[0.412453, 0.357580, 0.180423], [0.212671, 0.715160, 0.072169], [0.019334, 0.119193, 0.950227]])
|
10 |
+
print('Loading UNet for AvgLabLoss')
|
11 |
+
self.parsenet = unet()
|
12 |
+
self.parsenet.load_state_dict(torch.load(opts.parsenet_weights))
|
13 |
+
self.parsenet.eval()
|
14 |
+
self.shrink = torch.nn.AdaptiveAvgPool2d((512, 512))
|
15 |
+
self.magnify = torch.nn.AdaptiveAvgPool2d((1024, 1024))
|
16 |
+
|
17 |
+
def gen_hair_mask(self, input_image):
|
18 |
+
labels_predict = self.parsenet(self.shrink(input_image)).detach()
|
19 |
+
mask_512 = (torch.unsqueeze(torch.max(labels_predict, 1)[1], 1)==13).float()
|
20 |
+
mask_1024 = self.magnify(mask_512)
|
21 |
+
return mask_1024
|
22 |
+
|
23 |
+
# cal lab written by liuqk
|
24 |
+
def f(self, input):
|
25 |
+
output = input * 1
|
26 |
+
mask = input > 0.008856
|
27 |
+
output[mask] = torch.pow(input[mask], 1 / 3)
|
28 |
+
output[~mask] = 7.787 * input[~mask] + 0.137931
|
29 |
+
return output
|
30 |
+
|
31 |
+
def rgb2xyz(self, input):
|
32 |
+
assert input.size(1) == 3
|
33 |
+
M_tmp = self.M.to(input.device).unsqueeze(0)
|
34 |
+
M_tmp = M_tmp.repeat(input.size(0), 1, 1) # BxCxC
|
35 |
+
output = torch.einsum('bnc,bchw->bnhw', M_tmp, input) # BxCxHxW
|
36 |
+
M_tmp = M_tmp.sum(dim=2, keepdim=True) # BxCx1
|
37 |
+
M_tmp = M_tmp.unsqueeze(3) # BxCx1x1
|
38 |
+
return output / M_tmp
|
39 |
+
|
40 |
+
def xyz2lab(self, input):
|
41 |
+
assert input.size(1) == 3
|
42 |
+
output = input * 1
|
43 |
+
xyz_f = self.f(input)
|
44 |
+
# compute l
|
45 |
+
mask = input[:, 1, :, :] > 0.008856
|
46 |
+
output[:, 0, :, :][mask] = 116 * xyz_f[:, 1, :, :][mask] - 16
|
47 |
+
output[:, 0, :, :][~mask] = 903.3 * input[:, 1, :, :][~mask]
|
48 |
+
# compute a
|
49 |
+
output[:, 1, :, :] = 500 * (xyz_f[:, 0, :, :] - xyz_f[:, 1, :, :])
|
50 |
+
# compute b
|
51 |
+
output[:, 2, :, :] = 200 * (xyz_f[:, 1, :, :] - xyz_f[:, 2, :, :])
|
52 |
+
return output
|
53 |
+
def cal_hair_avg(self, input, mask):
|
54 |
+
x = input * mask
|
55 |
+
sum = torch.sum(torch.sum(x, dim=2, keepdim=True), dim=3, keepdim=True) # [n,3,1,1]
|
56 |
+
mask_sum = torch.sum(torch.sum(mask, dim=2, keepdim=True), dim=3, keepdim=True) # [n,1,1,1]
|
57 |
+
mask_sum[mask_sum == 0] = 1
|
58 |
+
avg = sum / mask_sum
|
59 |
+
return avg
|
60 |
+
|
61 |
+
def forward(self, fake, real):
|
62 |
+
# the mask is [n,1,h,w]
|
63 |
+
# normalize to 0~1
|
64 |
+
mask_fake = self.gen_hair_mask(fake)
|
65 |
+
mask_real = self.gen_hair_mask(real)
|
66 |
+
fake_RGB = (fake + 1) / 2.0
|
67 |
+
real_RGB = (real + 1) / 2.0
|
68 |
+
# from RGB to Lab by liuqk
|
69 |
+
fake_xyz = self.rgb2xyz(fake_RGB)
|
70 |
+
fake_Lab = self.xyz2lab(fake_xyz)
|
71 |
+
real_xyz = self.rgb2xyz(real_RGB)
|
72 |
+
real_Lab = self.xyz2lab(real_xyz)
|
73 |
+
# cal average value
|
74 |
+
fake_Lab_avg = self.cal_hair_avg(fake_Lab, mask_fake)
|
75 |
+
real_Lab_avg = self.cal_hair_avg(real_Lab, mask_real)
|
76 |
+
|
77 |
+
loss = self.criterion(fake_Lab_avg, real_Lab_avg)
|
78 |
+
return loss
|
metrics/criteria/parse_related_loss/bg_loss.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from criteria.parse_related_loss.unet import unet
|
4 |
+
|
5 |
+
class BackgroundLoss(nn.Module):
|
6 |
+
def __init__(self, opts):
|
7 |
+
super(BackgroundLoss, self).__init__()
|
8 |
+
print('Loading UNet for Background Loss')
|
9 |
+
self.parsenet = unet()
|
10 |
+
self.parsenet.load_state_dict(torch.load(opts.parsenet_weights))
|
11 |
+
self.parsenet.eval()
|
12 |
+
self.bg_mask_l2_loss = torch.nn.MSELoss()
|
13 |
+
self.shrink = torch.nn.AdaptiveAvgPool2d((512, 512))
|
14 |
+
self.magnify = torch.nn.AdaptiveAvgPool2d((1024, 1024))
|
15 |
+
|
16 |
+
|
17 |
+
def gen_bg_mask(self, input_image):
|
18 |
+
labels_predict = self.parsenet(self.shrink(input_image)).detach()
|
19 |
+
mask_512 = (torch.unsqueeze(torch.max(labels_predict, 1)[1], 1)!=13).float()
|
20 |
+
mask_1024 = self.magnify(mask_512)
|
21 |
+
return mask_1024
|
22 |
+
|
23 |
+
def forward(self, x, x_hat):
|
24 |
+
x_bg_mask = self.gen_bg_mask(x)
|
25 |
+
x_hat_bg_mask = self.gen_bg_mask(x_hat)
|
26 |
+
bg_mask = ((x_bg_mask+x_hat_bg_mask)==2).float()
|
27 |
+
loss = self.bg_mask_l2_loss(x * bg_mask, x_hat * bg_mask) / self.bg_mask_l2_loss(bg_mask, torch.zeros_like(bg_mask))
|
28 |
+
return loss
|
29 |
+
|
metrics/criteria/parse_related_loss/model_utils.py
ADDED
@@ -0,0 +1,851 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
class conv2DBatchNorm(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
in_channels,
|
10 |
+
n_filters,
|
11 |
+
k_size,
|
12 |
+
stride,
|
13 |
+
padding,
|
14 |
+
bias=True,
|
15 |
+
dilation=1,
|
16 |
+
is_batchnorm=True,
|
17 |
+
):
|
18 |
+
super(conv2DBatchNorm, self).__init__()
|
19 |
+
|
20 |
+
conv_mod = nn.Conv2d(int(in_channels),
|
21 |
+
int(n_filters),
|
22 |
+
kernel_size=k_size,
|
23 |
+
padding=padding,
|
24 |
+
stride=stride,
|
25 |
+
bias=bias,
|
26 |
+
dilation=dilation,)
|
27 |
+
|
28 |
+
if is_batchnorm:
|
29 |
+
self.cb_unit = nn.Sequential(conv_mod, nn.BatchNorm2d(int(n_filters)))
|
30 |
+
else:
|
31 |
+
self.cb_unit = nn.Sequential(conv_mod)
|
32 |
+
|
33 |
+
def forward(self, inputs):
|
34 |
+
outputs = self.cb_unit(inputs)
|
35 |
+
return outputs
|
36 |
+
|
37 |
+
|
38 |
+
class conv2DGroupNorm(nn.Module):
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
in_channels,
|
42 |
+
n_filters,
|
43 |
+
k_size,
|
44 |
+
stride,
|
45 |
+
padding,
|
46 |
+
bias=True,
|
47 |
+
dilation=1,
|
48 |
+
n_groups=16,
|
49 |
+
):
|
50 |
+
super(conv2DGroupNorm, self).__init__()
|
51 |
+
|
52 |
+
conv_mod = nn.Conv2d(int(in_channels),
|
53 |
+
int(n_filters),
|
54 |
+
kernel_size=k_size,
|
55 |
+
padding=padding,
|
56 |
+
stride=stride,
|
57 |
+
bias=bias,
|
58 |
+
dilation=dilation,)
|
59 |
+
|
60 |
+
self.cg_unit = nn.Sequential(conv_mod,
|
61 |
+
nn.GroupNorm(n_groups, int(n_filters)))
|
62 |
+
|
63 |
+
def forward(self, inputs):
|
64 |
+
outputs = self.cg_unit(inputs)
|
65 |
+
return outputs
|
66 |
+
|
67 |
+
|
68 |
+
class deconv2DBatchNorm(nn.Module):
|
69 |
+
def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True):
|
70 |
+
super(deconv2DBatchNorm, self).__init__()
|
71 |
+
|
72 |
+
self.dcb_unit = nn.Sequential(
|
73 |
+
nn.ConvTranspose2d(
|
74 |
+
int(in_channels),
|
75 |
+
int(n_filters),
|
76 |
+
kernel_size=k_size,
|
77 |
+
padding=padding,
|
78 |
+
stride=stride,
|
79 |
+
bias=bias,
|
80 |
+
),
|
81 |
+
nn.BatchNorm2d(int(n_filters)),
|
82 |
+
)
|
83 |
+
|
84 |
+
def forward(self, inputs):
|
85 |
+
outputs = self.dcb_unit(inputs)
|
86 |
+
return outputs
|
87 |
+
|
88 |
+
|
89 |
+
class conv2DBatchNormRelu(nn.Module):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
in_channels,
|
93 |
+
n_filters,
|
94 |
+
k_size,
|
95 |
+
stride,
|
96 |
+
padding,
|
97 |
+
bias=True,
|
98 |
+
dilation=1,
|
99 |
+
is_batchnorm=True,
|
100 |
+
):
|
101 |
+
super(conv2DBatchNormRelu, self).__init__()
|
102 |
+
|
103 |
+
conv_mod = nn.Conv2d(int(in_channels),
|
104 |
+
int(n_filters),
|
105 |
+
kernel_size=k_size,
|
106 |
+
padding=padding,
|
107 |
+
stride=stride,
|
108 |
+
bias=bias,
|
109 |
+
dilation=dilation,)
|
110 |
+
|
111 |
+
if is_batchnorm:
|
112 |
+
self.cbr_unit = nn.Sequential(conv_mod,
|
113 |
+
nn.BatchNorm2d(int(n_filters)),
|
114 |
+
nn.ReLU(inplace=True))
|
115 |
+
else:
|
116 |
+
self.cbr_unit = nn.Sequential(conv_mod, nn.ReLU(inplace=True))
|
117 |
+
|
118 |
+
def forward(self, inputs):
|
119 |
+
outputs = self.cbr_unit(inputs)
|
120 |
+
return outputs
|
121 |
+
|
122 |
+
|
123 |
+
class conv2DGroupNormRelu(nn.Module):
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
in_channels,
|
127 |
+
n_filters,
|
128 |
+
k_size,
|
129 |
+
stride,
|
130 |
+
padding,
|
131 |
+
bias=True,
|
132 |
+
dilation=1,
|
133 |
+
n_groups=16,
|
134 |
+
):
|
135 |
+
super(conv2DGroupNormRelu, self).__init__()
|
136 |
+
|
137 |
+
conv_mod = nn.Conv2d(int(in_channels),
|
138 |
+
int(n_filters),
|
139 |
+
kernel_size=k_size,
|
140 |
+
padding=padding,
|
141 |
+
stride=stride,
|
142 |
+
bias=bias,
|
143 |
+
dilation=dilation,)
|
144 |
+
|
145 |
+
self.cgr_unit = nn.Sequential(conv_mod,
|
146 |
+
nn.GroupNorm(n_groups, int(n_filters)),
|
147 |
+
nn.ReLU(inplace=True))
|
148 |
+
|
149 |
+
def forward(self, inputs):
|
150 |
+
outputs = self.cgr_unit(inputs)
|
151 |
+
return outputs
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
class deconv2DBatchNormRelu(nn.Module):
|
156 |
+
def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True):
|
157 |
+
super(deconv2DBatchNormRelu, self).__init__()
|
158 |
+
|
159 |
+
self.dcbr_unit = nn.Sequential(
|
160 |
+
nn.ConvTranspose2d(
|
161 |
+
int(in_channels),
|
162 |
+
int(n_filters),
|
163 |
+
kernel_size=k_size,
|
164 |
+
padding=padding,
|
165 |
+
stride=stride,
|
166 |
+
bias=bias,
|
167 |
+
),
|
168 |
+
nn.BatchNorm2d(int(n_filters)),
|
169 |
+
nn.ReLU(inplace=True),
|
170 |
+
)
|
171 |
+
|
172 |
+
def forward(self, inputs):
|
173 |
+
outputs = self.dcbr_unit(inputs)
|
174 |
+
return outputs
|
175 |
+
|
176 |
+
|
177 |
+
class unetConv2(nn.Module):
|
178 |
+
def __init__(self, in_size, out_size, is_batchnorm):
|
179 |
+
super(unetConv2, self).__init__()
|
180 |
+
|
181 |
+
if is_batchnorm:
|
182 |
+
self.conv1 = nn.Sequential(
|
183 |
+
nn.Conv2d(in_size, out_size, 3, 1, 1),
|
184 |
+
nn.BatchNorm2d(out_size),
|
185 |
+
nn.ReLU(),
|
186 |
+
)
|
187 |
+
self.conv2 = nn.Sequential(
|
188 |
+
nn.Conv2d(out_size, out_size, 3, 1, 1),
|
189 |
+
nn.BatchNorm2d(out_size),
|
190 |
+
nn.ReLU(),
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 1), nn.ReLU())
|
194 |
+
self.conv2 = nn.Sequential(
|
195 |
+
nn.Conv2d(out_size, out_size, 3, 1, 1), nn.ReLU()
|
196 |
+
)
|
197 |
+
|
198 |
+
def forward(self, inputs):
|
199 |
+
outputs = self.conv1(inputs)
|
200 |
+
#print (outputs.shape)
|
201 |
+
outputs = self.conv2(outputs)
|
202 |
+
#print (outputs.shape)
|
203 |
+
return outputs
|
204 |
+
|
205 |
+
|
206 |
+
class unetUp(nn.Module):
|
207 |
+
def __init__(self, in_size, out_size, is_deconv, is_batchnorm):
|
208 |
+
super(unetUp, self).__init__()
|
209 |
+
self.conv = unetConv2(in_size, out_size, is_batchnorm)
|
210 |
+
if is_deconv:
|
211 |
+
self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
|
212 |
+
else:
|
213 |
+
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
|
214 |
+
|
215 |
+
def forward(self, inputs1, inputs2):
|
216 |
+
outputs2 = self.up(inputs2)
|
217 |
+
offset = outputs2.size()[2] - inputs1.size()[2]
|
218 |
+
padding = 2 * [offset // 2, offset // 2]
|
219 |
+
outputs1 = F.pad(inputs1, padding)
|
220 |
+
|
221 |
+
return self.conv(torch.cat([outputs1, outputs2], 1))
|
222 |
+
|
223 |
+
|
224 |
+
class segnetDown2(nn.Module):
|
225 |
+
def __init__(self, in_size, out_size):
|
226 |
+
super(segnetDown2, self).__init__()
|
227 |
+
self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)
|
228 |
+
self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1)
|
229 |
+
self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)
|
230 |
+
|
231 |
+
def forward(self, inputs):
|
232 |
+
outputs = self.conv1(inputs)
|
233 |
+
outputs = self.conv2(outputs)
|
234 |
+
unpooled_shape = outputs.size()
|
235 |
+
outputs, indices = self.maxpool_with_argmax(outputs)
|
236 |
+
return outputs, indices, unpooled_shape
|
237 |
+
|
238 |
+
|
239 |
+
class segnetDown3(nn.Module):
|
240 |
+
def __init__(self, in_size, out_size):
|
241 |
+
super(segnetDown3, self).__init__()
|
242 |
+
self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)
|
243 |
+
self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1)
|
244 |
+
self.conv3 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1)
|
245 |
+
self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)
|
246 |
+
|
247 |
+
def forward(self, inputs):
|
248 |
+
outputs = self.conv1(inputs)
|
249 |
+
outputs = self.conv2(outputs)
|
250 |
+
outputs = self.conv3(outputs)
|
251 |
+
unpooled_shape = outputs.size()
|
252 |
+
outputs, indices = self.maxpool_with_argmax(outputs)
|
253 |
+
return outputs, indices, unpooled_shape
|
254 |
+
|
255 |
+
|
256 |
+
class segnetUp2(nn.Module):
|
257 |
+
def __init__(self, in_size, out_size):
|
258 |
+
super(segnetUp2, self).__init__()
|
259 |
+
self.unpool = nn.MaxUnpool2d(2, 2)
|
260 |
+
self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1)
|
261 |
+
self.conv2 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)
|
262 |
+
|
263 |
+
def forward(self, inputs, indices, output_shape):
|
264 |
+
outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape)
|
265 |
+
outputs = self.conv1(outputs)
|
266 |
+
outputs = self.conv2(outputs)
|
267 |
+
return outputs
|
268 |
+
|
269 |
+
|
270 |
+
class segnetUp3(nn.Module):
|
271 |
+
def __init__(self, in_size, out_size):
|
272 |
+
super(segnetUp3, self).__init__()
|
273 |
+
self.unpool = nn.MaxUnpool2d(2, 2)
|
274 |
+
self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1)
|
275 |
+
self.conv2 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1)
|
276 |
+
self.conv3 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)
|
277 |
+
|
278 |
+
def forward(self, inputs, indices, output_shape):
|
279 |
+
outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape)
|
280 |
+
outputs = self.conv1(outputs)
|
281 |
+
outputs = self.conv2(outputs)
|
282 |
+
outputs = self.conv3(outputs)
|
283 |
+
return outputs
|
284 |
+
|
285 |
+
|
286 |
+
class residualBlock(nn.Module):
|
287 |
+
expansion = 1
|
288 |
+
|
289 |
+
def __init__(self, in_channels, n_filters, stride=1, downsample=None):
|
290 |
+
super(residualBlock, self).__init__()
|
291 |
+
|
292 |
+
self.convbnrelu1 = conv2DBatchNormRelu(
|
293 |
+
in_channels, n_filters, 3, stride, 1, bias=False
|
294 |
+
)
|
295 |
+
self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, bias=False)
|
296 |
+
self.downsample = downsample
|
297 |
+
self.stride = stride
|
298 |
+
self.relu = nn.ReLU(inplace=True)
|
299 |
+
|
300 |
+
def forward(self, x):
|
301 |
+
residual = x
|
302 |
+
|
303 |
+
out = self.convbnrelu1(x)
|
304 |
+
out = self.convbn2(out)
|
305 |
+
|
306 |
+
if self.downsample is not None:
|
307 |
+
residual = self.downsample(x)
|
308 |
+
|
309 |
+
out += residual
|
310 |
+
out = self.relu(out)
|
311 |
+
return out
|
312 |
+
|
313 |
+
|
314 |
+
class residualBottleneck(nn.Module):
|
315 |
+
expansion = 4
|
316 |
+
|
317 |
+
def __init__(self, in_channels, n_filters, stride=1, downsample=None):
|
318 |
+
super(residualBottleneck, self).__init__()
|
319 |
+
self.convbn1 = nn.Conv2DBatchNorm(in_channels, n_filters, k_size=1, bias=False)
|
320 |
+
self.convbn2 = nn.Conv2DBatchNorm(
|
321 |
+
n_filters, n_filters, k_size=3, padding=1, stride=stride, bias=False
|
322 |
+
)
|
323 |
+
self.convbn3 = nn.Conv2DBatchNorm(
|
324 |
+
n_filters, n_filters * 4, k_size=1, bias=False
|
325 |
+
)
|
326 |
+
self.relu = nn.ReLU(inplace=True)
|
327 |
+
self.downsample = downsample
|
328 |
+
self.stride = stride
|
329 |
+
|
330 |
+
def forward(self, x):
|
331 |
+
residual = x
|
332 |
+
|
333 |
+
out = self.convbn1(x)
|
334 |
+
out = self.convbn2(out)
|
335 |
+
out = self.convbn3(out)
|
336 |
+
|
337 |
+
if self.downsample is not None:
|
338 |
+
residual = self.downsample(x)
|
339 |
+
|
340 |
+
out += residual
|
341 |
+
out = self.relu(out)
|
342 |
+
|
343 |
+
return out
|
344 |
+
|
345 |
+
|
346 |
+
class linknetUp(nn.Module):
|
347 |
+
def __init__(self, in_channels, n_filters):
|
348 |
+
super(linknetUp, self).__init__()
|
349 |
+
|
350 |
+
# B, 2C, H, W -> B, C/2, H, W
|
351 |
+
self.convbnrelu1 = conv2DBatchNormRelu(
|
352 |
+
in_channels, n_filters / 2, k_size=1, stride=1, padding=1
|
353 |
+
)
|
354 |
+
|
355 |
+
# B, C/2, H, W -> B, C/2, H, W
|
356 |
+
self.deconvbnrelu2 = nn.deconv2DBatchNormRelu(
|
357 |
+
n_filters / 2, n_filters / 2, k_size=3, stride=2, padding=0
|
358 |
+
)
|
359 |
+
|
360 |
+
# B, C/2, H, W -> B, C, H, W
|
361 |
+
self.convbnrelu3 = conv2DBatchNormRelu(
|
362 |
+
n_filters / 2, n_filters, k_size=1, stride=1, padding=1
|
363 |
+
)
|
364 |
+
|
365 |
+
def forward(self, x):
|
366 |
+
x = self.convbnrelu1(x)
|
367 |
+
x = self.deconvbnrelu2(x)
|
368 |
+
x = self.convbnrelu3(x)
|
369 |
+
return x
|
370 |
+
|
371 |
+
|
372 |
+
class FRRU(nn.Module):
|
373 |
+
"""
|
374 |
+
Full Resolution Residual Unit for FRRN
|
375 |
+
"""
|
376 |
+
|
377 |
+
def __init__(self,
|
378 |
+
prev_channels,
|
379 |
+
out_channels,
|
380 |
+
scale,
|
381 |
+
group_norm=False,
|
382 |
+
n_groups=None):
|
383 |
+
super(FRRU, self).__init__()
|
384 |
+
self.scale = scale
|
385 |
+
self.prev_channels = prev_channels
|
386 |
+
self.out_channels = out_channels
|
387 |
+
self.group_norm = group_norm
|
388 |
+
self.n_groups = n_groups
|
389 |
+
|
390 |
+
|
391 |
+
if self.group_norm:
|
392 |
+
conv_unit = conv2DGroupNormRelu
|
393 |
+
self.conv1 = conv_unit(
|
394 |
+
prev_channels + 32, out_channels, k_size=3,
|
395 |
+
stride=1, padding=1, bias=False, n_groups=self.n_groups
|
396 |
+
)
|
397 |
+
self.conv2 = conv_unit(
|
398 |
+
out_channels, out_channels, k_size=3,
|
399 |
+
stride=1, padding=1, bias=False, n_groups=self.n_groups
|
400 |
+
)
|
401 |
+
|
402 |
+
else:
|
403 |
+
conv_unit = conv2DBatchNormRelu
|
404 |
+
self.conv1 = conv_unit(prev_channels + 32, out_channels, k_size=3,
|
405 |
+
stride=1, padding=1, bias=False,)
|
406 |
+
self.conv2 = conv_unit(out_channels, out_channels, k_size=3,
|
407 |
+
stride=1, padding=1, bias=False,)
|
408 |
+
|
409 |
+
self.conv_res = nn.Conv2d(out_channels, 32, kernel_size=1, stride=1, padding=0)
|
410 |
+
|
411 |
+
def forward(self, y, z):
|
412 |
+
x = torch.cat([y, nn.MaxPool2d(self.scale, self.scale)(z)], dim=1)
|
413 |
+
y_prime = self.conv1(x)
|
414 |
+
y_prime = self.conv2(y_prime)
|
415 |
+
|
416 |
+
x = self.conv_res(y_prime)
|
417 |
+
upsample_size = torch.Size([_s * self.scale for _s in y_prime.shape[-2:]])
|
418 |
+
x = F.upsample(x, size=upsample_size, mode="nearest")
|
419 |
+
z_prime = z + x
|
420 |
+
|
421 |
+
return y_prime, z_prime
|
422 |
+
|
423 |
+
|
424 |
+
class RU(nn.Module):
|
425 |
+
"""
|
426 |
+
Residual Unit for FRRN
|
427 |
+
"""
|
428 |
+
|
429 |
+
def __init__(self,
|
430 |
+
channels,
|
431 |
+
kernel_size=3,
|
432 |
+
strides=1,
|
433 |
+
group_norm=False,
|
434 |
+
n_groups=None):
|
435 |
+
super(RU, self).__init__()
|
436 |
+
self.group_norm = group_norm
|
437 |
+
self.n_groups = n_groups
|
438 |
+
|
439 |
+
if self.group_norm:
|
440 |
+
self.conv1 = conv2DGroupNormRelu(
|
441 |
+
channels, channels, k_size=kernel_size,
|
442 |
+
stride=strides, padding=1, bias=False,n_groups=self.n_groups)
|
443 |
+
self.conv2 = conv2DGroupNorm(
|
444 |
+
channels, channels, k_size=kernel_size,
|
445 |
+
stride=strides, padding=1, bias=False,n_groups=self.n_groups)
|
446 |
+
|
447 |
+
else:
|
448 |
+
self.conv1 = conv2DBatchNormRelu(
|
449 |
+
channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False,)
|
450 |
+
self.conv2 = conv2DBatchNorm(
|
451 |
+
channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False,)
|
452 |
+
|
453 |
+
def forward(self, x):
|
454 |
+
incoming = x
|
455 |
+
x = self.conv1(x)
|
456 |
+
x = self.conv2(x)
|
457 |
+
return x + incoming
|
458 |
+
|
459 |
+
|
460 |
+
class residualConvUnit(nn.Module):
|
461 |
+
def __init__(self, channels, kernel_size=3):
|
462 |
+
super(residualConvUnit, self).__init__()
|
463 |
+
|
464 |
+
self.residual_conv_unit = nn.Sequential(
|
465 |
+
nn.ReLU(inplace=True),
|
466 |
+
nn.Conv2d(channels, channels, kernel_size=kernel_size),
|
467 |
+
nn.ReLU(inplace=True),
|
468 |
+
nn.Conv2d(channels, channels, kernel_size=kernel_size),
|
469 |
+
)
|
470 |
+
|
471 |
+
def forward(self, x):
|
472 |
+
input = x
|
473 |
+
x = self.residual_conv_unit(x)
|
474 |
+
return x + input
|
475 |
+
|
476 |
+
|
477 |
+
class multiResolutionFusion(nn.Module):
|
478 |
+
def __init__(self, channels, up_scale_high, up_scale_low, high_shape, low_shape):
|
479 |
+
super(multiResolutionFusion, self).__init__()
|
480 |
+
|
481 |
+
self.up_scale_high = up_scale_high
|
482 |
+
self.up_scale_low = up_scale_low
|
483 |
+
|
484 |
+
self.conv_high = nn.Conv2d(high_shape[1], channels, kernel_size=3)
|
485 |
+
|
486 |
+
if low_shape is not None:
|
487 |
+
self.conv_low = nn.Conv2d(low_shape[1], channels, kernel_size=3)
|
488 |
+
|
489 |
+
def forward(self, x_high, x_low):
|
490 |
+
high_upsampled = F.upsample(
|
491 |
+
self.conv_high(x_high), scale_factor=self.up_scale_high, mode="bilinear"
|
492 |
+
)
|
493 |
+
|
494 |
+
if x_low is None:
|
495 |
+
return high_upsampled
|
496 |
+
|
497 |
+
low_upsampled = F.upsample(
|
498 |
+
self.conv_low(x_low), scale_factor=self.up_scale_low, mode="bilinear"
|
499 |
+
)
|
500 |
+
|
501 |
+
return low_upsampled + high_upsampled
|
502 |
+
|
503 |
+
|
504 |
+
class chainedResidualPooling(nn.Module):
|
505 |
+
def __init__(self, channels, input_shape):
|
506 |
+
super(chainedResidualPooling, self).__init__()
|
507 |
+
|
508 |
+
self.chained_residual_pooling = nn.Sequential(
|
509 |
+
nn.ReLU(inplace=True),
|
510 |
+
nn.MaxPool2d(5, 1, 2),
|
511 |
+
nn.Conv2d(input_shape[1], channels, kernel_size=3),
|
512 |
+
)
|
513 |
+
|
514 |
+
def forward(self, x):
|
515 |
+
input = x
|
516 |
+
x = self.chained_residual_pooling(x)
|
517 |
+
return x + input
|
518 |
+
|
519 |
+
|
520 |
+
class pyramidPooling(nn.Module):
|
521 |
+
def __init__(
|
522 |
+
self,
|
523 |
+
in_channels,
|
524 |
+
pool_sizes,
|
525 |
+
model_name="pspnet",
|
526 |
+
fusion_mode="cat",
|
527 |
+
is_batchnorm=True,
|
528 |
+
):
|
529 |
+
super(pyramidPooling, self).__init__()
|
530 |
+
|
531 |
+
bias = not is_batchnorm
|
532 |
+
|
533 |
+
self.paths = []
|
534 |
+
for i in range(len(pool_sizes)):
|
535 |
+
self.paths.append(
|
536 |
+
conv2DBatchNormRelu(
|
537 |
+
in_channels,
|
538 |
+
int(in_channels / len(pool_sizes)),
|
539 |
+
1,
|
540 |
+
1,
|
541 |
+
0,
|
542 |
+
bias=bias,
|
543 |
+
is_batchnorm=is_batchnorm,
|
544 |
+
)
|
545 |
+
)
|
546 |
+
|
547 |
+
self.path_module_list = nn.ModuleList(self.paths)
|
548 |
+
self.pool_sizes = pool_sizes
|
549 |
+
self.model_name = model_name
|
550 |
+
self.fusion_mode = fusion_mode
|
551 |
+
|
552 |
+
def forward(self, x):
|
553 |
+
h, w = x.shape[2:]
|
554 |
+
|
555 |
+
if self.training or self.model_name != "icnet": # general settings or pspnet
|
556 |
+
k_sizes = []
|
557 |
+
strides = []
|
558 |
+
for pool_size in self.pool_sizes:
|
559 |
+
k_sizes.append((int(h / pool_size), int(w / pool_size)))
|
560 |
+
strides.append((int(h / pool_size), int(w / pool_size)))
|
561 |
+
else: # eval mode and icnet: pre-trained for 1025 x 2049
|
562 |
+
k_sizes = [(8, 15), (13, 25), (17, 33), (33, 65)]
|
563 |
+
strides = [(5, 10), (10, 20), (16, 32), (33, 65)]
|
564 |
+
|
565 |
+
if self.fusion_mode == "cat": # pspnet: concat (including x)
|
566 |
+
output_slices = [x]
|
567 |
+
|
568 |
+
for i, (module, pool_size) in enumerate(
|
569 |
+
zip(self.path_module_list, self.pool_sizes)
|
570 |
+
):
|
571 |
+
out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0)
|
572 |
+
# out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size))
|
573 |
+
if self.model_name != "icnet":
|
574 |
+
out = module(out)
|
575 |
+
out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True)
|
576 |
+
output_slices.append(out)
|
577 |
+
|
578 |
+
return torch.cat(output_slices, dim=1)
|
579 |
+
else: # icnet: element-wise sum (including x)
|
580 |
+
pp_sum = x
|
581 |
+
|
582 |
+
for i, (module, pool_size) in enumerate(
|
583 |
+
zip(self.path_module_list, self.pool_sizes)
|
584 |
+
):
|
585 |
+
out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0)
|
586 |
+
# out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size))
|
587 |
+
if self.model_name != "icnet":
|
588 |
+
out = module(out)
|
589 |
+
out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True)
|
590 |
+
pp_sum = pp_sum + out
|
591 |
+
|
592 |
+
return pp_sum
|
593 |
+
|
594 |
+
|
595 |
+
class bottleNeckPSP(nn.Module):
|
596 |
+
def __init__(
|
597 |
+
self, in_channels, mid_channels, out_channels, stride, dilation=1, is_batchnorm=True
|
598 |
+
):
|
599 |
+
super(bottleNeckPSP, self).__init__()
|
600 |
+
|
601 |
+
bias = not is_batchnorm
|
602 |
+
|
603 |
+
self.cbr1 = conv2DBatchNormRelu(
|
604 |
+
in_channels,
|
605 |
+
mid_channels,
|
606 |
+
1,
|
607 |
+
stride=1,
|
608 |
+
padding=0,
|
609 |
+
bias=bias,
|
610 |
+
is_batchnorm=is_batchnorm,
|
611 |
+
)
|
612 |
+
if dilation > 1:
|
613 |
+
self.cbr2 = conv2DBatchNormRelu(
|
614 |
+
mid_channels,
|
615 |
+
mid_channels,
|
616 |
+
3,
|
617 |
+
stride=stride,
|
618 |
+
padding=dilation,
|
619 |
+
bias=bias,
|
620 |
+
dilation=dilation,
|
621 |
+
is_batchnorm=is_batchnorm,
|
622 |
+
)
|
623 |
+
else:
|
624 |
+
self.cbr2 = conv2DBatchNormRelu(
|
625 |
+
mid_channels,
|
626 |
+
mid_channels,
|
627 |
+
3,
|
628 |
+
stride=stride,
|
629 |
+
padding=1,
|
630 |
+
bias=bias,
|
631 |
+
dilation=1,
|
632 |
+
is_batchnorm=is_batchnorm,
|
633 |
+
)
|
634 |
+
self.cb3 = conv2DBatchNorm(
|
635 |
+
mid_channels,
|
636 |
+
out_channels,
|
637 |
+
1,
|
638 |
+
stride=1,
|
639 |
+
padding=0,
|
640 |
+
bias=bias,
|
641 |
+
is_batchnorm=is_batchnorm,
|
642 |
+
)
|
643 |
+
self.cb4 = conv2DBatchNorm(
|
644 |
+
in_channels,
|
645 |
+
out_channels,
|
646 |
+
1,
|
647 |
+
stride=stride,
|
648 |
+
padding=0,
|
649 |
+
bias=bias,
|
650 |
+
is_batchnorm=is_batchnorm,
|
651 |
+
)
|
652 |
+
|
653 |
+
def forward(self, x):
|
654 |
+
conv = self.cb3(self.cbr2(self.cbr1(x)))
|
655 |
+
residual = self.cb4(x)
|
656 |
+
return F.relu(conv + residual, inplace=True)
|
657 |
+
|
658 |
+
|
659 |
+
class bottleNeckIdentifyPSP(nn.Module):
|
660 |
+
def __init__(self, in_channels, mid_channels, stride, dilation=1, is_batchnorm=True):
|
661 |
+
super(bottleNeckIdentifyPSP, self).__init__()
|
662 |
+
|
663 |
+
bias = not is_batchnorm
|
664 |
+
|
665 |
+
self.cbr1 = conv2DBatchNormRelu(
|
666 |
+
in_channels,
|
667 |
+
mid_channels,
|
668 |
+
1,
|
669 |
+
stride=1,
|
670 |
+
padding=0,
|
671 |
+
bias=bias,
|
672 |
+
is_batchnorm=is_batchnorm,
|
673 |
+
)
|
674 |
+
if dilation > 1:
|
675 |
+
self.cbr2 = conv2DBatchNormRelu(
|
676 |
+
mid_channels,
|
677 |
+
mid_channels,
|
678 |
+
3,
|
679 |
+
stride=1,
|
680 |
+
padding=dilation,
|
681 |
+
bias=bias,
|
682 |
+
dilation=dilation,
|
683 |
+
is_batchnorm=is_batchnorm,
|
684 |
+
)
|
685 |
+
else:
|
686 |
+
self.cbr2 = conv2DBatchNormRelu(
|
687 |
+
mid_channels,
|
688 |
+
mid_channels,
|
689 |
+
3,
|
690 |
+
stride=1,
|
691 |
+
padding=1,
|
692 |
+
bias=bias,
|
693 |
+
dilation=1,
|
694 |
+
is_batchnorm=is_batchnorm,
|
695 |
+
)
|
696 |
+
self.cb3 = conv2DBatchNorm(
|
697 |
+
mid_channels,
|
698 |
+
in_channels,
|
699 |
+
1,
|
700 |
+
stride=1,
|
701 |
+
padding=0,
|
702 |
+
bias=bias,
|
703 |
+
is_batchnorm=is_batchnorm,
|
704 |
+
)
|
705 |
+
|
706 |
+
def forward(self, x):
|
707 |
+
residual = x
|
708 |
+
x = self.cb3(self.cbr2(self.cbr1(x)))
|
709 |
+
return F.relu(x + residual, inplace=True)
|
710 |
+
|
711 |
+
|
712 |
+
class residualBlockPSP(nn.Module):
|
713 |
+
def __init__(
|
714 |
+
self,
|
715 |
+
n_blocks,
|
716 |
+
in_channels,
|
717 |
+
mid_channels,
|
718 |
+
out_channels,
|
719 |
+
stride,
|
720 |
+
dilation=1,
|
721 |
+
include_range="all",
|
722 |
+
is_batchnorm=True,
|
723 |
+
):
|
724 |
+
super(residualBlockPSP, self).__init__()
|
725 |
+
|
726 |
+
if dilation > 1:
|
727 |
+
stride = 1
|
728 |
+
|
729 |
+
# residualBlockPSP = convBlockPSP + identityBlockPSPs
|
730 |
+
layers = []
|
731 |
+
if include_range in ["all", "conv"]:
|
732 |
+
layers.append(
|
733 |
+
bottleNeckPSP(
|
734 |
+
in_channels,
|
735 |
+
mid_channels,
|
736 |
+
out_channels,
|
737 |
+
stride,
|
738 |
+
dilation,
|
739 |
+
is_batchnorm=is_batchnorm,
|
740 |
+
)
|
741 |
+
)
|
742 |
+
if include_range in ["all", "identity"]:
|
743 |
+
for i in range(n_blocks - 1):
|
744 |
+
layers.append(
|
745 |
+
bottleNeckIdentifyPSP(
|
746 |
+
out_channels, mid_channels, stride, dilation, is_batchnorm=is_batchnorm
|
747 |
+
)
|
748 |
+
)
|
749 |
+
|
750 |
+
self.layers = nn.Sequential(*layers)
|
751 |
+
|
752 |
+
def forward(self, x):
|
753 |
+
return self.layers(x)
|
754 |
+
|
755 |
+
|
756 |
+
class cascadeFeatureFusion(nn.Module):
|
757 |
+
def __init__(
|
758 |
+
self, n_classes, low_in_channels, high_in_channels, out_channels, is_batchnorm=True
|
759 |
+
):
|
760 |
+
super(cascadeFeatureFusion, self).__init__()
|
761 |
+
|
762 |
+
bias = not is_batchnorm
|
763 |
+
|
764 |
+
self.low_dilated_conv_bn = conv2DBatchNorm(
|
765 |
+
low_in_channels,
|
766 |
+
out_channels,
|
767 |
+
3,
|
768 |
+
stride=1,
|
769 |
+
padding=2,
|
770 |
+
bias=bias,
|
771 |
+
dilation=2,
|
772 |
+
is_batchnorm=is_batchnorm,
|
773 |
+
)
|
774 |
+
self.low_classifier_conv = nn.Conv2d(
|
775 |
+
int(low_in_channels),
|
776 |
+
int(n_classes),
|
777 |
+
kernel_size=1,
|
778 |
+
padding=0,
|
779 |
+
stride=1,
|
780 |
+
bias=True,
|
781 |
+
dilation=1,
|
782 |
+
) # Train only
|
783 |
+
self.high_proj_conv_bn = conv2DBatchNorm(
|
784 |
+
high_in_channels,
|
785 |
+
out_channels,
|
786 |
+
1,
|
787 |
+
stride=1,
|
788 |
+
padding=0,
|
789 |
+
bias=bias,
|
790 |
+
is_batchnorm=is_batchnorm,
|
791 |
+
)
|
792 |
+
|
793 |
+
def forward(self, x_low, x_high):
|
794 |
+
x_low_upsampled = F.interpolate(
|
795 |
+
x_low, size=get_interp_size(x_low, z_factor=2), mode="bilinear", align_corners=True
|
796 |
+
)
|
797 |
+
|
798 |
+
low_cls = self.low_classifier_conv(x_low_upsampled)
|
799 |
+
|
800 |
+
low_fm = self.low_dilated_conv_bn(x_low_upsampled)
|
801 |
+
high_fm = self.high_proj_conv_bn(x_high)
|
802 |
+
high_fused_fm = F.relu(low_fm + high_fm, inplace=True)
|
803 |
+
|
804 |
+
return high_fused_fm, low_cls
|
805 |
+
|
806 |
+
|
807 |
+
def get_interp_size(input, s_factor=1, z_factor=1): # for caffe
|
808 |
+
ori_h, ori_w = input.shape[2:]
|
809 |
+
|
810 |
+
# shrink (s_factor >= 1)
|
811 |
+
ori_h = (ori_h - 1) / s_factor + 1
|
812 |
+
ori_w = (ori_w - 1) / s_factor + 1
|
813 |
+
|
814 |
+
# zoom (z_factor >= 1)
|
815 |
+
ori_h = ori_h + (ori_h - 1) * (z_factor - 1)
|
816 |
+
ori_w = ori_w + (ori_w - 1) * (z_factor - 1)
|
817 |
+
|
818 |
+
resize_shape = (int(ori_h), int(ori_w))
|
819 |
+
return resize_shape
|
820 |
+
|
821 |
+
|
822 |
+
def interp(input, output_size, mode="bilinear"):
|
823 |
+
n, c, ih, iw = input.shape
|
824 |
+
oh, ow = output_size
|
825 |
+
|
826 |
+
# normalize to [-1, 1]
|
827 |
+
h = torch.arange(0, oh, dtype=torch.float, device='cuda' if input.is_cuda else 'cpu') / (oh - 1) * 2 - 1
|
828 |
+
w = torch.arange(0, ow, dtype=torch.float, device='cuda' if input.is_cuda else 'cpu') / (ow - 1) * 2 - 1
|
829 |
+
|
830 |
+
grid = torch.zeros(oh, ow, 2, dtype=torch.float, device='cuda' if input.is_cuda else 'cpu')
|
831 |
+
grid[:, :, 0] = w.unsqueeze(0).repeat(oh, 1)
|
832 |
+
grid[:, :, 1] = h.unsqueeze(0).repeat(ow, 1).transpose(0, 1)
|
833 |
+
grid = grid.unsqueeze(0).repeat(n, 1, 1, 1) # grid.shape: [n, oh, ow, 2]
|
834 |
+
|
835 |
+
return F.grid_sample(input, grid, mode=mode)
|
836 |
+
|
837 |
+
|
838 |
+
def get_upsampling_weight(in_channels, out_channels, kernel_size):
|
839 |
+
"""Make a 2D bilinear kernel suitable for upsampling"""
|
840 |
+
factor = (kernel_size + 1) // 2
|
841 |
+
if kernel_size % 2 == 1:
|
842 |
+
center = factor - 1
|
843 |
+
else:
|
844 |
+
center = factor - 0.5
|
845 |
+
og = np.ogrid[:kernel_size, :kernel_size]
|
846 |
+
filt = (1 - abs(og[0] - center) / factor) * \
|
847 |
+
(1 - abs(og[1] - center) / factor)
|
848 |
+
weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
|
849 |
+
dtype=np.float64)
|
850 |
+
weight[range(in_channels), range(out_channels), :, :] = filt
|
851 |
+
return torch.from_numpy(weight).float()
|
metrics/criteria/parse_related_loss/unet.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from criteria.parse_related_loss.model_utils import *
|
3 |
+
|
4 |
+
|
5 |
+
class unet(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
feature_scale=4,
|
9 |
+
n_classes=19,
|
10 |
+
is_deconv=True,
|
11 |
+
in_channels=3,
|
12 |
+
is_batchnorm=True,
|
13 |
+
):
|
14 |
+
super(unet, self).__init__()
|
15 |
+
self.is_deconv = is_deconv
|
16 |
+
self.in_channels = in_channels
|
17 |
+
self.is_batchnorm = is_batchnorm
|
18 |
+
self.feature_scale = feature_scale
|
19 |
+
|
20 |
+
filters = [64, 128, 256, 512, 1024]
|
21 |
+
filters = [int(x / self.feature_scale) for x in filters]
|
22 |
+
|
23 |
+
# downsampling
|
24 |
+
self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
|
25 |
+
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
|
26 |
+
|
27 |
+
self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
|
28 |
+
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
|
29 |
+
|
30 |
+
self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
|
31 |
+
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
|
32 |
+
|
33 |
+
self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
|
34 |
+
self.maxpool4 = nn.MaxPool2d(kernel_size=2)
|
35 |
+
|
36 |
+
self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)
|
37 |
+
|
38 |
+
# upsampling
|
39 |
+
self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv, self.is_batchnorm)
|
40 |
+
self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv, self.is_batchnorm)
|
41 |
+
self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv, self.is_batchnorm)
|
42 |
+
self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv, self.is_batchnorm)
|
43 |
+
|
44 |
+
# final conv (without any concat)
|
45 |
+
self.final = nn.Conv2d(filters[0], n_classes, 1)
|
46 |
+
|
47 |
+
def forward(self, inputs):
|
48 |
+
conv1 = self.conv1(inputs)
|
49 |
+
maxpool1 = self.maxpool1(conv1)
|
50 |
+
|
51 |
+
conv2 = self.conv2(maxpool1)
|
52 |
+
maxpool2 = self.maxpool2(conv2)
|
53 |
+
|
54 |
+
conv3 = self.conv3(maxpool2)
|
55 |
+
maxpool3 = self.maxpool3(conv3)
|
56 |
+
|
57 |
+
conv4 = self.conv4(maxpool3)
|
58 |
+
maxpool4 = self.maxpool4(conv4)
|
59 |
+
|
60 |
+
center = self.center(maxpool4)
|
61 |
+
up4 = self.up_concat4(conv4, center)
|
62 |
+
up3 = self.up_concat3(conv3, up4)
|
63 |
+
up2 = self.up_concat2(conv2, up3)
|
64 |
+
up1 = self.up_concat1(conv1, up2)
|
65 |
+
|
66 |
+
final = self.final(up1)
|
67 |
+
|
68 |
+
return final
|
metrics/face_eval.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .face_parsing import BiSeNet
|
2 |
+
import numpy as np
|
3 |
+
from .metrics import LPIPS, MS_SSIM, IdScore, ClipHair
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
class FaceSegmentation(nn.Module):
|
9 |
+
def __init__(self, n_classes=19, device='cuda', save_pth='./pretrained_models/79999_iter.pth'):
|
10 |
+
super(FaceSegmentation, self).__init__()
|
11 |
+
self.net = BiSeNet(n_classes=n_classes).to(device)
|
12 |
+
self.net.load_state_dict(torch.load(save_pth))
|
13 |
+
self.net.eval()
|
14 |
+
self.transform = transforms.Compose([
|
15 |
+
transforms.ToTensor(),
|
16 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
17 |
+
])
|
18 |
+
self.device=device
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
def get_facemask(self, parsing_anno):
|
23 |
+
"""
|
24 |
+
Returns a binary image of the face.
|
25 |
+
"""
|
26 |
+
# face_attr = {1: 'skin', 2: 'l_brow', 3: 'r_brow', 4: 'l_eye', 5: 'r_eye', 6: 'eye_glass', 7: 'l_ear', 8: 'r_ear', 10: 'nose', 11: 'mouth', 12: 'u_lip', 13: 'l_lip', 14: 'neck'}
|
27 |
+
face_attr = torch.tensor([1,2,3,4,5,6,7,8,10,11,12,13,14],device=self.device)
|
28 |
+
face_mask = torch.isin(parsing_anno, face_attr)
|
29 |
+
return(face_mask.int())
|
30 |
+
|
31 |
+
|
32 |
+
def get_hairmask(self, parsing_anno):
|
33 |
+
"""
|
34 |
+
Returns a binary image of the hair.
|
35 |
+
"""
|
36 |
+
hair_mask = parsing_anno == 17
|
37 |
+
return(hair_mask.int())
|
38 |
+
|
39 |
+
def forward(self, img):
|
40 |
+
"""
|
41 |
+
Returns a binary image of the face and hair.
|
42 |
+
"""
|
43 |
+
img = self.transform(img).to(self.device)
|
44 |
+
parsing_anno = self.net(img.unsqueeze(0))[0].squeeze(0).argmax(0)
|
45 |
+
face_mask = self.get_facemask(parsing_anno).to(self.device)
|
46 |
+
hair_mask = self.get_hairmask(parsing_anno).to(self.device)
|
47 |
+
return img, face_mask, hair_mask
|
48 |
+
|
49 |
+
|
50 |
+
class FaceMetric(nn.Module):
|
51 |
+
def __init__(self, metric_type, eval_face=True, eval_hair=True, device='cuda', seg_save_pth='./pretrained_models/79999_iter.pth'):
|
52 |
+
super(FaceMetric, self).__init__()
|
53 |
+
if metric_type == 'ms-ssim':
|
54 |
+
self.metric = MS_SSIM()
|
55 |
+
self.eval_hair= eval_hair
|
56 |
+
self.eval_face= eval_face
|
57 |
+
elif metric_type == 'lpips':
|
58 |
+
self.metric = LPIPS(device=device)
|
59 |
+
self.eval_hair= eval_hair
|
60 |
+
self.eval_face= eval_face
|
61 |
+
elif metric_type == 'id':
|
62 |
+
self.metric = IdScore(device=device)
|
63 |
+
self.eval_hair = False
|
64 |
+
self.eval_face = eval_face
|
65 |
+
elif metric_type == 'cliphair':
|
66 |
+
self.metric = ClipHair(device=device)
|
67 |
+
self.eval_face = False
|
68 |
+
self.eval_hair = eval_hair
|
69 |
+
else:
|
70 |
+
raise NotImplementedError
|
71 |
+
self.parser = FaceSegmentation(device=device, save_pth=seg_save_pth)
|
72 |
+
self.device=device
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
def forward(self, x, y):
|
78 |
+
face_score, hair_score = None, None
|
79 |
+
x_tensor, x_face_seg, x_hair_seg = self.parser(x)
|
80 |
+
y_tensor, y_face_seg, y_hair_seg = self.parser(y)
|
81 |
+
if self.eval_hair == True:
|
82 |
+
|
83 |
+
|
84 |
+
## Get union of two hair masks
|
85 |
+
#hair_mask = (x_hair_seg + y_hair_seg) > 0
|
86 |
+
|
87 |
+
x_hair = x_tensor * x_hair_seg
|
88 |
+
y_hair = y_tensor * y_hair_seg
|
89 |
+
|
90 |
+
hair_score = self.metric(x_hair, y_hair).item()
|
91 |
+
if self.eval_face == True:
|
92 |
+
|
93 |
+
## Get intersection of two face masks
|
94 |
+
face_mask = (x_face_seg + y_face_seg) > 1
|
95 |
+
|
96 |
+
x_face = x_tensor * face_mask
|
97 |
+
y_face = y_tensor * face_mask
|
98 |
+
|
99 |
+
face_score = self.metric(x_face, y_face).item()
|
100 |
+
|
101 |
+
return face_score, hair_score
|
102 |
+
|
103 |
+
|
metrics/metrics.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import Module, Sequential, Conv2d, BatchNorm2d, PReLU, Dropout, Flatten, Linear, BatchNorm1d, MaxPool2d, AdaptiveAvgPool2d, ReLU, Sigmoid
|
4 |
+
from collections import namedtuple
|
5 |
+
from pytorch_msssim import ms_ssim
|
6 |
+
import lpips
|
7 |
+
import clip
|
8 |
+
from torchvision import transforms
|
9 |
+
|
10 |
+
class LPIPS(nn.Module):
|
11 |
+
def __init__(self, net='alex', device='cuda'):
|
12 |
+
super(LPIPS, self).__init__()
|
13 |
+
self.lpips = lpips.LPIPS(net='alex').to(device)
|
14 |
+
|
15 |
+
def forward(self, x, y):
|
16 |
+
return 1- self.lpips(x, y).squeeze()
|
17 |
+
|
18 |
+
|
19 |
+
class MS_SSIM(nn.Module):
|
20 |
+
def __init__(self, avg=False):
|
21 |
+
super(MS_SSIM, self).__init__()
|
22 |
+
self.ssim = ms_ssim
|
23 |
+
self.avg = avg
|
24 |
+
|
25 |
+
def forward(self, x, y):
|
26 |
+
## normalize images to [0, 1]
|
27 |
+
x = (x+1)/2
|
28 |
+
y = (y+1)/2
|
29 |
+
return self.ssim(x.unsqueeze(0), y.unsqueeze(0), data_range=1, size_average=self.avg)
|
30 |
+
|
31 |
+
|
32 |
+
class IdScore(nn.Module):
|
33 |
+
# def __init__(self, opts):
|
34 |
+
def __init__(self, device='cuda'):
|
35 |
+
super(IdScore, self).__init__()
|
36 |
+
# print('Loading ResNet ArcFace')
|
37 |
+
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6).to(device)
|
38 |
+
self.facenet.load_state_dict(torch.load('./pretrained_models/model_ir_se50.pth'))
|
39 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
40 |
+
self.facenet.eval()
|
41 |
+
self.cosine_sim = nn.CosineSimilarity(dim=1)
|
42 |
+
|
43 |
+
|
44 |
+
def extract_feats(self, x):
|
45 |
+
x = self.face_pool(x)
|
46 |
+
x_feats = self.facenet(x)
|
47 |
+
return x_feats
|
48 |
+
|
49 |
+
def forward(self, y, x):
|
50 |
+
x = x.unsqueeze(0)
|
51 |
+
y = y.unsqueeze(0)
|
52 |
+
x_feats = self.extract_feats(x)
|
53 |
+
y_feats = self.extract_feats(y) # Otherwise use the feature from there
|
54 |
+
y_feats = y_feats.detach()
|
55 |
+
|
56 |
+
# diff_views = y_feats[0].dot(x_feats[0])
|
57 |
+
cosine_sim = self.cosine_sim(y_feats, x_feats)
|
58 |
+
|
59 |
+
return cosine_sim
|
60 |
+
|
61 |
+
class ClipHair(nn.Module):
|
62 |
+
def __init__(self, device='cuda'):
|
63 |
+
super(ClipHair, self).__init__()
|
64 |
+
self.model, self.preprocessing = clip.load("ViT-B/32", device=device)
|
65 |
+
self.cosine_sim = nn.CosineSimilarity(dim=1)
|
66 |
+
self.device = device
|
67 |
+
# self.model, self.preprocessing = model, preprocessing
|
68 |
+
|
69 |
+
def extract_feats(self, x):
|
70 |
+
|
71 |
+
x = transforms.ToPILImage()(x.squeeze())
|
72 |
+
x = self.preprocessing(x).unsqueeze(0).to(self.device)
|
73 |
+
x = self.model.encode_image(x)
|
74 |
+
return x
|
75 |
+
|
76 |
+
def forward(self, y, x):
|
77 |
+
x = x.unsqueeze(0)
|
78 |
+
y = y.unsqueeze(0)
|
79 |
+
x_feats = self.extract_feats(x)
|
80 |
+
y_feats = self.extract_feats(y)
|
81 |
+
y_feats = y_feats.detach()
|
82 |
+
|
83 |
+
cosine_sim = self.cosine_sim(x_feats, y_feats)
|
84 |
+
|
85 |
+
# diff_views = y_feats[0].dot(x_feats[0])/ (y_feats[0].norm() * x_feats[0].norm())
|
86 |
+
return cosine_sim
|
87 |
+
|
88 |
+
|
89 |
+
class bottleneck_IR_SE(Module):
|
90 |
+
def __init__(self, in_channel, depth, stride):
|
91 |
+
super(bottleneck_IR_SE, self).__init__()
|
92 |
+
if in_channel == depth:
|
93 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
94 |
+
else:
|
95 |
+
self.shortcut_layer = Sequential(
|
96 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
97 |
+
BatchNorm2d(depth)
|
98 |
+
)
|
99 |
+
self.res_layer = Sequential(
|
100 |
+
BatchNorm2d(in_channel),
|
101 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
102 |
+
PReLU(depth),
|
103 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
104 |
+
BatchNorm2d(depth),
|
105 |
+
SEModule(depth, 16)
|
106 |
+
)
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
shortcut = self.shortcut_layer(x)
|
110 |
+
res = self.res_layer(x)
|
111 |
+
return res + shortcut
|
112 |
+
|
113 |
+
|
114 |
+
class Backbone(Module):
|
115 |
+
def __init__(self, input_size, num_layers, drop_ratio=0.4, affine=True):
|
116 |
+
super(Backbone, self).__init__()
|
117 |
+
assert input_size in [112, 224], "input_size should be 112 or 224"
|
118 |
+
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
|
119 |
+
blocks = get_blocks(num_layers)
|
120 |
+
|
121 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
122 |
+
BatchNorm2d(64),
|
123 |
+
PReLU(64))
|
124 |
+
if input_size == 112:
|
125 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
126 |
+
Dropout(drop_ratio),
|
127 |
+
Flatten(),
|
128 |
+
Linear(512 * 7 * 7, 512),
|
129 |
+
BatchNorm1d(512, affine=affine))
|
130 |
+
else:
|
131 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
132 |
+
Dropout(drop_ratio),
|
133 |
+
Flatten(),
|
134 |
+
Linear(512 * 14 * 14, 512),
|
135 |
+
BatchNorm1d(512, affine=affine))
|
136 |
+
|
137 |
+
modules = []
|
138 |
+
for block in blocks:
|
139 |
+
for bottleneck in block:
|
140 |
+
modules.append(bottleneck_IR_SE(bottleneck.in_channel,
|
141 |
+
bottleneck.depth,
|
142 |
+
bottleneck.stride))
|
143 |
+
self.body = Sequential(*modules)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
x = self.input_layer(x)
|
147 |
+
x = self.body(x)
|
148 |
+
x = self.output_layer(x)
|
149 |
+
return l2_norm(x)
|
150 |
+
|
151 |
+
def get_blocks(num_layers):
|
152 |
+
if num_layers == 50:
|
153 |
+
blocks = [
|
154 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
155 |
+
get_block(in_channel=64, depth=128, num_units=4),
|
156 |
+
get_block(in_channel=128, depth=256, num_units=14),
|
157 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
158 |
+
]
|
159 |
+
elif num_layers == 100:
|
160 |
+
blocks = [
|
161 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
162 |
+
get_block(in_channel=64, depth=128, num_units=13),
|
163 |
+
get_block(in_channel=128, depth=256, num_units=30),
|
164 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
165 |
+
]
|
166 |
+
elif num_layers == 152:
|
167 |
+
blocks = [
|
168 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
169 |
+
get_block(in_channel=64, depth=128, num_units=8),
|
170 |
+
get_block(in_channel=128, depth=256, num_units=36),
|
171 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
172 |
+
]
|
173 |
+
else:
|
174 |
+
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
|
175 |
+
return blocks
|
176 |
+
|
177 |
+
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
178 |
+
""" A named tuple describing a ResNet block. """
|
179 |
+
|
180 |
+
|
181 |
+
def get_block(in_channel, depth, num_units, stride=2):
|
182 |
+
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
183 |
+
|
184 |
+
def l2_norm(input, axis=1):
|
185 |
+
norm = torch.norm(input, 2, axis, True)
|
186 |
+
output = torch.div(input, norm)
|
187 |
+
return output
|
188 |
+
|
189 |
+
class SEModule(Module):
|
190 |
+
def __init__(self, channels, reduction):
|
191 |
+
super(SEModule, self).__init__()
|
192 |
+
self.avg_pool = AdaptiveAvgPool2d(1)
|
193 |
+
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
194 |
+
self.relu = ReLU(inplace=True)
|
195 |
+
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
196 |
+
self.sigmoid = Sigmoid()
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
module_input = x
|
200 |
+
x = self.avg_pool(x)
|
201 |
+
x = self.fc1(x)
|
202 |
+
x = self.relu(x)
|
203 |
+
x = self.fc2(x)
|
204 |
+
x = self.sigmoid(x)
|
205 |
+
return module_input * x
|
requirements.txt
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
torch
|
2 |
torchvision
|
|
|
3 |
dlib
|
4 |
pillow
|
5 |
numpy
|
|
|
1 |
torch
|
2 |
torchvision
|
3 |
+
cudatoolkit
|
4 |
dlib
|
5 |
pillow
|
6 |
numpy
|
ris/model.py
CHANGED
@@ -508,12 +508,7 @@ class Generator(nn.Module):
|
|
508 |
output.append(self.to_rgb1.get_latent(latent[:, 1]))
|
509 |
|
510 |
i = 1
|
511 |
-
# print("Get latent dimensions:")
|
512 |
for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2], self.to_rgbs):
|
513 |
-
# print(f'{i}: {conv1.get_latent(latent[:, i]).shape}')
|
514 |
-
# print(f'{i+1}: {conv2.get_latent(latent[:, i+1]).shape}')
|
515 |
-
# print(f'{i+2}: {to_rgb.get_latent(latent[:, i+2]).shape}')
|
516 |
-
# print("")
|
517 |
output.append(conv1.get_latent(latent[:, i]))
|
518 |
output.append(conv2.get_latent(latent[:, i+1]))
|
519 |
output.append(to_rgb.get_latent(latent[:, i+2]))
|
|
|
508 |
output.append(self.to_rgb1.get_latent(latent[:, 1]))
|
509 |
|
510 |
i = 1
|
|
|
511 |
for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2], self.to_rgbs):
|
|
|
|
|
|
|
|
|
512 |
output.append(conv1.get_latent(latent[:, i]))
|
513 |
output.append(conv2.get_latent(latent[:, i+1]))
|
514 |
output.append(to_rgb.get_latent(latent[:, i+2]))
|