Spaces:
Runtime error
Runtime error
ethanNeuralImage
commited on
Commit
•
07ecdd9
1
Parent(s):
7deca27
initial commit -- gradio
Browse files- app.py +215 -0
- gradio_wrapper/demo.py +549 -0
- gradio_wrapper/gradio_options.py +53 -0
app.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
|
5 |
+
sys.path.append(".")
|
6 |
+
|
7 |
+
from gradio_wrapper.gradio_options import GradioTestOptions
|
8 |
+
from models.hyperstyle.utils.model_utils import load_model
|
9 |
+
from models.hyperstyle.utils.common import tensor2im
|
10 |
+
from models.hyperstyle.utils.inference_utils import run_inversion
|
11 |
+
|
12 |
+
from hyperstyle_global_directions.edit import load_direction_calculator, edit_image
|
13 |
+
|
14 |
+
from torchvision import transforms
|
15 |
+
|
16 |
+
import gradio as gr
|
17 |
+
|
18 |
+
from utils.alignment import align_face
|
19 |
+
import dlib
|
20 |
+
|
21 |
+
from argparse import Namespace
|
22 |
+
|
23 |
+
from mapper.styleclip_mapper import StyleCLIPMapper
|
24 |
+
|
25 |
+
from PIL import Image
|
26 |
+
|
27 |
+
opts_args = ['--no_fine_mapper']
|
28 |
+
opts = GradioTestOptions().parse(opts_args)
|
29 |
+
|
30 |
+
mapper_dict = {
|
31 |
+
'afro':'./pretrained_models/styleCLIP_mappers/afro_hairstyle.pt',
|
32 |
+
'bob':'./pretrained_models/styleCLIP_mappers/bob_hairstyle.pt',
|
33 |
+
'bowl':'./pretrained_models/styleCLIP_mappers/bowl_hairstyle.pt',
|
34 |
+
'buzz':'./pretrained_models/styleCLIP_mappers/buzz_hairstyle.pt',
|
35 |
+
'caesar':'./pretrained_models/styleCLIP_mappers/caesar_hairstyle.pt',
|
36 |
+
'crew':'./pretrained_models/styleCLIP_mappers/crew_hairstyle.pt',
|
37 |
+
'pixie':'./pretrained_models/styleCLIP_mappers/pixie_hairstyle.pt',
|
38 |
+
'straight':'./pretrained_models/styleCLIP_mappers/straight_hairstyle.pt',
|
39 |
+
'undercut':'./pretrained_models/styleCLIP_mappers/undercut_hairstyle.pt',
|
40 |
+
'wavy':'./pretrained_models/styleCLIP_mappers/wavy_hairstyle.pt'
|
41 |
+
}
|
42 |
+
|
43 |
+
predictor = dlib.shape_predictor("./pretrained_models/hyperstyle/shape_predictor_68_face_landmarks.dat")
|
44 |
+
hyperstyle, hyperstyle_args = load_model(opts.hyperstyle_checkpoint_path, update_opts=opts)
|
45 |
+
resize_amount = (256, 256) if hyperstyle_args.resize_outputs else (hyperstyle_args.output_size, hyperstyle_args.output_size)
|
46 |
+
im2tensor_transforms = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
47 |
+
direction_calculator = load_direction_calculator(opts)
|
48 |
+
|
49 |
+
ckpt = torch.load(mapper_dict['afro'], map_location='cpu')
|
50 |
+
opts.checkpoint_path = mapper_dict['afro']
|
51 |
+
mapper_args = ckpt['opts']
|
52 |
+
mapper_args.update(vars(opts))
|
53 |
+
mapper_args = Namespace(**mapper_args)
|
54 |
+
mapper = StyleCLIPMapper(mapper_args)
|
55 |
+
mapper.eval()
|
56 |
+
mapper.cuda()
|
57 |
+
|
58 |
+
def change_mapper(desc):
|
59 |
+
global mapper
|
60 |
+
global mapper_args
|
61 |
+
mapper = None
|
62 |
+
ckpt = None
|
63 |
+
mapper_args = None
|
64 |
+
torch.cuda.empty_cache()
|
65 |
+
opts.checkpoint_path = mapper_dict[desc]
|
66 |
+
ckpt = torch.load(mapper_dict[desc], map_location='cpu')
|
67 |
+
mapper_args = ckpt['opts']
|
68 |
+
mapper_args.update(vars(opts))
|
69 |
+
mapper_args = Namespace(**mapper_args)
|
70 |
+
mapper = StyleCLIPMapper(mapper_args)
|
71 |
+
mapper.eval()
|
72 |
+
mapper.cuda()
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
with gr.Blocks() as demo:
|
77 |
+
with gr.Row() as row:
|
78 |
+
with gr.Column() as inputs:
|
79 |
+
source = gr.Image(label="Image to Map", type='filepath')
|
80 |
+
align = gr.Checkbox(True, label='Align Image')
|
81 |
+
inverter_bools = gr.CheckboxGroup(["Hyperstyle", "E4E"], value=['Hyperstyle'], label='Inverter Choices')
|
82 |
+
n_hyperstyle_iterations = gr.Number(3, label='Number of Iterations For Hyperstyle', precision=0)
|
83 |
+
with gr.Box():
|
84 |
+
mapper_bool = gr.Checkbox(True, label='Output Mapper Result')
|
85 |
+
with gr.Box() as mapper_opts:
|
86 |
+
mapper_choice = gr.Dropdown(['afro', 'bob', 'bowl', 'buzz', 'caesar', 'crew', 'pixie', 'straight', 'undercut', 'wavy'], value='afro', label='What Hairstyle Mapper to Use?')
|
87 |
+
mapper_alpha = gr.Slider(minimum=-0.5, maximum=0.5, value=0.01, step=0.1, label='Strength of Mapper Alpha',)
|
88 |
+
with gr.Box():
|
89 |
+
gd_bool = gr.Checkbox(False, label='Output Global Direction Result')
|
90 |
+
with gr.Box(visible=False) as gd_opts:
|
91 |
+
neutral_text = gr.Text(value='A face with hair', label='Neutral Text')
|
92 |
+
target_text = gr.Text(value=mapper_args.description, label='Target Text')
|
93 |
+
alpha = gr.Slider(minimum=-10.0, maximum=10.0, value=4.1, step=0.1, label="Alpha for Global Direction")
|
94 |
+
beta = gr.Slider(minimum=0.0, maximum=0.30, value=0.15, step=0.01, label="Beta for Global Direction")
|
95 |
+
submit_button = gr.Button("Edit Image")
|
96 |
+
with gr.Column() as outputs:
|
97 |
+
with gr.Row() as hyperstyle_images:
|
98 |
+
output_hyperstyle_mapper = gr.Image(type='pil', label="Hyperstyle Mapper")
|
99 |
+
output_hyperstyle_gd = gr.Image(type='pil', label="Hyperstyle Global Directions", visible=False)
|
100 |
+
with gr.Row(visible=False) as e4e_images:
|
101 |
+
output_e4e_mapper = gr.Image(type='pil', label="E4E Mapper")
|
102 |
+
output_e4e_gd = gr.Image(type='pil', label="E4E Global Directions", visible=False)
|
103 |
+
def mapper_change(new_mapper):
|
104 |
+
change_mapper(new_mapper)
|
105 |
+
return mapper_args.description
|
106 |
+
def inverter_toggles(bools):
|
107 |
+
e4e_bool = 'E4E' in bools
|
108 |
+
hyperstyle_bool = 'Hyperstyle' in bools
|
109 |
+
return {
|
110 |
+
hyperstyle_images: gr.update(visible=hyperstyle_bool),
|
111 |
+
e4e_images: gr.update(visible=e4e_bool),
|
112 |
+
n_hyperstyle_iterations: gr.update(visible=hyperstyle_bool)
|
113 |
+
}
|
114 |
+
|
115 |
+
def mapper_toggles(bool):
|
116 |
+
return {
|
117 |
+
mapper_opts: gr.update(visible=bool),
|
118 |
+
output_hyperstyle_mapper: gr.update(visible=bool),
|
119 |
+
output_e4e_mapper: gr.update(visible=bool)
|
120 |
+
}
|
121 |
+
def gd_toggles(bool):
|
122 |
+
return {
|
123 |
+
gd_opts: gr.update(visible=bool),
|
124 |
+
output_hyperstyle_gd: gr.update(visible=bool),
|
125 |
+
output_e4e_gd: gr.update(visible=bool)
|
126 |
+
}
|
127 |
+
|
128 |
+
mapper_choice.change(mapper_change, mapper_choice, [target_text])
|
129 |
+
inverter_bools.change(inverter_toggles, inverter_bools, [hyperstyle_images, e4e_images, n_hyperstyle_iterations])
|
130 |
+
mapper_bool.change(mapper_toggles, mapper_bool, [mapper_opts, output_hyperstyle_mapper, output_e4e_mapper])
|
131 |
+
gd_bool.change(gd_toggles, gd_bool, [gd_opts, output_hyperstyle_gd, output_e4e_gd])
|
132 |
+
def map_latent(inputs, stylespace=False, weight_deltas=None, strength=0.1):
|
133 |
+
w = inputs.cuda()
|
134 |
+
with torch.no_grad():
|
135 |
+
if stylespace:
|
136 |
+
delta = mapper.mapper(w)
|
137 |
+
w_hat = [c + strength * delta_c for (c, delta_c) in zip(w, delta)]
|
138 |
+
x_hat, _, w_hat = mapper.decoder([w_hat], input_is_latent=True, return_latents=True,
|
139 |
+
randomize_noise=False, truncation=1, input_is_stylespace=True, weights_deltas=weight_deltas)
|
140 |
+
else:
|
141 |
+
delta = mapper.mapper(w)
|
142 |
+
w_hat = w + strength * delta
|
143 |
+
x_hat, w_hat, _ = mapper.decoder([w_hat], input_is_latent=True, return_latents=True,
|
144 |
+
randomize_noise=False, truncation=1, weights_deltas=weight_deltas)
|
145 |
+
result_batch = (x_hat, w_hat)
|
146 |
+
return result_batch
|
147 |
+
def submit(
|
148 |
+
src, align_img, inverter_bools, n_iterations,
|
149 |
+
mapper_bool, mapper_choice, mapper_alpha,
|
150 |
+
gd_bool, neutral_text, target_text, alpha, beta,
|
151 |
+
):
|
152 |
+
torch.cuda.empty_cache()
|
153 |
+
with torch.no_grad():
|
154 |
+
output_imgs = []
|
155 |
+
if align_img:
|
156 |
+
input_img = align_face(src, predictor)
|
157 |
+
else:
|
158 |
+
input_img = Image.open(src).convert('RGB')
|
159 |
+
input_img = im2tensor_transforms(input_img).cuda()
|
160 |
+
|
161 |
+
if gd_bool:
|
162 |
+
opts.neutral_text = neutral_text
|
163 |
+
opts.target_text = target_text
|
164 |
+
opts.alpha = alpha
|
165 |
+
opts.beta = beta
|
166 |
+
|
167 |
+
if 'Hyperstyle' in inverter_bools:
|
168 |
+
hyperstyle_batch, hyperstyle_latents, hyperstyle_deltas, _ = run_inversion(input_img.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
|
169 |
+
if mapper_bool:
|
170 |
+
mapped_hyperstyle, _ = map_latent(hyperstyle_latents, stylespace=False, weight_deltas=hyperstyle_deltas, strength=mapper_alpha)
|
171 |
+
mapped_hyperstyle = tensor2im(mapped_hyperstyle[0])
|
172 |
+
else:
|
173 |
+
mapped_hyperstyle = None
|
174 |
+
|
175 |
+
if gd_bool:
|
176 |
+
gd_hyperstyle = edit_image(_, hyperstyle_latents[0], hyperstyle.decoder, direction_calculator, opts, hyperstyle_deltas)[0]
|
177 |
+
gd_hyperstyle = tensor2im(gd_hyperstyle)
|
178 |
+
else:
|
179 |
+
gd_hyperstyle = None
|
180 |
+
|
181 |
+
hyperstyle_output = [mapped_hyperstyle,gd_hyperstyle]
|
182 |
+
else:
|
183 |
+
hyperstyle_output = [None, None]
|
184 |
+
output_imgs.extend(hyperstyle_output)
|
185 |
+
if 'E4E' in inverter_bools:
|
186 |
+
e4e_batch, e4e_latents = hyperstyle.w_invert(input_img.unsqueeze(0))
|
187 |
+
e4e_deltas = None
|
188 |
+
if mapper_bool:
|
189 |
+
mapped_e4e, _ = map_latent(e4e_latents, stylespace=False, weight_deltas=e4e_deltas, strength=mapper_alpha)
|
190 |
+
mapped_e4e = tensor2im(mapped_e4e[0])
|
191 |
+
else:
|
192 |
+
mapped_e4e = None
|
193 |
+
|
194 |
+
if gd_bool:
|
195 |
+
gd_e4e = edit_image(_, e4e_latents[0], hyperstyle.decoder, direction_calculator, opts, e4e_deltas)[0]
|
196 |
+
gd_e4e = tensor2im(gd_e4e)
|
197 |
+
else:
|
198 |
+
gd_e4e = None
|
199 |
+
|
200 |
+
e4e_output = [mapped_e4e, gd_e4e]
|
201 |
+
else:
|
202 |
+
e4e_output = [None, None]
|
203 |
+
output_imgs.extend(e4e_output)
|
204 |
+
return output_imgs
|
205 |
+
submit_button.click(
|
206 |
+
submit,
|
207 |
+
[
|
208 |
+
source, align, inverter_bools, n_hyperstyle_iterations,
|
209 |
+
mapper_bool, mapper_choice, mapper_alpha,
|
210 |
+
gd_bool, neutral_text, target_text, alpha, beta,
|
211 |
+
],
|
212 |
+
[output_hyperstyle_mapper, output_hyperstyle_gd, output_e4e_mapper, output_e4e_gd]
|
213 |
+
)
|
214 |
+
|
215 |
+
demo.launch()
|
gradio_wrapper/demo.py
ADDED
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.optim as optim
|
6 |
+
import torch.backends.cudnn as cudnn
|
7 |
+
cudnn.benchmark = True
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import numpy as np
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
from tqdm import tqdm as tqdm
|
13 |
+
import pickle
|
14 |
+
import warnings
|
15 |
+
warnings.filterwarnings("ignore")
|
16 |
+
from spherical_kmeans import MiniBatchSphericalKMeans as sKmeans
|
17 |
+
from argparse import Namespace
|
18 |
+
from torch.utils.data import Dataset, DataLoader
|
19 |
+
from torchvision.models import vgg19
|
20 |
+
import glob
|
21 |
+
from pathlib import Path
|
22 |
+
import lpips
|
23 |
+
import argparse
|
24 |
+
import gc
|
25 |
+
import cv2
|
26 |
+
|
27 |
+
from e4e_projection import projection
|
28 |
+
|
29 |
+
|
30 |
+
from model import *
|
31 |
+
from util import *
|
32 |
+
from e4e.models.psp import pSp
|
33 |
+
import torchvision.transforms as transforms
|
34 |
+
|
35 |
+
from torch.nn import DataParallel
|
36 |
+
import torchvision.transforms.functional as TF
|
37 |
+
from FaceQualityMetrics.utils import FaceMetric
|
38 |
+
from hyperstyle.utils.model_utils import load_model
|
39 |
+
from configs.paths_config import model_paths
|
40 |
+
|
41 |
+
|
42 |
+
from manipulator import Manipulator
|
43 |
+
from wrapper import Generator_wrapper
|
44 |
+
|
45 |
+
|
46 |
+
def run_inversion(inputs, net, n_iters_per_batch, return_intermediate_results=False, resize_outputs=False, weights_deltas=None):
|
47 |
+
y_hat, latent, weights_deltas, codes = None, None, weights_deltas, None
|
48 |
+
|
49 |
+
if return_intermediate_results:
|
50 |
+
results_batch = {idx: [] for idx in range(inputs.shape[0])}
|
51 |
+
results_latent = {idx: [] for idx in range(inputs.shape[0])}
|
52 |
+
results_deltas = {idx: [] for idx in range(inputs.shape[0])}
|
53 |
+
else:
|
54 |
+
results_batch, results_latent, results_deltas = None, None, None
|
55 |
+
|
56 |
+
if weights_deltas is None:
|
57 |
+
|
58 |
+
for iter in range(n_iters_per_batch):
|
59 |
+
y_hat, latent, weights_deltas, codes, _ = net.forward(inputs,
|
60 |
+
y_hat=y_hat,
|
61 |
+
codes=codes,
|
62 |
+
weights_deltas=weights_deltas,
|
63 |
+
return_latents=True,
|
64 |
+
resize=resize_outputs,
|
65 |
+
randomize_noise=False,
|
66 |
+
return_weight_deltas_and_codes=True)
|
67 |
+
# weights_deltas[14]= None
|
68 |
+
# weights_deltas[20]= None
|
69 |
+
# weights_deltas[21]= None
|
70 |
+
# weights_deltas[23]= None
|
71 |
+
# weights_deltas[24]= None
|
72 |
+
|
73 |
+
if return_intermediate_results:
|
74 |
+
store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas)
|
75 |
+
# resize input to 256 before feeding into next iteration
|
76 |
+
|
77 |
+
y_hat = net.face_pool(y_hat)
|
78 |
+
|
79 |
+
else:
|
80 |
+
for iter in range(n_iters_per_batch):
|
81 |
+
y_hat, latent, _, codes, _ = net.forward(inputs,
|
82 |
+
y_hat=y_hat,
|
83 |
+
codes=codes,
|
84 |
+
weights_deltas=weights_deltas,
|
85 |
+
return_latents=True,
|
86 |
+
resize=resize_outputs,
|
87 |
+
randomize_noise=False,
|
88 |
+
return_weight_deltas_and_codes=True)
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
if return_intermediate_results:
|
93 |
+
store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas)
|
94 |
+
|
95 |
+
# resize input to 256 before feeding into next iteration
|
96 |
+
|
97 |
+
y_hat = net.face_pool(y_hat)
|
98 |
+
|
99 |
+
if return_intermediate_results:
|
100 |
+
return results_batch, results_latent, results_deltas
|
101 |
+
return y_hat, latent, weights_deltas, codes
|
102 |
+
|
103 |
+
def store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas):
|
104 |
+
for idx in range(y_hat.shape[0]):
|
105 |
+
results_batch[idx].append(y_hat[idx])
|
106 |
+
results_latent[idx].append(latent[idx].cpu().numpy())
|
107 |
+
results_deltas[idx].append([w[idx].cpu().numpy() if w is not None else None for w in weights_deltas])
|
108 |
+
|
109 |
+
# compute M given a style code.
|
110 |
+
@torch.no_grad()
|
111 |
+
def compute_M(w, weights_deltas=None, device='cuda'):
|
112 |
+
M = []
|
113 |
+
|
114 |
+
# get segmentation
|
115 |
+
# _, outputs = generator(w, is_cluster=1)
|
116 |
+
_, outputs = generator(w, weights_deltas=weights_deltas)
|
117 |
+
cluster_layer = outputs[stop_idx][0]
|
118 |
+
activation = flatten_act(cluster_layer)
|
119 |
+
seg_mask = clusterer.predict(activation)
|
120 |
+
b,c,h,w = cluster_layer.size()
|
121 |
+
|
122 |
+
# create masks for each feature
|
123 |
+
all_seg_mask = []
|
124 |
+
seg_mask = torch.from_numpy(seg_mask).view(b,1,h,w,1).to(device)
|
125 |
+
|
126 |
+
for key in range(n_class):
|
127 |
+
# combine masks for all indices for a particular segmentation class
|
128 |
+
indices = labels_map[key].view(1,1,1,1,-1)
|
129 |
+
key_mask = (seg_mask == indices.to(device)).any(-1) #[b,1,h,w]
|
130 |
+
all_seg_mask.append(key_mask)
|
131 |
+
|
132 |
+
all_seg_mask = torch.stack(all_seg_mask, 1)
|
133 |
+
|
134 |
+
# go through each activation layer and compute M
|
135 |
+
for layer_idx in range(len(outputs)):
|
136 |
+
layer = outputs[layer_idx][1].to(device)
|
137 |
+
b,c,h,w = layer.size()
|
138 |
+
layer = F.instance_norm(layer)
|
139 |
+
layer = layer.pow(2)
|
140 |
+
|
141 |
+
# resize the segmentation masks to current activations' resolution
|
142 |
+
layer_seg_mask = F.interpolate(all_seg_mask.flatten(0,1).float(), align_corners=False,
|
143 |
+
size=(h,w), mode='bilinear').view(b,-1,1,h,w)
|
144 |
+
|
145 |
+
masked_layer = layer.unsqueeze(1) * layer_seg_mask # [b,k,c,h,w]
|
146 |
+
masked_layer = (masked_layer.sum([3,4])/ (h*w))#[b,k,c]
|
147 |
+
|
148 |
+
M.append(masked_layer.to(device))
|
149 |
+
|
150 |
+
M = torch.cat(M, -1) #[b, k, c]
|
151 |
+
|
152 |
+
# softmax to assign each channel to a particular segmentation class
|
153 |
+
M = F.softmax(M/.1, 1)
|
154 |
+
# simple thresholding
|
155 |
+
M = (M>.8).float()
|
156 |
+
|
157 |
+
# zero out torgb transfers, from https://arxiv.org/abs/2011.12799
|
158 |
+
for i in range(n_class):
|
159 |
+
part_M = style2list(M[:, i])
|
160 |
+
for j in range(len(part_M)):
|
161 |
+
if j in rgb_layer_idx:
|
162 |
+
part_M[j].zero_()
|
163 |
+
part_M = list2style(part_M)
|
164 |
+
M[:, i] = part_M
|
165 |
+
|
166 |
+
return M
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
#====
|
171 |
+
# for i in range(len(blend_deltas)):
|
172 |
+
# if blend_deltas[i] is not None:
|
173 |
+
# print(f'{i}: {part_M_mask[i].sum()}/{sum(part_M_mask[i].shape)}')
|
174 |
+
# if part_M_mask[i].sum() >= sum(part_M_mask[i].shape)/2:
|
175 |
+
# print(i)
|
176 |
+
# blend_deltas[i] = ref_deltas[i]
|
177 |
+
|
178 |
+
|
179 |
+
def tensor2img(tensor):
|
180 |
+
tensor = tensor.cpu().clamp(-1, 1)
|
181 |
+
img = topil(tensor.squeeze())
|
182 |
+
|
183 |
+
return img
|
184 |
+
|
185 |
+
|
186 |
+
def hair_transfer_hyperstyle(source_img_path, ref_img_path):
|
187 |
+
|
188 |
+
with torch.no_grad():
|
189 |
+
source_img = align_face(source_img_path, predictor=predictor)
|
190 |
+
ref_img = align_face(ref_img_path, predictor=predictor)
|
191 |
+
source_img = Image.fromarray(np.uint8(source_img))
|
192 |
+
ref_img = Image.fromarray(np.uint8(ref_img))
|
193 |
+
|
194 |
+
source_tensor = transform(source_img).unsqueeze(0).to(device)
|
195 |
+
ref_tensor = transform(ref_img).unsqueeze(0).to(device)
|
196 |
+
|
197 |
+
source_batch, source_latent, source_deltas, source_codes = run_inversion(source_tensor, net, n_iters_per_batch=5, return_intermediate_results=False)
|
198 |
+
ref_batch, ref_latent, ref_deltas, ref_codes = run_inversion(ref_tensor, net, n_iters_per_batch=5, return_intermediate_results=False)
|
199 |
+
|
200 |
+
source = generator.get_latent(source_latent[0].unsqueeze(0), truncation=1, is_latent=True)
|
201 |
+
ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True)
|
202 |
+
|
203 |
+
source_out, _ = generator(source, weights_deltas=source_deltas, randomize_noise=False)
|
204 |
+
ref_out, _ = generator(ref, weights_deltas=ref_deltas, randomize_noise=False)
|
205 |
+
|
206 |
+
source_M = compute_M(source, weights_deltas=source_deltas, device='cpu')
|
207 |
+
ref_M = compute_M(ref, weights_deltas=ref_deltas, device='cpu')
|
208 |
+
|
209 |
+
blend_deltas = source_deltas
|
210 |
+
|
211 |
+
max_M = torch.max(source_M.expand_as(ref_M), ref_M)
|
212 |
+
max_M = add_pose(max_M, labels2idx)
|
213 |
+
idx = labels2idx['hair']
|
214 |
+
part_M = max_M[:, idx].to(device)
|
215 |
+
part_M_mask = style2list(part_M)
|
216 |
+
blend = style2list((add_direction(source, ref, part_M, 1.3)))
|
217 |
+
blend_out, _ = generator(blend, weights_deltas=blend_deltas)
|
218 |
+
|
219 |
+
source_out = tensor2img(source_out)
|
220 |
+
ref_out = tensor2img(ref_out)
|
221 |
+
blend_out = tensor2img(blend_out)
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
lpips_face, _ = lpips(blend_out, source_out)
|
226 |
+
ssim_face, _ = ssim(blend_out, source_out)
|
227 |
+
id_face, _ = id_score(blend_out, source_out)
|
228 |
+
_, lpips_hair = lpips(blend_out, ref_out)
|
229 |
+
_, ssim_hair = ssim(blend_out, ref_out)
|
230 |
+
_, clip_hair = clip(blend_out, source_out)
|
231 |
+
out_str = f'lpips_face: {lpips_face}\nlpips_hair: {lpips_hair}\nssim_face: {ssim_face}\nssim_hair: {ssim_hair}\nid_face: {id_face}\n clip_hair: {clip_hair}'
|
232 |
+
|
233 |
+
e4e_blend_out, _ = generator(blend)
|
234 |
+
e4e_blend_out = tensor2img(e4e_blend_out)
|
235 |
+
_, _, e4e_blend_hair_mask = lpips.parser(e4e_blend_out)
|
236 |
+
source_out_np = np.array(source_out)
|
237 |
+
blend_np =np.array(e4e_blend_out).astype(np.uint8)
|
238 |
+
|
239 |
+
e4e_blend_hair_mask = e4e_blend_hair_mask.cpu().numpy().astype(np.uint8)*255
|
240 |
+
mask_dilate = cv2.dilate(e4e_blend_hair_mask,
|
241 |
+
kernel=np.ones((50, 50), np.uint8))
|
242 |
+
mask_dilate_blur = cv2.blur(mask_dilate, ksize=(30, 30))
|
243 |
+
mask_dilate_blur = (e4e_blend_hair_mask + (255 - e4e_blend_hair_mask) / 255 * mask_dilate_blur).astype(np.uint8)
|
244 |
+
face_mask = 255 - mask_dilate_blur
|
245 |
+
|
246 |
+
index = np.where(face_mask > 0)
|
247 |
+
cy = (np.min(index[0]) + np.max(index[0])) // 2
|
248 |
+
cx = (np.min(index[1]) + np.max(index[1])) // 2
|
249 |
+
center = (cx, cy)
|
250 |
+
|
251 |
+
clone_out = cv2.seamlessClone(source_out_np, blend_np, face_mask, center, cv2.NORMAL_CLONE)
|
252 |
+
|
253 |
+
|
254 |
+
return source_out, ref_out, blend_out, out_str, clone_out
|
255 |
+
|
256 |
+
|
257 |
+
def hair_transfer_e4e(source_img_path, ref_img_path):
|
258 |
+
|
259 |
+
|
260 |
+
with torch.no_grad():
|
261 |
+
source_img = align_face(source_img_path, predictor=predictor)
|
262 |
+
ref_img = align_face(ref_img_path, predictor=predictor)
|
263 |
+
source_img = Image.fromarray(np.uint8(source_img))
|
264 |
+
ref_img = Image.fromarray(np.uint8(ref_img))
|
265 |
+
|
266 |
+
source_tensor = transform(source_img).unsqueeze(0).to(device)
|
267 |
+
ref_tensor = transform(ref_img).unsqueeze(0).to(device)
|
268 |
+
|
269 |
+
source_batch, source_latent, source_deltas, source_codes = run_inversion(source_tensor, net, n_iters_per_batch=5, return_intermediate_results=False)
|
270 |
+
ref_batch, ref_latent, ref_deltas, ref_codes = run_inversion(ref_tensor, net, n_iters_per_batch=5, return_intermediate_results=False)
|
271 |
+
|
272 |
+
source = generator.get_latent(source_latent[0].unsqueeze(0), truncation=1, is_latent=True)
|
273 |
+
ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True)
|
274 |
+
|
275 |
+
e4e_source_out, _ = generator(source, randomize_noise=False)
|
276 |
+
e4e_ref_out, _ = generator(ref, randomize_noise=False)
|
277 |
+
|
278 |
+
e4e_source_M = compute_M(source, device='cpu')
|
279 |
+
e4e_ref_M = compute_M(ref, device='cpu')
|
280 |
+
|
281 |
+
e4e_max_M = torch.max(e4e_source_M.expand_as(e4e_ref_M), e4e_ref_M)
|
282 |
+
e4e_max_M = add_pose(e4e_max_M, labels2idx)
|
283 |
+
e4e_idx = labels2idx['hair']
|
284 |
+
|
285 |
+
e4e_part_M = e4e_max_M[:, e4e_idx].to(device)
|
286 |
+
e4e_part_M_mask = style2list(e4e_part_M)
|
287 |
+
|
288 |
+
e4e_blend = style2list((add_direction(source, ref, e4e_part_M, 1.3)))
|
289 |
+
e4e_blend_out, _ = generator(e4e_blend)
|
290 |
+
|
291 |
+
|
292 |
+
e4e_source_out = tensor2img(e4e_source_out)
|
293 |
+
e4e_ref_out = tensor2img(e4e_ref_out)
|
294 |
+
e4e_blend_out = tensor2img(e4e_blend_out)
|
295 |
+
|
296 |
+
|
297 |
+
e4e_lpips_face, _ = lpips(e4e_blend_out, e4e_source_out)
|
298 |
+
e4e_ssim_face, _ = ssim(e4e_blend_out, e4e_source_out)
|
299 |
+
e4e_id_face, _ = id_score(e4e_blend_out, e4e_source_out)
|
300 |
+
_, e4e_lpips_hair = lpips(e4e_blend_out, e4e_ref_out)
|
301 |
+
_, e4e_ssim_hair = ssim(e4e_blend_out, e4e_ref_out)
|
302 |
+
_, e4e_clip_hair = clip(e4e_blend_out, e4e_source_out)
|
303 |
+
|
304 |
+
e4e_out_str = f'e4e_lpips_face: {e4e_lpips_face}\ne4e_lpips_hair: {e4e_lpips_hair}\ne4e_ssim_face: {e4e_ssim_face}\ne4e_ssim_hair: {e4e_ssim_hair}\ne4e_id_face: {e4e_id_face}\ne4e_ clip_hair: {e4e_clip_hair}'
|
305 |
+
|
306 |
+
|
307 |
+
return e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str
|
308 |
+
|
309 |
+
def hair_transfer_PTI(source_img_path, ref_img_path):
|
310 |
+
|
311 |
+
ckpt = 'pretrained/ffhq.pkl'
|
312 |
+
G = Generator_wrapper(ckpt, device)
|
313 |
+
manipulator = Manipulator(G, device)
|
314 |
+
manipulator.set_real_img_projection(source_img_path, inv_mode='w+', pti_mode='s')
|
315 |
+
|
316 |
+
with torch.no_grad():
|
317 |
+
source_img = align_face(source_img_path, predictor=predictor)
|
318 |
+
ref_img = align_face(ref_img_path, predictor=predictor)
|
319 |
+
source_img = Image.fromarray(np.uint8(source_img))
|
320 |
+
|
321 |
+
|
322 |
+
projection(source_img, 'source', generator, device)
|
323 |
+
projection(ref_img, 'ref', generator, device)
|
324 |
+
source = load_source('source', generator, device)
|
325 |
+
ref = load_source('ref', generator, device)
|
326 |
+
|
327 |
+
|
328 |
+
e4e_source_out, _ = generator(source, randomize_noise=False)
|
329 |
+
e4e_ref_out, _ = generator(ref, randomize_noise=False)
|
330 |
+
|
331 |
+
e4e_source_M = compute_M(source, device='cpu')
|
332 |
+
e4e_ref_M = compute_M(ref, device='cpu')
|
333 |
+
|
334 |
+
e4e_max_M = torch.max(e4e_source_M.expand_as(e4e_ref_M), e4e_ref_M)
|
335 |
+
e4e_max_M = add_pose(e4e_max_M, labels2idx)
|
336 |
+
e4e_idx = labels2idx['hair']
|
337 |
+
|
338 |
+
e4e_part_M = e4e_max_M[:, e4e_idx].to(device)
|
339 |
+
e4e_part_M_mask = style2list(e4e_part_M)
|
340 |
+
|
341 |
+
|
342 |
+
|
343 |
+
|
344 |
+
e4e_blend = style2list((add_direction(source, ref, e4e_part_M, 1.3)))
|
345 |
+
|
346 |
+
e4e_source_out = tensor2img(e4e_source_out)
|
347 |
+
e4e_ref_out = tensor2img(e4e_ref_out)
|
348 |
+
# e4e_blend_out = tensor2img(e4e_blend_out)
|
349 |
+
|
350 |
+
|
351 |
+
# e4e_lpips_face, _ = lpips(e4e_blend_out, e4e_source_out)
|
352 |
+
# e4e_ssim_face, _ = ssim(e4e_blend_out, e4e_source_out)
|
353 |
+
# e4e_id_face, _ = id_score(e4e_blend_out, e4e_source_out)
|
354 |
+
# _, e4e_lpips_hair = lpips(e4e_blend_out, e4e_ref_out)
|
355 |
+
# _, e4e_ssim_hair = ssim(e4e_blend_out, e4e_ref_out)
|
356 |
+
# _, e4e_clip_hair = clip(e4e_blend_out, e4e_source_out)
|
357 |
+
|
358 |
+
|
359 |
+
keys = (['G.synthesis.b4.conv1.affine', 'G.synthesis.b4.torgb.affine', 'G.synthesis.b8.conv0.affine', 'G.synthesis.b8.conv1.affine', 'G.synthesis.b8.torgb.affine', 'G.synthesis.b16.conv0.affine', 'G.synthesis.b16.conv1.affine', 'G.synthesis.b16.torgb.affine', 'G.synthesis.b32.conv0.affine', 'G.synthesis.b32.conv1.affine', 'G.synthesis.b32.torgb.affine', 'G.synthesis.b64.conv0.affine', 'G.synthesis.b64.conv1.affine', 'G.synthesis.b64.torgb.affine', 'G.synthesis.b128.conv0.affine', 'G.synthesis.b128.conv1.affine', 'G.synthesis.b128.torgb.affine', 'G.synthesis.b256.conv0.affine', 'G.synthesis.b256.conv1.affine', 'G.synthesis.b256.torgb.affine', 'G.synthesis.b512.conv0.affine', 'G.synthesis.b512.conv1.affine', 'G.synthesis.b512.torgb.affine', 'G.synthesis.b1024.conv0.affine', 'G.synthesis.b1024.conv1.affine', 'G.synthesis.b1024.torgb.affine'])
|
360 |
+
test_dict = dict(zip(keys, e4e_blend))
|
361 |
+
manipulator_list = []
|
362 |
+
manipulator_list.append(test_dict)
|
363 |
+
all_imgs = manipulator.synthesis_from_styles(manipulator_list, 0, 1)
|
364 |
+
PTI_outstr = 'PTI_outstr'
|
365 |
+
blend_out = tensor2img(all_imgs[0])
|
366 |
+
return e4e_source_out, e4e_ref_out, blend_out, PTI_outstr
|
367 |
+
|
368 |
+
|
369 |
+
|
370 |
+
|
371 |
+
# _, _, e4e_blend_hair_mask = lpips.parser(e4e_blend_out)
|
372 |
+
# blend_out_np = np.array(blend_out)
|
373 |
+
# blend_np =np.array(e4e_blend_out).astype(np.uint8)
|
374 |
+
|
375 |
+
# e4e_blend_hair_mask = e4e_blend_hair_mask.cpu().numpy().astype(np.uint8)*255
|
376 |
+
# mask_dilate = cv2.dilate(e4e_blend_hair_mask,
|
377 |
+
# kernel=np.ones((50, 50), np.uint8))
|
378 |
+
# mask_dilate_blur = cv2.blur(mask_dilate, ksize=(30, 30))
|
379 |
+
# mask_dilate_blur = (e4e_blend_hair_mask + (255 - e4e_blend_hair_mask) / 255 * mask_dilate_blur).astype(np.uint8)
|
380 |
+
# face_mask = 255 - mask_dilate_blur
|
381 |
+
|
382 |
+
# index = np.where(face_mask > 0)
|
383 |
+
# cy = (np.min(index[0]) + np.max(index[0])) // 2
|
384 |
+
# cx = (np.min(index[1]) + np.max(index[1])) // 2
|
385 |
+
# center = (cx, cy)
|
386 |
+
|
387 |
+
# clone_out = cv2.seamlessClone(blend_out_np, blend_np, face_mask, center, cv2.NORMAL_CLONE)
|
388 |
+
|
389 |
+
# out_str = f'lpips_face: {lpips_face}\nlpips_hair: {lpips_hair}\nssim_face: {ssim_face}\nssim_hair: {ssim_hair}\nid_face: {id_face}\n clip_hair: {clip_hair}'
|
390 |
+
|
391 |
+
|
392 |
+
# seg_out = torch.tensor(face_mask).float().unsqueeze(-1).repeat(1,1,3)
|
393 |
+
# seg_out = seg_out.cpu().numpy().astype(np.uint8)
|
394 |
+
# # seg_out*=255
|
395 |
+
|
396 |
+
# seg_out = Image.fromarray(seg_out)
|
397 |
+
|
398 |
+
# # return source_out, ref_out, blend_out, out_str, e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str, clone_out, seg_out
|
399 |
+
|
400 |
+
# ## Set source_tensor requires_grad=True
|
401 |
+
# source_tensor.requires_grad = True
|
402 |
+
# ref_tensor.requires_grad = True
|
403 |
+
|
404 |
+
|
405 |
+
# ckpt = 'pretrained/ffhq.pkl'
|
406 |
+
# G = Generator_wrapper(ckpt, device)
|
407 |
+
# manipulator = Manipulator(G, device)
|
408 |
+
# manipulator.set_real_img_projection(source_img_path, inv_mode='w+', pti_mode='s')
|
409 |
+
# blend = style2list((add_direction(source_tensor, ref_tensor, part_M, 1.3)))
|
410 |
+
# keys = (['G.synthesis.b4.conv1.affine', 'G.synthesis.b4.torgb.affine', 'G.synthesis.b8.conv0.affine', 'G.synthesis.b8.conv1.affine', 'G.synthesis.b8.torgb.affine', 'G.synthesis.b16.conv0.affine', 'G.synthesis.b16.conv1.affine', 'G.synthesis.b16.torgb.affine', 'G.synthesis.b32.conv0.affine', 'G.synthesis.b32.conv1.affine', 'G.synthesis.b32.torgb.affine', 'G.synthesis.b64.conv0.affine', 'G.synthesis.b64.conv1.affine', 'G.synthesis.b64.torgb.affine', 'G.synthesis.b128.conv0.affine', 'G.synthesis.b128.conv1.affine', 'G.synthesis.b128.torgb.affine', 'G.synthesis.b256.conv0.affine', 'G.synthesis.b256.conv1.affine', 'G.synthesis.b256.torgb.affine', 'G.synthesis.b512.conv0.affine', 'G.synthesis.b512.conv1.affine', 'G.synthesis.b512.torgb.affine', 'G.synthesis.b1024.conv0.affine', 'G.synthesis.b1024.conv1.affine', 'G.synthesis.b1024.torgb.affine'])
|
411 |
+
# test_dict = dict(zip(keys, blend))
|
412 |
+
# manipulator_list = []
|
413 |
+
# manipulator_list.append(test_dict)
|
414 |
+
# all_imgs = manipulator.synthesis_from_styles(manipulator_list, 0, 1)
|
415 |
+
|
416 |
+
# return source_out, ref_out, blend_out, out_str, e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str, clone_out, all_imgs
|
417 |
+
|
418 |
+
|
419 |
+
|
420 |
+
|
421 |
+
|
422 |
+
|
423 |
+
|
424 |
+
## argument for choosing encoder between e4e and hyperstyle
|
425 |
+
args = argparse.ArgumentParser()
|
426 |
+
args.add_argument('--encoder', type=str, default='hyperstyle')
|
427 |
+
|
428 |
+
opt = args.parse_args()
|
429 |
+
|
430 |
+
|
431 |
+
|
432 |
+
|
433 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
434 |
+
|
435 |
+
lpips = FaceMetric(metric_type='lpips', device=device)
|
436 |
+
ssim = FaceMetric(metric_type='ms-ssim', device=device)
|
437 |
+
id_score = FaceMetric(metric_type='id', device=device)
|
438 |
+
clip = FaceMetric(metric_type='cliphair', device=device)
|
439 |
+
|
440 |
+
# generator = Generator(1024, 512, 8, channel_multiplier=2).to(device).eval()
|
441 |
+
# ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
|
442 |
+
# generator.load_state_dict(ckpt['g_ema'], strict=False)
|
443 |
+
|
444 |
+
generator = Generator(1024, 512, 8, channel_multiplier=2).to(device).eval()
|
445 |
+
ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
|
446 |
+
generator.load_state_dict(ckpt['g_ema'], strict=False)
|
447 |
+
|
448 |
+
|
449 |
+
|
450 |
+
ckpt = 'pretrained/ffhq.pkl'
|
451 |
+
G = Generator_wrapper(ckpt, device)
|
452 |
+
manipulator = Manipulator(G, device)
|
453 |
+
|
454 |
+
|
455 |
+
|
456 |
+
if opt.encoder == 'e4e':
|
457 |
+
from util import align_face
|
458 |
+
model_path = 'e4e_ffhq_encode.pt'
|
459 |
+
ensure_checkpoint_exists(model_path)
|
460 |
+
ckpt = torch.load(model_path, map_location='cpu')
|
461 |
+
opts = ckpt['opts']
|
462 |
+
opts['checkpoint_path'] = model_path
|
463 |
+
opts= Namespace(**opts)
|
464 |
+
net = pSp(opts, device).eval().to(device)
|
465 |
+
|
466 |
+
elif opt.encoder == 'hyperstyle':
|
467 |
+
from hyperstyle.scripts.align_faces_parallel import align_face
|
468 |
+
model_path = 'pretrained_models/hyperstyle_ffhq.pt'
|
469 |
+
predictor = dlib.shape_predictor('pretrained_models/shape_predictor_68_face_landmarks.dat')
|
470 |
+
net, _ = load_model(model_path)
|
471 |
+
|
472 |
+
else:
|
473 |
+
raise ValueError('invalid encoder')
|
474 |
+
|
475 |
+
|
476 |
+
|
477 |
+
|
478 |
+
truncation = 0.5
|
479 |
+
stop_idx = 11
|
480 |
+
n_clusters = 18
|
481 |
+
|
482 |
+
clusterer = pickle.load(open('catalog.pkl', 'rb'))
|
483 |
+
|
484 |
+
labels2idx = {
|
485 |
+
'nose': 0,
|
486 |
+
'eyes': 1,
|
487 |
+
'mouth': 2,
|
488 |
+
'hair': 3,
|
489 |
+
'background': 4,
|
490 |
+
'cheek': 5,
|
491 |
+
'neck': 6,
|
492 |
+
'clothes': 7,
|
493 |
+
}
|
494 |
+
|
495 |
+
labels_map = {
|
496 |
+
0: torch.tensor([7]),
|
497 |
+
1: torch.tensor([1,6]),
|
498 |
+
2: torch.tensor([4]),
|
499 |
+
3: torch.tensor([0,3,5,8,10,15,16]),
|
500 |
+
4: torch.tensor([11,13,14]),
|
501 |
+
5: torch.tensor([9]),
|
502 |
+
6: torch.tensor([17]),
|
503 |
+
7: torch.tensor([2,12]),
|
504 |
+
}
|
505 |
+
|
506 |
+
lables2idx = dict((v,k) for k,v in labels2idx.items())
|
507 |
+
n_class = len(lables2idx)
|
508 |
+
|
509 |
+
segid_map = dict.fromkeys(labels_map[0].tolist(), 0)
|
510 |
+
segid_map.update(dict.fromkeys(labels_map[1].tolist(), 1))
|
511 |
+
segid_map.update(dict.fromkeys(labels_map[2].tolist(), 2))
|
512 |
+
segid_map.update(dict.fromkeys(labels_map[3].tolist(), 3))
|
513 |
+
segid_map.update(dict.fromkeys(labels_map[4].tolist(), 4))
|
514 |
+
segid_map.update(dict.fromkeys(labels_map[5].tolist(), 5))
|
515 |
+
segid_map.update(dict.fromkeys(labels_map[6].tolist(), 6))
|
516 |
+
segid_map.update(dict.fromkeys(labels_map[7].tolist(), 7))
|
517 |
+
|
518 |
+
torch.manual_seed(0)
|
519 |
+
|
520 |
+
|
521 |
+
transform = transforms.Compose(
|
522 |
+
[
|
523 |
+
transforms.Resize(256),
|
524 |
+
transforms.CenterCrop(256),
|
525 |
+
transforms.ToTensor(),
|
526 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
527 |
+
]
|
528 |
+
)
|
529 |
+
|
530 |
+
topil = transforms.Compose(
|
531 |
+
[
|
532 |
+
transforms.Normalize([-1, -1, -1], [2, 2, 2]),
|
533 |
+
transforms.ToPILImage(),
|
534 |
+
transforms.Resize(1024),
|
535 |
+
|
536 |
+
]
|
537 |
+
)
|
538 |
+
|
539 |
+
|
540 |
+
|
541 |
+
e4e_ris_demo = gr.Interface(hair_transfer_e4e, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text"])
|
542 |
+
hyperstyle_ris_demo = gr.Interface(hair_transfer_hyperstyle, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text", "image"])
|
543 |
+
PTI_ris_demo = gr.Interface(hair_transfer_PTI, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text"])
|
544 |
+
|
545 |
+
ris_demo = gr.TabbedInterface(interface_list = [hyperstyle_ris_demo,e4e_ris_demo, PTI_ris_demo], tab_names=["hyperstyle", "e4e", "PTI"])
|
546 |
+
|
547 |
+
|
548 |
+
ris_demo.launch(share=True)
|
549 |
+
|
gradio_wrapper/gradio_options.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
sys.path.append(".")
|
4 |
+
sys.path.append("..")
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
|
7 |
+
class GradioTestOptions:
|
8 |
+
|
9 |
+
def __init__(self):
|
10 |
+
self.parser = ArgumentParser()
|
11 |
+
self.initialize()
|
12 |
+
|
13 |
+
def initialize(self):
|
14 |
+
# arguments for inference script
|
15 |
+
self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to model checkpoint')
|
16 |
+
|
17 |
+
self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use')
|
18 |
+
self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true")
|
19 |
+
self.parser.add_argument('--no_medium_mapper', default=False, action="store_true")
|
20 |
+
self.parser.add_argument('--no_fine_mapper', default=False, action="store_true")
|
21 |
+
self.parser.add_argument('--use_weight_delta_mapper', default=False, action="store_true")
|
22 |
+
self.parser.add_argument('--stylegan_size', default=1024, type=int)
|
23 |
+
|
24 |
+
self.parser.add_argument('--alpha', default=4.1, type=float, help='Alpha to use for weight delta')
|
25 |
+
self.parser.add_argument('--beta', default=0.14, type=float, help='Beta to use for weight delta')
|
26 |
+
self.parser.add_argument('--edit_weight_delta', default=False, action='store_true', help='Edit the Weight Delta in addition')
|
27 |
+
self.parser.add_argument('--weight_delta_alpha', default=4.1, type=float, help='Alpha to use for weight delta')
|
28 |
+
self.parser.add_argument('--weight_delta_beta', default=0.14, type=float, help='Beta to use for weight delta')
|
29 |
+
self.parser.add_argument("--delta_i_c", type=str, default='./hyperstyle_global_directions/global_directions/ffhq/fs3.npy', help="path to file containing delta_i_c")
|
30 |
+
self.parser.add_argument("--s_statistics", type=str, default='./hyperstyle_global_directions/global_directions/ffhq/S_mean_std', help="path to file containing s statistics")
|
31 |
+
self.parser.add_argument("--text_prompt_templates", default='./hyperstyle_global_directions/global_directions/templates.txt')
|
32 |
+
|
33 |
+
self.parser.add_argument("--neutral_text", type=str, default="A face with hair")
|
34 |
+
self.parser.add_argument("--target_text", type=str, default=None)
|
35 |
+
|
36 |
+
#arguments for hyperstyle
|
37 |
+
self.parser.add_argument('--hyperstyle_checkpoint_path', default='./pretrained_models/hyperstyle/hyperstyle_ffhq.pt', type=str, help='Path to HyperStyle model checkpoint')
|
38 |
+
self.parser.add_argument('--resize_outputs', action='store_true', help='Whether to resize outputs to 256x256 or keep at original output resolution')
|
39 |
+
|
40 |
+
# arguments for loading pre-trained encoder
|
41 |
+
self.parser.add_argument('--load_w_encoder', action='store_true', help='Whether to load the w e4e encoder.')
|
42 |
+
self.parser.add_argument('--w_encoder_checkpoint_path', default='./pretrained_models/hyperstyle/faces_w_encoder.pt', type=str, help='Path to pre-trained W-encoder.')
|
43 |
+
self.parser.add_argument('--w_encoder_type', default='WEncoder', help='Encoder type for the encoder used to get the initial inversion')
|
44 |
+
|
45 |
+
# arguments for iterative inference
|
46 |
+
self.parser.add_argument('--n_iters_per_batch', default=5, type=int, help='Number of forward passes per batch during training.')
|
47 |
+
|
48 |
+
#arguments to test dataset
|
49 |
+
self.parser.add_argument('--work_in_stylespace', default=False, action='store_true')
|
50 |
+
|
51 |
+
def parse(self, args=None):
|
52 |
+
opts = self.parser.parse_args(args)
|
53 |
+
return opts
|