Spaces:
Runtime error
Runtime error
Commit
•
7e0de36
1
Parent(s):
564c410
Got RIS and Metrics up and running
Browse files- app.py +30 -23
- ris/blend.py +3 -0
- ris/model.py +5 -3
app.py
CHANGED
@@ -2,6 +2,8 @@ import sys
|
|
2 |
import os
|
3 |
import torch
|
4 |
|
|
|
|
|
5 |
sys.path.append(".")
|
6 |
|
7 |
from gradio_wrapper.gradio_options import GradioTestOptions
|
@@ -185,6 +187,9 @@ with gr.Blocks() as demo:
|
|
185 |
result_batch = (x_hat, w_hat)
|
186 |
return result_batch
|
187 |
def run_metrics(base_img, edited_img):
|
|
|
|
|
|
|
188 |
lpips_score = lpips_metric(base_img, edited_img)[0]
|
189 |
ssim_score = ssim_metric(base_img, edited_img)[0]
|
190 |
id_score = id_metric(base_img, edited_img)[0]
|
@@ -234,35 +239,36 @@ with gr.Blocks() as demo:
|
|
234 |
hyperstyle_metrics_text = ''
|
235 |
if 'Hyperstyle' in inverter_bools:
|
236 |
hyperstyle_batch, hyperstyle_latents, hyperstyle_deltas, _ = run_inversion(input_img.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
|
237 |
-
|
238 |
-
invert_hyperstyle = tensor2im(hyperstyle_batch[0])
|
239 |
-
else:
|
240 |
-
invert_hyperstyle = None
|
241 |
if mapper_bool:
|
242 |
mapped_hyperstyle, _ = map_latent(mapper, hyperstyle_latents, stylespace=False, weight_deltas=hyperstyle_deltas, strength=mapper_alpha)
|
243 |
-
|
244 |
mapped_hyperstyle = tensor2im(mapped_hyperstyle[0])
|
245 |
-
|
246 |
-
|
247 |
else:
|
248 |
mapped_hyperstyle = None
|
249 |
|
250 |
if gd_bool:
|
251 |
gd_hyperstyle = edit_image(_, hyperstyle_latents[0], hyperstyle.decoder, direction_calculator, opts, hyperstyle_deltas)
|
252 |
-
|
253 |
gd_hyperstyle = tensor2im(gd_hyperstyle[0])
|
254 |
-
|
255 |
-
|
256 |
else:
|
257 |
gd_hyperstyle = None
|
258 |
|
259 |
if ris_bool:
|
260 |
|
261 |
ref_hyperstyle_batch, ref_hyperstyle_latents, ref_hyperstyle_deltas, _ = run_inversion(ref_input.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
|
262 |
-
blend_hyperstyle, blend_hyperstyle_latents = blend_latents(hyperstyle_latents,
|
263 |
src_deltas=hyperstyle_deltas, ref_deltas=ref_hyperstyle_deltas,
|
264 |
generator=ris_gen, device=device)
|
265 |
-
ris_hyperstyle = tensor2im(blend_hyperstyle)
|
|
|
|
|
|
|
|
|
266 |
else:
|
267 |
ris_hyperstyle=None
|
268 |
|
@@ -274,16 +280,13 @@ with gr.Blocks() as demo:
|
|
274 |
if 'E4E' in inverter_bools:
|
275 |
e4e_batch, e4e_latents = hyperstyle.w_invert(input_img.unsqueeze(0))
|
276 |
e4e_deltas = None
|
277 |
-
|
278 |
-
invert_e4e = tensor2im(e4e_batch[0])
|
279 |
-
else:
|
280 |
-
invert_e4e = None
|
281 |
if mapper_bool:
|
282 |
mapped_e4e, _ = map_latent(mapper, e4e_latents, stylespace=False, weight_deltas=e4e_deltas, strength=mapper_alpha)
|
283 |
-
|
284 |
mapped_e4e = tensor2im(mapped_e4e[0])
|
285 |
-
|
286 |
-
|
287 |
|
288 |
else:
|
289 |
mapped_e4e = None
|
@@ -292,8 +295,8 @@ with gr.Blocks() as demo:
|
|
292 |
gd_e4e = edit_image(_, e4e_latents[0], hyperstyle.decoder, direction_calculator, opts, e4e_deltas)
|
293 |
clip_score = clip_text_metric(gd_e4e[0], opts.target_text)
|
294 |
gd_e4e = tensor2im(gd_e4e[0])
|
295 |
-
lpips_score, ssim_score, id_score = run_metrics(invert_e4e, gd_e4e)
|
296 |
-
e4e_metrics_text += f'
|
297 |
|
298 |
else:
|
299 |
gd_e4e = None
|
@@ -301,10 +304,14 @@ with gr.Blocks() as demo:
|
|
301 |
if ris_bool:
|
302 |
ref_e4e_batch, ref_e4e_latents, = hyperstyle.w_invert(ref_input.unsqueeze(0))
|
303 |
ref_e4e_deltas= None
|
304 |
-
blend_e4e, blend_e4e_latents = blend_latents(e4e_latents,
|
305 |
src_deltas=None, ref_deltas=None,
|
306 |
generator=ris_gen, device=device)
|
307 |
-
ris_e4e = tensor2im(blend_e4e)
|
|
|
|
|
|
|
|
|
308 |
else:
|
309 |
ris_e4e=None
|
310 |
|
|
|
2 |
import os
|
3 |
import torch
|
4 |
|
5 |
+
from metrics.metrics import ClipHair
|
6 |
+
|
7 |
sys.path.append(".")
|
8 |
|
9 |
from gradio_wrapper.gradio_options import GradioTestOptions
|
|
|
187 |
result_batch = (x_hat, w_hat)
|
188 |
return result_batch
|
189 |
def run_metrics(base_img, edited_img):
|
190 |
+
#print(base_img.shape, edited_img.shape)
|
191 |
+
#base_img = base_img.unsqueeze(0)
|
192 |
+
#edited_img = edited_img.unqueeze(0)
|
193 |
lpips_score = lpips_metric(base_img, edited_img)[0]
|
194 |
ssim_score = ssim_metric(base_img, edited_img)[0]
|
195 |
id_score = id_metric(base_img, edited_img)[0]
|
|
|
239 |
hyperstyle_metrics_text = ''
|
240 |
if 'Hyperstyle' in inverter_bools:
|
241 |
hyperstyle_batch, hyperstyle_latents, hyperstyle_deltas, _ = run_inversion(input_img.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
|
242 |
+
invert_hyperstyle = tensor2im(hyperstyle_batch[0])
|
|
|
|
|
|
|
243 |
if mapper_bool:
|
244 |
mapped_hyperstyle, _ = map_latent(mapper, hyperstyle_latents, stylespace=False, weight_deltas=hyperstyle_deltas, strength=mapper_alpha)
|
245 |
+
clip_score = clip_text_metric(mapped_hyperstyle[0], mapper_args.description)
|
246 |
mapped_hyperstyle = tensor2im(mapped_hyperstyle[0])
|
247 |
+
lpips_score, ssim_score, id_score = run_metrics(invert_hyperstyle.resize(resize_to), mapped_hyperstyle.resize(resize_to))
|
248 |
+
hyperstyle_metrics_text += f'\nMapper Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
|
249 |
else:
|
250 |
mapped_hyperstyle = None
|
251 |
|
252 |
if gd_bool:
|
253 |
gd_hyperstyle = edit_image(_, hyperstyle_latents[0], hyperstyle.decoder, direction_calculator, opts, hyperstyle_deltas)
|
254 |
+
clip_score = clip_text_metric(gd_hyperstyle[0], opts.target_text)
|
255 |
gd_hyperstyle = tensor2im(gd_hyperstyle[0])
|
256 |
+
lpips_score, ssim_score, id_score = run_metrics(invert_hyperstyle.resize(resize_to), gd_hyperstyle.resize(resize_to))
|
257 |
+
hyperstyle_metrics_text += f'\nGlobal Direction Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
|
258 |
else:
|
259 |
gd_hyperstyle = None
|
260 |
|
261 |
if ris_bool:
|
262 |
|
263 |
ref_hyperstyle_batch, ref_hyperstyle_latents, ref_hyperstyle_deltas, _ = run_inversion(ref_input.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
|
264 |
+
blend_hyperstyle, blend_hyperstyle_latents = blend_latents(hyperstyle_latents, ref_hyperstyle_latents,
|
265 |
src_deltas=hyperstyle_deltas, ref_deltas=ref_hyperstyle_deltas,
|
266 |
generator=ris_gen, device=device)
|
267 |
+
ris_hyperstyle = tensor2im(blend_hyperstyle[0])
|
268 |
+
|
269 |
+
lpips_score, ssim_score, id_score = run_metrics(invert_hyperstyle.resize(resize_to), ris_hyperstyle.resize(resize_to))
|
270 |
+
clip_score = clip_hair(invert_hyperstyle.resize(resize_to), ris_hyperstyle.resize(resize_to))[1]
|
271 |
+
hyperstyle_metrics_text += f'\nRIS Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Hair Score: \t{clip_score}'
|
272 |
else:
|
273 |
ris_hyperstyle=None
|
274 |
|
|
|
280 |
if 'E4E' in inverter_bools:
|
281 |
e4e_batch, e4e_latents = hyperstyle.w_invert(input_img.unsqueeze(0))
|
282 |
e4e_deltas = None
|
283 |
+
invert_e4e = tensor2im(e4e_batch[0])
|
|
|
|
|
|
|
284 |
if mapper_bool:
|
285 |
mapped_e4e, _ = map_latent(mapper, e4e_latents, stylespace=False, weight_deltas=e4e_deltas, strength=mapper_alpha)
|
286 |
+
clip_score = clip_text_metric(mapped_e4e[0], mapper_args.description)
|
287 |
mapped_e4e = tensor2im(mapped_e4e[0])
|
288 |
+
lpips_score, ssim_score, id_score = run_metrics(invert_e4e.resize(resize_to), mapped_e4e.resize(resize_to))
|
289 |
+
e4e_metrics_text += f'\nMapper Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
|
290 |
|
291 |
else:
|
292 |
mapped_e4e = None
|
|
|
295 |
gd_e4e = edit_image(_, e4e_latents[0], hyperstyle.decoder, direction_calculator, opts, e4e_deltas)
|
296 |
clip_score = clip_text_metric(gd_e4e[0], opts.target_text)
|
297 |
gd_e4e = tensor2im(gd_e4e[0])
|
298 |
+
lpips_score, ssim_score, id_score = run_metrics(invert_e4e.resize(resize_to), gd_e4e.resize(resize_to))
|
299 |
+
e4e_metrics_text += f'\nGlobal Direction Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
|
300 |
|
301 |
else:
|
302 |
gd_e4e = None
|
|
|
304 |
if ris_bool:
|
305 |
ref_e4e_batch, ref_e4e_latents, = hyperstyle.w_invert(ref_input.unsqueeze(0))
|
306 |
ref_e4e_deltas= None
|
307 |
+
blend_e4e, blend_e4e_latents = blend_latents(e4e_latents, ref_e4e_latents,
|
308 |
src_deltas=None, ref_deltas=None,
|
309 |
generator=ris_gen, device=device)
|
310 |
+
ris_e4e = tensor2im(blend_e4e[0])
|
311 |
+
|
312 |
+
lpips_score, ssim_score, id_score = run_metrics(invert_e4e.resize(resize_to), ris_e4e.resize(resize_to))
|
313 |
+
clip_score = clip_hair(invert_e4e.resize(resize_to), ris_e4e.resize(resize_to))[1]
|
314 |
+
e4e_metrics_text += f'\nRIS Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Hair Score: \t{clip_score}'
|
315 |
else:
|
316 |
ris_e4e=None
|
317 |
|
ris/blend.py
CHANGED
@@ -111,7 +111,9 @@ def compute_M(w, generator, weights_deltas=None, device='cuda'):
|
|
111 |
return M
|
112 |
|
113 |
def blend_latents (source_latent, ref_latent, generator, src_deltas=None, ref_deltas=None, device='cuda'):
|
|
|
114 |
source = generator.get_latent(source_latent[0].unsqueeze(0), truncation=1, is_latent=True)
|
|
|
115 |
ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True)
|
116 |
source_M = compute_M(source, generator, weights_deltas=src_deltas, device='cpu')
|
117 |
ref_M = compute_M(ref, generator, weights_deltas=ref_deltas, device='cpu')
|
@@ -127,5 +129,6 @@ def blend_latents (source_latent, ref_latent, generator, src_deltas=None, ref_de
|
|
127 |
|
128 |
blend = style2list((add_direction(source, ref, part_M, 1.3)))
|
129 |
blend_out, _ = generator(blend, weights_deltas=blend_deltas)
|
|
|
130 |
|
131 |
return blend_out, blend
|
|
|
111 |
return M
|
112 |
|
113 |
def blend_latents (source_latent, ref_latent, generator, src_deltas=None, ref_deltas=None, device='cuda'):
|
114 |
+
#print(source_latent.shape)
|
115 |
source = generator.get_latent(source_latent[0].unsqueeze(0), truncation=1, is_latent=True)
|
116 |
+
#print(ref_latent.shape)
|
117 |
ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True)
|
118 |
source_M = compute_M(source, generator, weights_deltas=src_deltas, device='cpu')
|
119 |
ref_M = compute_M(ref, generator, weights_deltas=ref_deltas, device='cpu')
|
|
|
129 |
|
130 |
blend = style2list((add_direction(source, ref, part_M, 1.3)))
|
131 |
blend_out, _ = generator(blend, weights_deltas=blend_deltas)
|
132 |
+
#print(blend_out.shape)
|
133 |
|
134 |
return blend_out, blend
|
ris/model.py
CHANGED
@@ -160,13 +160,15 @@ class EqualLinear(nn.Module):
|
|
160 |
self.lr_mul = lr_mul
|
161 |
|
162 |
def forward(self, input):
|
|
|
|
|
163 |
if self.activation:
|
164 |
-
out = F.linear(input,
|
165 |
-
out = fused_leaky_relu(out,
|
166 |
|
167 |
else:
|
168 |
out = F.linear(
|
169 |
-
input,
|
170 |
)
|
171 |
|
172 |
return out
|
|
|
160 |
self.lr_mul = lr_mul
|
161 |
|
162 |
def forward(self, input):
|
163 |
+
weight = self.weight * self.scale
|
164 |
+
bias = self.bias * self.lr_mul
|
165 |
if self.activation:
|
166 |
+
out = F.linear(input, weight)
|
167 |
+
out = fused_leaky_relu(out, bias)
|
168 |
|
169 |
else:
|
170 |
out = F.linear(
|
171 |
+
input, weight, bias=bias
|
172 |
)
|
173 |
|
174 |
return out
|