ethanNeuralImage commited on
Commit
07ecdd9
1 Parent(s): 7deca27

initial commit -- gradio

Browse files
Files changed (3) hide show
  1. app.py +215 -0
  2. gradio_wrapper/demo.py +549 -0
  3. gradio_wrapper/gradio_options.py +53 -0
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import torch
4
+
5
+ sys.path.append(".")
6
+
7
+ from gradio_wrapper.gradio_options import GradioTestOptions
8
+ from models.hyperstyle.utils.model_utils import load_model
9
+ from models.hyperstyle.utils.common import tensor2im
10
+ from models.hyperstyle.utils.inference_utils import run_inversion
11
+
12
+ from hyperstyle_global_directions.edit import load_direction_calculator, edit_image
13
+
14
+ from torchvision import transforms
15
+
16
+ import gradio as gr
17
+
18
+ from utils.alignment import align_face
19
+ import dlib
20
+
21
+ from argparse import Namespace
22
+
23
+ from mapper.styleclip_mapper import StyleCLIPMapper
24
+
25
+ from PIL import Image
26
+
27
+ opts_args = ['--no_fine_mapper']
28
+ opts = GradioTestOptions().parse(opts_args)
29
+
30
+ mapper_dict = {
31
+ 'afro':'./pretrained_models/styleCLIP_mappers/afro_hairstyle.pt',
32
+ 'bob':'./pretrained_models/styleCLIP_mappers/bob_hairstyle.pt',
33
+ 'bowl':'./pretrained_models/styleCLIP_mappers/bowl_hairstyle.pt',
34
+ 'buzz':'./pretrained_models/styleCLIP_mappers/buzz_hairstyle.pt',
35
+ 'caesar':'./pretrained_models/styleCLIP_mappers/caesar_hairstyle.pt',
36
+ 'crew':'./pretrained_models/styleCLIP_mappers/crew_hairstyle.pt',
37
+ 'pixie':'./pretrained_models/styleCLIP_mappers/pixie_hairstyle.pt',
38
+ 'straight':'./pretrained_models/styleCLIP_mappers/straight_hairstyle.pt',
39
+ 'undercut':'./pretrained_models/styleCLIP_mappers/undercut_hairstyle.pt',
40
+ 'wavy':'./pretrained_models/styleCLIP_mappers/wavy_hairstyle.pt'
41
+ }
42
+
43
+ predictor = dlib.shape_predictor("./pretrained_models/hyperstyle/shape_predictor_68_face_landmarks.dat")
44
+ hyperstyle, hyperstyle_args = load_model(opts.hyperstyle_checkpoint_path, update_opts=opts)
45
+ resize_amount = (256, 256) if hyperstyle_args.resize_outputs else (hyperstyle_args.output_size, hyperstyle_args.output_size)
46
+ im2tensor_transforms = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
47
+ direction_calculator = load_direction_calculator(opts)
48
+
49
+ ckpt = torch.load(mapper_dict['afro'], map_location='cpu')
50
+ opts.checkpoint_path = mapper_dict['afro']
51
+ mapper_args = ckpt['opts']
52
+ mapper_args.update(vars(opts))
53
+ mapper_args = Namespace(**mapper_args)
54
+ mapper = StyleCLIPMapper(mapper_args)
55
+ mapper.eval()
56
+ mapper.cuda()
57
+
58
+ def change_mapper(desc):
59
+ global mapper
60
+ global mapper_args
61
+ mapper = None
62
+ ckpt = None
63
+ mapper_args = None
64
+ torch.cuda.empty_cache()
65
+ opts.checkpoint_path = mapper_dict[desc]
66
+ ckpt = torch.load(mapper_dict[desc], map_location='cpu')
67
+ mapper_args = ckpt['opts']
68
+ mapper_args.update(vars(opts))
69
+ mapper_args = Namespace(**mapper_args)
70
+ mapper = StyleCLIPMapper(mapper_args)
71
+ mapper.eval()
72
+ mapper.cuda()
73
+
74
+
75
+
76
+ with gr.Blocks() as demo:
77
+ with gr.Row() as row:
78
+ with gr.Column() as inputs:
79
+ source = gr.Image(label="Image to Map", type='filepath')
80
+ align = gr.Checkbox(True, label='Align Image')
81
+ inverter_bools = gr.CheckboxGroup(["Hyperstyle", "E4E"], value=['Hyperstyle'], label='Inverter Choices')
82
+ n_hyperstyle_iterations = gr.Number(3, label='Number of Iterations For Hyperstyle', precision=0)
83
+ with gr.Box():
84
+ mapper_bool = gr.Checkbox(True, label='Output Mapper Result')
85
+ with gr.Box() as mapper_opts:
86
+ mapper_choice = gr.Dropdown(['afro', 'bob', 'bowl', 'buzz', 'caesar', 'crew', 'pixie', 'straight', 'undercut', 'wavy'], value='afro', label='What Hairstyle Mapper to Use?')
87
+ mapper_alpha = gr.Slider(minimum=-0.5, maximum=0.5, value=0.01, step=0.1, label='Strength of Mapper Alpha',)
88
+ with gr.Box():
89
+ gd_bool = gr.Checkbox(False, label='Output Global Direction Result')
90
+ with gr.Box(visible=False) as gd_opts:
91
+ neutral_text = gr.Text(value='A face with hair', label='Neutral Text')
92
+ target_text = gr.Text(value=mapper_args.description, label='Target Text')
93
+ alpha = gr.Slider(minimum=-10.0, maximum=10.0, value=4.1, step=0.1, label="Alpha for Global Direction")
94
+ beta = gr.Slider(minimum=0.0, maximum=0.30, value=0.15, step=0.01, label="Beta for Global Direction")
95
+ submit_button = gr.Button("Edit Image")
96
+ with gr.Column() as outputs:
97
+ with gr.Row() as hyperstyle_images:
98
+ output_hyperstyle_mapper = gr.Image(type='pil', label="Hyperstyle Mapper")
99
+ output_hyperstyle_gd = gr.Image(type='pil', label="Hyperstyle Global Directions", visible=False)
100
+ with gr.Row(visible=False) as e4e_images:
101
+ output_e4e_mapper = gr.Image(type='pil', label="E4E Mapper")
102
+ output_e4e_gd = gr.Image(type='pil', label="E4E Global Directions", visible=False)
103
+ def mapper_change(new_mapper):
104
+ change_mapper(new_mapper)
105
+ return mapper_args.description
106
+ def inverter_toggles(bools):
107
+ e4e_bool = 'E4E' in bools
108
+ hyperstyle_bool = 'Hyperstyle' in bools
109
+ return {
110
+ hyperstyle_images: gr.update(visible=hyperstyle_bool),
111
+ e4e_images: gr.update(visible=e4e_bool),
112
+ n_hyperstyle_iterations: gr.update(visible=hyperstyle_bool)
113
+ }
114
+
115
+ def mapper_toggles(bool):
116
+ return {
117
+ mapper_opts: gr.update(visible=bool),
118
+ output_hyperstyle_mapper: gr.update(visible=bool),
119
+ output_e4e_mapper: gr.update(visible=bool)
120
+ }
121
+ def gd_toggles(bool):
122
+ return {
123
+ gd_opts: gr.update(visible=bool),
124
+ output_hyperstyle_gd: gr.update(visible=bool),
125
+ output_e4e_gd: gr.update(visible=bool)
126
+ }
127
+
128
+ mapper_choice.change(mapper_change, mapper_choice, [target_text])
129
+ inverter_bools.change(inverter_toggles, inverter_bools, [hyperstyle_images, e4e_images, n_hyperstyle_iterations])
130
+ mapper_bool.change(mapper_toggles, mapper_bool, [mapper_opts, output_hyperstyle_mapper, output_e4e_mapper])
131
+ gd_bool.change(gd_toggles, gd_bool, [gd_opts, output_hyperstyle_gd, output_e4e_gd])
132
+ def map_latent(inputs, stylespace=False, weight_deltas=None, strength=0.1):
133
+ w = inputs.cuda()
134
+ with torch.no_grad():
135
+ if stylespace:
136
+ delta = mapper.mapper(w)
137
+ w_hat = [c + strength * delta_c for (c, delta_c) in zip(w, delta)]
138
+ x_hat, _, w_hat = mapper.decoder([w_hat], input_is_latent=True, return_latents=True,
139
+ randomize_noise=False, truncation=1, input_is_stylespace=True, weights_deltas=weight_deltas)
140
+ else:
141
+ delta = mapper.mapper(w)
142
+ w_hat = w + strength * delta
143
+ x_hat, w_hat, _ = mapper.decoder([w_hat], input_is_latent=True, return_latents=True,
144
+ randomize_noise=False, truncation=1, weights_deltas=weight_deltas)
145
+ result_batch = (x_hat, w_hat)
146
+ return result_batch
147
+ def submit(
148
+ src, align_img, inverter_bools, n_iterations,
149
+ mapper_bool, mapper_choice, mapper_alpha,
150
+ gd_bool, neutral_text, target_text, alpha, beta,
151
+ ):
152
+ torch.cuda.empty_cache()
153
+ with torch.no_grad():
154
+ output_imgs = []
155
+ if align_img:
156
+ input_img = align_face(src, predictor)
157
+ else:
158
+ input_img = Image.open(src).convert('RGB')
159
+ input_img = im2tensor_transforms(input_img).cuda()
160
+
161
+ if gd_bool:
162
+ opts.neutral_text = neutral_text
163
+ opts.target_text = target_text
164
+ opts.alpha = alpha
165
+ opts.beta = beta
166
+
167
+ if 'Hyperstyle' in inverter_bools:
168
+ hyperstyle_batch, hyperstyle_latents, hyperstyle_deltas, _ = run_inversion(input_img.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
169
+ if mapper_bool:
170
+ mapped_hyperstyle, _ = map_latent(hyperstyle_latents, stylespace=False, weight_deltas=hyperstyle_deltas, strength=mapper_alpha)
171
+ mapped_hyperstyle = tensor2im(mapped_hyperstyle[0])
172
+ else:
173
+ mapped_hyperstyle = None
174
+
175
+ if gd_bool:
176
+ gd_hyperstyle = edit_image(_, hyperstyle_latents[0], hyperstyle.decoder, direction_calculator, opts, hyperstyle_deltas)[0]
177
+ gd_hyperstyle = tensor2im(gd_hyperstyle)
178
+ else:
179
+ gd_hyperstyle = None
180
+
181
+ hyperstyle_output = [mapped_hyperstyle,gd_hyperstyle]
182
+ else:
183
+ hyperstyle_output = [None, None]
184
+ output_imgs.extend(hyperstyle_output)
185
+ if 'E4E' in inverter_bools:
186
+ e4e_batch, e4e_latents = hyperstyle.w_invert(input_img.unsqueeze(0))
187
+ e4e_deltas = None
188
+ if mapper_bool:
189
+ mapped_e4e, _ = map_latent(e4e_latents, stylespace=False, weight_deltas=e4e_deltas, strength=mapper_alpha)
190
+ mapped_e4e = tensor2im(mapped_e4e[0])
191
+ else:
192
+ mapped_e4e = None
193
+
194
+ if gd_bool:
195
+ gd_e4e = edit_image(_, e4e_latents[0], hyperstyle.decoder, direction_calculator, opts, e4e_deltas)[0]
196
+ gd_e4e = tensor2im(gd_e4e)
197
+ else:
198
+ gd_e4e = None
199
+
200
+ e4e_output = [mapped_e4e, gd_e4e]
201
+ else:
202
+ e4e_output = [None, None]
203
+ output_imgs.extend(e4e_output)
204
+ return output_imgs
205
+ submit_button.click(
206
+ submit,
207
+ [
208
+ source, align, inverter_bools, n_hyperstyle_iterations,
209
+ mapper_bool, mapper_choice, mapper_alpha,
210
+ gd_bool, neutral_text, target_text, alpha, beta,
211
+ ],
212
+ [output_hyperstyle_mapper, output_hyperstyle_gd, output_e4e_mapper, output_e4e_gd]
213
+ )
214
+
215
+ demo.launch()
gradio_wrapper/demo.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.optim as optim
6
+ import torch.backends.cudnn as cudnn
7
+ cudnn.benchmark = True
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import os
11
+ import sys
12
+ from tqdm import tqdm as tqdm
13
+ import pickle
14
+ import warnings
15
+ warnings.filterwarnings("ignore")
16
+ from spherical_kmeans import MiniBatchSphericalKMeans as sKmeans
17
+ from argparse import Namespace
18
+ from torch.utils.data import Dataset, DataLoader
19
+ from torchvision.models import vgg19
20
+ import glob
21
+ from pathlib import Path
22
+ import lpips
23
+ import argparse
24
+ import gc
25
+ import cv2
26
+
27
+ from e4e_projection import projection
28
+
29
+
30
+ from model import *
31
+ from util import *
32
+ from e4e.models.psp import pSp
33
+ import torchvision.transforms as transforms
34
+
35
+ from torch.nn import DataParallel
36
+ import torchvision.transforms.functional as TF
37
+ from FaceQualityMetrics.utils import FaceMetric
38
+ from hyperstyle.utils.model_utils import load_model
39
+ from configs.paths_config import model_paths
40
+
41
+
42
+ from manipulator import Manipulator
43
+ from wrapper import Generator_wrapper
44
+
45
+
46
+ def run_inversion(inputs, net, n_iters_per_batch, return_intermediate_results=False, resize_outputs=False, weights_deltas=None):
47
+ y_hat, latent, weights_deltas, codes = None, None, weights_deltas, None
48
+
49
+ if return_intermediate_results:
50
+ results_batch = {idx: [] for idx in range(inputs.shape[0])}
51
+ results_latent = {idx: [] for idx in range(inputs.shape[0])}
52
+ results_deltas = {idx: [] for idx in range(inputs.shape[0])}
53
+ else:
54
+ results_batch, results_latent, results_deltas = None, None, None
55
+
56
+ if weights_deltas is None:
57
+
58
+ for iter in range(n_iters_per_batch):
59
+ y_hat, latent, weights_deltas, codes, _ = net.forward(inputs,
60
+ y_hat=y_hat,
61
+ codes=codes,
62
+ weights_deltas=weights_deltas,
63
+ return_latents=True,
64
+ resize=resize_outputs,
65
+ randomize_noise=False,
66
+ return_weight_deltas_and_codes=True)
67
+ # weights_deltas[14]= None
68
+ # weights_deltas[20]= None
69
+ # weights_deltas[21]= None
70
+ # weights_deltas[23]= None
71
+ # weights_deltas[24]= None
72
+
73
+ if return_intermediate_results:
74
+ store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas)
75
+ # resize input to 256 before feeding into next iteration
76
+
77
+ y_hat = net.face_pool(y_hat)
78
+
79
+ else:
80
+ for iter in range(n_iters_per_batch):
81
+ y_hat, latent, _, codes, _ = net.forward(inputs,
82
+ y_hat=y_hat,
83
+ codes=codes,
84
+ weights_deltas=weights_deltas,
85
+ return_latents=True,
86
+ resize=resize_outputs,
87
+ randomize_noise=False,
88
+ return_weight_deltas_and_codes=True)
89
+
90
+
91
+
92
+ if return_intermediate_results:
93
+ store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas)
94
+
95
+ # resize input to 256 before feeding into next iteration
96
+
97
+ y_hat = net.face_pool(y_hat)
98
+
99
+ if return_intermediate_results:
100
+ return results_batch, results_latent, results_deltas
101
+ return y_hat, latent, weights_deltas, codes
102
+
103
+ def store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas):
104
+ for idx in range(y_hat.shape[0]):
105
+ results_batch[idx].append(y_hat[idx])
106
+ results_latent[idx].append(latent[idx].cpu().numpy())
107
+ results_deltas[idx].append([w[idx].cpu().numpy() if w is not None else None for w in weights_deltas])
108
+
109
+ # compute M given a style code.
110
+ @torch.no_grad()
111
+ def compute_M(w, weights_deltas=None, device='cuda'):
112
+ M = []
113
+
114
+ # get segmentation
115
+ # _, outputs = generator(w, is_cluster=1)
116
+ _, outputs = generator(w, weights_deltas=weights_deltas)
117
+ cluster_layer = outputs[stop_idx][0]
118
+ activation = flatten_act(cluster_layer)
119
+ seg_mask = clusterer.predict(activation)
120
+ b,c,h,w = cluster_layer.size()
121
+
122
+ # create masks for each feature
123
+ all_seg_mask = []
124
+ seg_mask = torch.from_numpy(seg_mask).view(b,1,h,w,1).to(device)
125
+
126
+ for key in range(n_class):
127
+ # combine masks for all indices for a particular segmentation class
128
+ indices = labels_map[key].view(1,1,1,1,-1)
129
+ key_mask = (seg_mask == indices.to(device)).any(-1) #[b,1,h,w]
130
+ all_seg_mask.append(key_mask)
131
+
132
+ all_seg_mask = torch.stack(all_seg_mask, 1)
133
+
134
+ # go through each activation layer and compute M
135
+ for layer_idx in range(len(outputs)):
136
+ layer = outputs[layer_idx][1].to(device)
137
+ b,c,h,w = layer.size()
138
+ layer = F.instance_norm(layer)
139
+ layer = layer.pow(2)
140
+
141
+ # resize the segmentation masks to current activations' resolution
142
+ layer_seg_mask = F.interpolate(all_seg_mask.flatten(0,1).float(), align_corners=False,
143
+ size=(h,w), mode='bilinear').view(b,-1,1,h,w)
144
+
145
+ masked_layer = layer.unsqueeze(1) * layer_seg_mask # [b,k,c,h,w]
146
+ masked_layer = (masked_layer.sum([3,4])/ (h*w))#[b,k,c]
147
+
148
+ M.append(masked_layer.to(device))
149
+
150
+ M = torch.cat(M, -1) #[b, k, c]
151
+
152
+ # softmax to assign each channel to a particular segmentation class
153
+ M = F.softmax(M/.1, 1)
154
+ # simple thresholding
155
+ M = (M>.8).float()
156
+
157
+ # zero out torgb transfers, from https://arxiv.org/abs/2011.12799
158
+ for i in range(n_class):
159
+ part_M = style2list(M[:, i])
160
+ for j in range(len(part_M)):
161
+ if j in rgb_layer_idx:
162
+ part_M[j].zero_()
163
+ part_M = list2style(part_M)
164
+ M[:, i] = part_M
165
+
166
+ return M
167
+
168
+
169
+
170
+ #====
171
+ # for i in range(len(blend_deltas)):
172
+ # if blend_deltas[i] is not None:
173
+ # print(f'{i}: {part_M_mask[i].sum()}/{sum(part_M_mask[i].shape)}')
174
+ # if part_M_mask[i].sum() >= sum(part_M_mask[i].shape)/2:
175
+ # print(i)
176
+ # blend_deltas[i] = ref_deltas[i]
177
+
178
+
179
+ def tensor2img(tensor):
180
+ tensor = tensor.cpu().clamp(-1, 1)
181
+ img = topil(tensor.squeeze())
182
+
183
+ return img
184
+
185
+
186
+ def hair_transfer_hyperstyle(source_img_path, ref_img_path):
187
+
188
+ with torch.no_grad():
189
+ source_img = align_face(source_img_path, predictor=predictor)
190
+ ref_img = align_face(ref_img_path, predictor=predictor)
191
+ source_img = Image.fromarray(np.uint8(source_img))
192
+ ref_img = Image.fromarray(np.uint8(ref_img))
193
+
194
+ source_tensor = transform(source_img).unsqueeze(0).to(device)
195
+ ref_tensor = transform(ref_img).unsqueeze(0).to(device)
196
+
197
+ source_batch, source_latent, source_deltas, source_codes = run_inversion(source_tensor, net, n_iters_per_batch=5, return_intermediate_results=False)
198
+ ref_batch, ref_latent, ref_deltas, ref_codes = run_inversion(ref_tensor, net, n_iters_per_batch=5, return_intermediate_results=False)
199
+
200
+ source = generator.get_latent(source_latent[0].unsqueeze(0), truncation=1, is_latent=True)
201
+ ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True)
202
+
203
+ source_out, _ = generator(source, weights_deltas=source_deltas, randomize_noise=False)
204
+ ref_out, _ = generator(ref, weights_deltas=ref_deltas, randomize_noise=False)
205
+
206
+ source_M = compute_M(source, weights_deltas=source_deltas, device='cpu')
207
+ ref_M = compute_M(ref, weights_deltas=ref_deltas, device='cpu')
208
+
209
+ blend_deltas = source_deltas
210
+
211
+ max_M = torch.max(source_M.expand_as(ref_M), ref_M)
212
+ max_M = add_pose(max_M, labels2idx)
213
+ idx = labels2idx['hair']
214
+ part_M = max_M[:, idx].to(device)
215
+ part_M_mask = style2list(part_M)
216
+ blend = style2list((add_direction(source, ref, part_M, 1.3)))
217
+ blend_out, _ = generator(blend, weights_deltas=blend_deltas)
218
+
219
+ source_out = tensor2img(source_out)
220
+ ref_out = tensor2img(ref_out)
221
+ blend_out = tensor2img(blend_out)
222
+
223
+
224
+
225
+ lpips_face, _ = lpips(blend_out, source_out)
226
+ ssim_face, _ = ssim(blend_out, source_out)
227
+ id_face, _ = id_score(blend_out, source_out)
228
+ _, lpips_hair = lpips(blend_out, ref_out)
229
+ _, ssim_hair = ssim(blend_out, ref_out)
230
+ _, clip_hair = clip(blend_out, source_out)
231
+ out_str = f'lpips_face: {lpips_face}\nlpips_hair: {lpips_hair}\nssim_face: {ssim_face}\nssim_hair: {ssim_hair}\nid_face: {id_face}\n clip_hair: {clip_hair}'
232
+
233
+ e4e_blend_out, _ = generator(blend)
234
+ e4e_blend_out = tensor2img(e4e_blend_out)
235
+ _, _, e4e_blend_hair_mask = lpips.parser(e4e_blend_out)
236
+ source_out_np = np.array(source_out)
237
+ blend_np =np.array(e4e_blend_out).astype(np.uint8)
238
+
239
+ e4e_blend_hair_mask = e4e_blend_hair_mask.cpu().numpy().astype(np.uint8)*255
240
+ mask_dilate = cv2.dilate(e4e_blend_hair_mask,
241
+ kernel=np.ones((50, 50), np.uint8))
242
+ mask_dilate_blur = cv2.blur(mask_dilate, ksize=(30, 30))
243
+ mask_dilate_blur = (e4e_blend_hair_mask + (255 - e4e_blend_hair_mask) / 255 * mask_dilate_blur).astype(np.uint8)
244
+ face_mask = 255 - mask_dilate_blur
245
+
246
+ index = np.where(face_mask > 0)
247
+ cy = (np.min(index[0]) + np.max(index[0])) // 2
248
+ cx = (np.min(index[1]) + np.max(index[1])) // 2
249
+ center = (cx, cy)
250
+
251
+ clone_out = cv2.seamlessClone(source_out_np, blend_np, face_mask, center, cv2.NORMAL_CLONE)
252
+
253
+
254
+ return source_out, ref_out, blend_out, out_str, clone_out
255
+
256
+
257
+ def hair_transfer_e4e(source_img_path, ref_img_path):
258
+
259
+
260
+ with torch.no_grad():
261
+ source_img = align_face(source_img_path, predictor=predictor)
262
+ ref_img = align_face(ref_img_path, predictor=predictor)
263
+ source_img = Image.fromarray(np.uint8(source_img))
264
+ ref_img = Image.fromarray(np.uint8(ref_img))
265
+
266
+ source_tensor = transform(source_img).unsqueeze(0).to(device)
267
+ ref_tensor = transform(ref_img).unsqueeze(0).to(device)
268
+
269
+ source_batch, source_latent, source_deltas, source_codes = run_inversion(source_tensor, net, n_iters_per_batch=5, return_intermediate_results=False)
270
+ ref_batch, ref_latent, ref_deltas, ref_codes = run_inversion(ref_tensor, net, n_iters_per_batch=5, return_intermediate_results=False)
271
+
272
+ source = generator.get_latent(source_latent[0].unsqueeze(0), truncation=1, is_latent=True)
273
+ ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True)
274
+
275
+ e4e_source_out, _ = generator(source, randomize_noise=False)
276
+ e4e_ref_out, _ = generator(ref, randomize_noise=False)
277
+
278
+ e4e_source_M = compute_M(source, device='cpu')
279
+ e4e_ref_M = compute_M(ref, device='cpu')
280
+
281
+ e4e_max_M = torch.max(e4e_source_M.expand_as(e4e_ref_M), e4e_ref_M)
282
+ e4e_max_M = add_pose(e4e_max_M, labels2idx)
283
+ e4e_idx = labels2idx['hair']
284
+
285
+ e4e_part_M = e4e_max_M[:, e4e_idx].to(device)
286
+ e4e_part_M_mask = style2list(e4e_part_M)
287
+
288
+ e4e_blend = style2list((add_direction(source, ref, e4e_part_M, 1.3)))
289
+ e4e_blend_out, _ = generator(e4e_blend)
290
+
291
+
292
+ e4e_source_out = tensor2img(e4e_source_out)
293
+ e4e_ref_out = tensor2img(e4e_ref_out)
294
+ e4e_blend_out = tensor2img(e4e_blend_out)
295
+
296
+
297
+ e4e_lpips_face, _ = lpips(e4e_blend_out, e4e_source_out)
298
+ e4e_ssim_face, _ = ssim(e4e_blend_out, e4e_source_out)
299
+ e4e_id_face, _ = id_score(e4e_blend_out, e4e_source_out)
300
+ _, e4e_lpips_hair = lpips(e4e_blend_out, e4e_ref_out)
301
+ _, e4e_ssim_hair = ssim(e4e_blend_out, e4e_ref_out)
302
+ _, e4e_clip_hair = clip(e4e_blend_out, e4e_source_out)
303
+
304
+ e4e_out_str = f'e4e_lpips_face: {e4e_lpips_face}\ne4e_lpips_hair: {e4e_lpips_hair}\ne4e_ssim_face: {e4e_ssim_face}\ne4e_ssim_hair: {e4e_ssim_hair}\ne4e_id_face: {e4e_id_face}\ne4e_ clip_hair: {e4e_clip_hair}'
305
+
306
+
307
+ return e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str
308
+
309
+ def hair_transfer_PTI(source_img_path, ref_img_path):
310
+
311
+ ckpt = 'pretrained/ffhq.pkl'
312
+ G = Generator_wrapper(ckpt, device)
313
+ manipulator = Manipulator(G, device)
314
+ manipulator.set_real_img_projection(source_img_path, inv_mode='w+', pti_mode='s')
315
+
316
+ with torch.no_grad():
317
+ source_img = align_face(source_img_path, predictor=predictor)
318
+ ref_img = align_face(ref_img_path, predictor=predictor)
319
+ source_img = Image.fromarray(np.uint8(source_img))
320
+
321
+
322
+ projection(source_img, 'source', generator, device)
323
+ projection(ref_img, 'ref', generator, device)
324
+ source = load_source('source', generator, device)
325
+ ref = load_source('ref', generator, device)
326
+
327
+
328
+ e4e_source_out, _ = generator(source, randomize_noise=False)
329
+ e4e_ref_out, _ = generator(ref, randomize_noise=False)
330
+
331
+ e4e_source_M = compute_M(source, device='cpu')
332
+ e4e_ref_M = compute_M(ref, device='cpu')
333
+
334
+ e4e_max_M = torch.max(e4e_source_M.expand_as(e4e_ref_M), e4e_ref_M)
335
+ e4e_max_M = add_pose(e4e_max_M, labels2idx)
336
+ e4e_idx = labels2idx['hair']
337
+
338
+ e4e_part_M = e4e_max_M[:, e4e_idx].to(device)
339
+ e4e_part_M_mask = style2list(e4e_part_M)
340
+
341
+
342
+
343
+
344
+ e4e_blend = style2list((add_direction(source, ref, e4e_part_M, 1.3)))
345
+
346
+ e4e_source_out = tensor2img(e4e_source_out)
347
+ e4e_ref_out = tensor2img(e4e_ref_out)
348
+ # e4e_blend_out = tensor2img(e4e_blend_out)
349
+
350
+
351
+ # e4e_lpips_face, _ = lpips(e4e_blend_out, e4e_source_out)
352
+ # e4e_ssim_face, _ = ssim(e4e_blend_out, e4e_source_out)
353
+ # e4e_id_face, _ = id_score(e4e_blend_out, e4e_source_out)
354
+ # _, e4e_lpips_hair = lpips(e4e_blend_out, e4e_ref_out)
355
+ # _, e4e_ssim_hair = ssim(e4e_blend_out, e4e_ref_out)
356
+ # _, e4e_clip_hair = clip(e4e_blend_out, e4e_source_out)
357
+
358
+
359
+ keys = (['G.synthesis.b4.conv1.affine', 'G.synthesis.b4.torgb.affine', 'G.synthesis.b8.conv0.affine', 'G.synthesis.b8.conv1.affine', 'G.synthesis.b8.torgb.affine', 'G.synthesis.b16.conv0.affine', 'G.synthesis.b16.conv1.affine', 'G.synthesis.b16.torgb.affine', 'G.synthesis.b32.conv0.affine', 'G.synthesis.b32.conv1.affine', 'G.synthesis.b32.torgb.affine', 'G.synthesis.b64.conv0.affine', 'G.synthesis.b64.conv1.affine', 'G.synthesis.b64.torgb.affine', 'G.synthesis.b128.conv0.affine', 'G.synthesis.b128.conv1.affine', 'G.synthesis.b128.torgb.affine', 'G.synthesis.b256.conv0.affine', 'G.synthesis.b256.conv1.affine', 'G.synthesis.b256.torgb.affine', 'G.synthesis.b512.conv0.affine', 'G.synthesis.b512.conv1.affine', 'G.synthesis.b512.torgb.affine', 'G.synthesis.b1024.conv0.affine', 'G.synthesis.b1024.conv1.affine', 'G.synthesis.b1024.torgb.affine'])
360
+ test_dict = dict(zip(keys, e4e_blend))
361
+ manipulator_list = []
362
+ manipulator_list.append(test_dict)
363
+ all_imgs = manipulator.synthesis_from_styles(manipulator_list, 0, 1)
364
+ PTI_outstr = 'PTI_outstr'
365
+ blend_out = tensor2img(all_imgs[0])
366
+ return e4e_source_out, e4e_ref_out, blend_out, PTI_outstr
367
+
368
+
369
+
370
+
371
+ # _, _, e4e_blend_hair_mask = lpips.parser(e4e_blend_out)
372
+ # blend_out_np = np.array(blend_out)
373
+ # blend_np =np.array(e4e_blend_out).astype(np.uint8)
374
+
375
+ # e4e_blend_hair_mask = e4e_blend_hair_mask.cpu().numpy().astype(np.uint8)*255
376
+ # mask_dilate = cv2.dilate(e4e_blend_hair_mask,
377
+ # kernel=np.ones((50, 50), np.uint8))
378
+ # mask_dilate_blur = cv2.blur(mask_dilate, ksize=(30, 30))
379
+ # mask_dilate_blur = (e4e_blend_hair_mask + (255 - e4e_blend_hair_mask) / 255 * mask_dilate_blur).astype(np.uint8)
380
+ # face_mask = 255 - mask_dilate_blur
381
+
382
+ # index = np.where(face_mask > 0)
383
+ # cy = (np.min(index[0]) + np.max(index[0])) // 2
384
+ # cx = (np.min(index[1]) + np.max(index[1])) // 2
385
+ # center = (cx, cy)
386
+
387
+ # clone_out = cv2.seamlessClone(blend_out_np, blend_np, face_mask, center, cv2.NORMAL_CLONE)
388
+
389
+ # out_str = f'lpips_face: {lpips_face}\nlpips_hair: {lpips_hair}\nssim_face: {ssim_face}\nssim_hair: {ssim_hair}\nid_face: {id_face}\n clip_hair: {clip_hair}'
390
+
391
+
392
+ # seg_out = torch.tensor(face_mask).float().unsqueeze(-1).repeat(1,1,3)
393
+ # seg_out = seg_out.cpu().numpy().astype(np.uint8)
394
+ # # seg_out*=255
395
+
396
+ # seg_out = Image.fromarray(seg_out)
397
+
398
+ # # return source_out, ref_out, blend_out, out_str, e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str, clone_out, seg_out
399
+
400
+ # ## Set source_tensor requires_grad=True
401
+ # source_tensor.requires_grad = True
402
+ # ref_tensor.requires_grad = True
403
+
404
+
405
+ # ckpt = 'pretrained/ffhq.pkl'
406
+ # G = Generator_wrapper(ckpt, device)
407
+ # manipulator = Manipulator(G, device)
408
+ # manipulator.set_real_img_projection(source_img_path, inv_mode='w+', pti_mode='s')
409
+ # blend = style2list((add_direction(source_tensor, ref_tensor, part_M, 1.3)))
410
+ # keys = (['G.synthesis.b4.conv1.affine', 'G.synthesis.b4.torgb.affine', 'G.synthesis.b8.conv0.affine', 'G.synthesis.b8.conv1.affine', 'G.synthesis.b8.torgb.affine', 'G.synthesis.b16.conv0.affine', 'G.synthesis.b16.conv1.affine', 'G.synthesis.b16.torgb.affine', 'G.synthesis.b32.conv0.affine', 'G.synthesis.b32.conv1.affine', 'G.synthesis.b32.torgb.affine', 'G.synthesis.b64.conv0.affine', 'G.synthesis.b64.conv1.affine', 'G.synthesis.b64.torgb.affine', 'G.synthesis.b128.conv0.affine', 'G.synthesis.b128.conv1.affine', 'G.synthesis.b128.torgb.affine', 'G.synthesis.b256.conv0.affine', 'G.synthesis.b256.conv1.affine', 'G.synthesis.b256.torgb.affine', 'G.synthesis.b512.conv0.affine', 'G.synthesis.b512.conv1.affine', 'G.synthesis.b512.torgb.affine', 'G.synthesis.b1024.conv0.affine', 'G.synthesis.b1024.conv1.affine', 'G.synthesis.b1024.torgb.affine'])
411
+ # test_dict = dict(zip(keys, blend))
412
+ # manipulator_list = []
413
+ # manipulator_list.append(test_dict)
414
+ # all_imgs = manipulator.synthesis_from_styles(manipulator_list, 0, 1)
415
+
416
+ # return source_out, ref_out, blend_out, out_str, e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str, clone_out, all_imgs
417
+
418
+
419
+
420
+
421
+
422
+
423
+
424
+ ## argument for choosing encoder between e4e and hyperstyle
425
+ args = argparse.ArgumentParser()
426
+ args.add_argument('--encoder', type=str, default='hyperstyle')
427
+
428
+ opt = args.parse_args()
429
+
430
+
431
+
432
+
433
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
434
+
435
+ lpips = FaceMetric(metric_type='lpips', device=device)
436
+ ssim = FaceMetric(metric_type='ms-ssim', device=device)
437
+ id_score = FaceMetric(metric_type='id', device=device)
438
+ clip = FaceMetric(metric_type='cliphair', device=device)
439
+
440
+ # generator = Generator(1024, 512, 8, channel_multiplier=2).to(device).eval()
441
+ # ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
442
+ # generator.load_state_dict(ckpt['g_ema'], strict=False)
443
+
444
+ generator = Generator(1024, 512, 8, channel_multiplier=2).to(device).eval()
445
+ ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
446
+ generator.load_state_dict(ckpt['g_ema'], strict=False)
447
+
448
+
449
+
450
+ ckpt = 'pretrained/ffhq.pkl'
451
+ G = Generator_wrapper(ckpt, device)
452
+ manipulator = Manipulator(G, device)
453
+
454
+
455
+
456
+ if opt.encoder == 'e4e':
457
+ from util import align_face
458
+ model_path = 'e4e_ffhq_encode.pt'
459
+ ensure_checkpoint_exists(model_path)
460
+ ckpt = torch.load(model_path, map_location='cpu')
461
+ opts = ckpt['opts']
462
+ opts['checkpoint_path'] = model_path
463
+ opts= Namespace(**opts)
464
+ net = pSp(opts, device).eval().to(device)
465
+
466
+ elif opt.encoder == 'hyperstyle':
467
+ from hyperstyle.scripts.align_faces_parallel import align_face
468
+ model_path = 'pretrained_models/hyperstyle_ffhq.pt'
469
+ predictor = dlib.shape_predictor('pretrained_models/shape_predictor_68_face_landmarks.dat')
470
+ net, _ = load_model(model_path)
471
+
472
+ else:
473
+ raise ValueError('invalid encoder')
474
+
475
+
476
+
477
+
478
+ truncation = 0.5
479
+ stop_idx = 11
480
+ n_clusters = 18
481
+
482
+ clusterer = pickle.load(open('catalog.pkl', 'rb'))
483
+
484
+ labels2idx = {
485
+ 'nose': 0,
486
+ 'eyes': 1,
487
+ 'mouth': 2,
488
+ 'hair': 3,
489
+ 'background': 4,
490
+ 'cheek': 5,
491
+ 'neck': 6,
492
+ 'clothes': 7,
493
+ }
494
+
495
+ labels_map = {
496
+ 0: torch.tensor([7]),
497
+ 1: torch.tensor([1,6]),
498
+ 2: torch.tensor([4]),
499
+ 3: torch.tensor([0,3,5,8,10,15,16]),
500
+ 4: torch.tensor([11,13,14]),
501
+ 5: torch.tensor([9]),
502
+ 6: torch.tensor([17]),
503
+ 7: torch.tensor([2,12]),
504
+ }
505
+
506
+ lables2idx = dict((v,k) for k,v in labels2idx.items())
507
+ n_class = len(lables2idx)
508
+
509
+ segid_map = dict.fromkeys(labels_map[0].tolist(), 0)
510
+ segid_map.update(dict.fromkeys(labels_map[1].tolist(), 1))
511
+ segid_map.update(dict.fromkeys(labels_map[2].tolist(), 2))
512
+ segid_map.update(dict.fromkeys(labels_map[3].tolist(), 3))
513
+ segid_map.update(dict.fromkeys(labels_map[4].tolist(), 4))
514
+ segid_map.update(dict.fromkeys(labels_map[5].tolist(), 5))
515
+ segid_map.update(dict.fromkeys(labels_map[6].tolist(), 6))
516
+ segid_map.update(dict.fromkeys(labels_map[7].tolist(), 7))
517
+
518
+ torch.manual_seed(0)
519
+
520
+
521
+ transform = transforms.Compose(
522
+ [
523
+ transforms.Resize(256),
524
+ transforms.CenterCrop(256),
525
+ transforms.ToTensor(),
526
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
527
+ ]
528
+ )
529
+
530
+ topil = transforms.Compose(
531
+ [
532
+ transforms.Normalize([-1, -1, -1], [2, 2, 2]),
533
+ transforms.ToPILImage(),
534
+ transforms.Resize(1024),
535
+
536
+ ]
537
+ )
538
+
539
+
540
+
541
+ e4e_ris_demo = gr.Interface(hair_transfer_e4e, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text"])
542
+ hyperstyle_ris_demo = gr.Interface(hair_transfer_hyperstyle, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text", "image"])
543
+ PTI_ris_demo = gr.Interface(hair_transfer_PTI, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text"])
544
+
545
+ ris_demo = gr.TabbedInterface(interface_list = [hyperstyle_ris_demo,e4e_ris_demo, PTI_ris_demo], tab_names=["hyperstyle", "e4e", "PTI"])
546
+
547
+
548
+ ris_demo.launch(share=True)
549
+
gradio_wrapper/gradio_options.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(".")
4
+ sys.path.append("..")
5
+ from argparse import ArgumentParser
6
+
7
+ class GradioTestOptions:
8
+
9
+ def __init__(self):
10
+ self.parser = ArgumentParser()
11
+ self.initialize()
12
+
13
+ def initialize(self):
14
+ # arguments for inference script
15
+ self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to model checkpoint')
16
+
17
+ self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use')
18
+ self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true")
19
+ self.parser.add_argument('--no_medium_mapper', default=False, action="store_true")
20
+ self.parser.add_argument('--no_fine_mapper', default=False, action="store_true")
21
+ self.parser.add_argument('--use_weight_delta_mapper', default=False, action="store_true")
22
+ self.parser.add_argument('--stylegan_size', default=1024, type=int)
23
+
24
+ self.parser.add_argument('--alpha', default=4.1, type=float, help='Alpha to use for weight delta')
25
+ self.parser.add_argument('--beta', default=0.14, type=float, help='Beta to use for weight delta')
26
+ self.parser.add_argument('--edit_weight_delta', default=False, action='store_true', help='Edit the Weight Delta in addition')
27
+ self.parser.add_argument('--weight_delta_alpha', default=4.1, type=float, help='Alpha to use for weight delta')
28
+ self.parser.add_argument('--weight_delta_beta', default=0.14, type=float, help='Beta to use for weight delta')
29
+ self.parser.add_argument("--delta_i_c", type=str, default='./hyperstyle_global_directions/global_directions/ffhq/fs3.npy', help="path to file containing delta_i_c")
30
+ self.parser.add_argument("--s_statistics", type=str, default='./hyperstyle_global_directions/global_directions/ffhq/S_mean_std', help="path to file containing s statistics")
31
+ self.parser.add_argument("--text_prompt_templates", default='./hyperstyle_global_directions/global_directions/templates.txt')
32
+
33
+ self.parser.add_argument("--neutral_text", type=str, default="A face with hair")
34
+ self.parser.add_argument("--target_text", type=str, default=None)
35
+
36
+ #arguments for hyperstyle
37
+ self.parser.add_argument('--hyperstyle_checkpoint_path', default='./pretrained_models/hyperstyle/hyperstyle_ffhq.pt', type=str, help='Path to HyperStyle model checkpoint')
38
+ self.parser.add_argument('--resize_outputs', action='store_true', help='Whether to resize outputs to 256x256 or keep at original output resolution')
39
+
40
+ # arguments for loading pre-trained encoder
41
+ self.parser.add_argument('--load_w_encoder', action='store_true', help='Whether to load the w e4e encoder.')
42
+ self.parser.add_argument('--w_encoder_checkpoint_path', default='./pretrained_models/hyperstyle/faces_w_encoder.pt', type=str, help='Path to pre-trained W-encoder.')
43
+ self.parser.add_argument('--w_encoder_type', default='WEncoder', help='Encoder type for the encoder used to get the initial inversion')
44
+
45
+ # arguments for iterative inference
46
+ self.parser.add_argument('--n_iters_per_batch', default=5, type=int, help='Number of forward passes per batch during training.')
47
+
48
+ #arguments to test dataset
49
+ self.parser.add_argument('--work_in_stylespace', default=False, action='store_true')
50
+
51
+ def parse(self, args=None):
52
+ opts = self.parser.parse_args(args)
53
+ return opts