File size: 12,408 Bytes
ab189a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import sys
import os
import torch

from metrics.metrics import ClipHair

sys.path.append(".")

from gradio_wrapper.gradio_options import GradioTestOptions
from models.hyperstyle.utils.model_utils import load_model
from models.hyperstyle.utils.common import tensor2im
from models.hyperstyle.utils.inference_utils import run_inversion

from hyperstyle_global_directions.edit import load_direction_calculator, edit_image

from torchvision import transforms

import gradio as gr

from utils.alignment import align_face
import dlib

from argparse import Namespace

from mapper.styleclip_mapper import StyleCLIPMapper

import ris.spherical_kmeans as spherical_kmeans
from ris.blend import blend_latents
from ris.model import Generator as RIS_Generator

from models.pti.manipulator import Manipulator
from models.pti.wrapper import Generator_wrapper
#from models.pti.e4e_projection import projection

from metrics import FaceMetric
from metrics.criteria.clip_loss import CLIPLoss
import clip

from PIL import Image

opts_args = ['--no_fine_mapper']
opts = GradioTestOptions().parse(opts_args)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
opts.device= device

mapper_dict = {
    'afro':'./pretrained_models/styleCLIP_mappers/afro_hairstyle.pt',
    'bob':'./pretrained_models/styleCLIP_mappers/bob_hairstyle.pt',
    'bowl':'./pretrained_models/styleCLIP_mappers/bowl_hairstyle.pt',
    'buzz':'./pretrained_models/styleCLIP_mappers/buzz_hairstyle.pt',
    'caesar':'./pretrained_models/styleCLIP_mappers/caesar_hairstyle.pt',
    'crew':'./pretrained_models/styleCLIP_mappers/crew_hairstyle.pt',
    'pixie':'./pretrained_models/styleCLIP_mappers/pixie_hairstyle.pt',
    'straight':'./pretrained_models/styleCLIP_mappers/straight_hairstyle.pt',
    'undercut':'./pretrained_models/styleCLIP_mappers/undercut_hairstyle.pt',
    'wavy':'./pretrained_models/styleCLIP_mappers/wavy_hairstyle.pt'
}

mapper_descs = {
    'afro':'A face with an afro',
    'bob':'A face with a bob-cut hairstyle',
    'bowl':'A face with a bowl cut hairstyle',
    'buzz':'A face with a buzz cut hairstyle',
    'caesar':'A face with a caesar cut hairstyle',
    'crew':'A face with a crew cut hairstyle',
    'pixie':'A face with a pixie cut hairstyle',
    'straight':'A face with a straight hair hairstyle',
    'undercut':'A face with a undercut hairstyle',
    'wavy':'A face with a wavy hair hairstyle',
}


predictor = dlib.shape_predictor("./pretrained_models/hyperstyle/shape_predictor_68_face_landmarks.lfs.dat")
hyperstyle, hyperstyle_args = load_model(opts.hyperstyle_checkpoint_path, device=device, update_opts=opts)
resize_amount = (256, 256) if hyperstyle_args.resize_outputs else (hyperstyle_args.output_size, hyperstyle_args.output_size)
im2tensor_transforms = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
direction_calculator = load_direction_calculator(opts)

ris_gen = RIS_Generator(1024, 512, 8, channel_multiplier=2).to(device).eval()
ris_ckpt = torch.load('./pretrained_models/ris/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
ris_gen.load_state_dict(ris_ckpt['g_ema'], strict=False)

lpips_metric = FaceMetric(metric_type='lpips', device=device)
ssim_metric = FaceMetric(metric_type='ms-ssim', device=device)
id_metric = FaceMetric(metric_type='id', device=device)
clip_hair = FaceMetric(metric_type='cliphair', device=device)
clip_text = CLIPLoss(hyperstyle_args)

G = Generator_wrapper('./pretrained_models/pti/ffhq.pkl', device)
manipulator = Manipulator(G, device)


def map_latent(mapper, inputs, stylespace=False, weight_deltas=None, strength=0.1):
    w = inputs.to(device)
    with torch.no_grad():
        if stylespace:
            delta = mapper.mapper(w)
            w_hat = [c + strength * delta_c for (c, delta_c) in zip(w, delta)]
            x_hat, _, w_hat = mapper.decoder([w_hat], input_is_latent=True, return_latents=True,
                                                randomize_noise=False, truncation=1, input_is_stylespace=True, weights_deltas=weight_deltas)
        else:
            delta = mapper.mapper(w)
            w_hat = w + strength * delta
            x_hat, w_hat, _ = mapper.decoder([w_hat], input_is_latent=True, return_latents=True,
                                                randomize_noise=False, truncation=1, weights_deltas=weight_deltas)
        result_batch = (x_hat, w_hat)
    return result_batch
def run_metrics(base_img, edited_img):
    #print(base_img.shape, edited_img.shape)
    #base_img = base_img.unsqueeze(0)
    #edited_img = edited_img.unqueeze(0)
    lpips_score = lpips_metric(base_img, edited_img)[0]
    ssim_score = ssim_metric(base_img, edited_img)[0]
    id_score = id_metric(base_img, edited_img)[0]

    return lpips_score, ssim_score, id_score
def clip_text_metric(tensor, text):
    clip_embed = torch.cat([clip.tokenize(text)]).cuda()
    clip_score = 1-clip_text(tensor.unsqueeze(0), clip_embed).item()
    return clip_score

def submit(
    src, align_img, inverter_bools, n_iterations, invert_bool,
    mapper_bool, mapper_choice, mapper_alpha, 
    gd_bool, neutral_text, target_text, alpha, beta,
    ris_bool, ref_img,
    ):
    if device == 'cuda': torch.cuda.empty_cache()
    opts.checkpoint_path = mapper_dict[mapper_choice]
    ckpt = torch.load(mapper_dict[mapper_choice], map_location='cpu')
    mapper_args = ckpt['opts']
    mapper_args.update(vars(opts))
    mapper_args = Namespace(**mapper_args)
    mapper = StyleCLIPMapper(mapper_args)
    mapper.eval()
    mapper.to(device)
    resize_to = (256, 256) if hyperstyle_args.resize_outputs else (hyperstyle_args.output_size, hyperstyle_args.output_size)
    with torch.no_grad():
        output_imgs = []
        if align_img:
            input_img = align_face(src, predictor)
        else:
            input_img = Image.open(src).convert('RGB')
        input_img = im2tensor_transforms(input_img).to(device)

        if gd_bool:
            opts.neutral_text = neutral_text
            opts.target_text = target_text
            opts.alpha = alpha
            opts.beta = beta
        
        if ris_bool:
            if align_img:
                ref_input = align_face(ref_img, predictor)
            else:
                ref_input = Image.open(src).convert('RGB')
            ref_input = im2tensor_transforms(ref_input).to(device)
        hyperstyle_metrics_text = ''
        if 'Hyperstyle' in inverter_bools:
            hyperstyle_batch, hyperstyle_latents, hyperstyle_deltas, _ = run_inversion(input_img.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
            invert_hyperstyle = tensor2im(hyperstyle_batch[0])
            if mapper_bool:
                mapped_hyperstyle, _ = map_latent(mapper, hyperstyle_latents, stylespace=False, weight_deltas=hyperstyle_deltas, strength=mapper_alpha)
                clip_score = clip_text_metric(mapped_hyperstyle[0], mapper_args.description)
                mapped_hyperstyle = tensor2im(mapped_hyperstyle[0])
                lpips_score, ssim_score, id_score = run_metrics(invert_hyperstyle.resize(resize_to), mapped_hyperstyle.resize(resize_to))
                hyperstyle_metrics_text += f'\nMapper Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
            else:
                mapped_hyperstyle = None
            
            if gd_bool:
                gd_hyperstyle = edit_image(_, hyperstyle_latents[0], hyperstyle.decoder, direction_calculator, opts, hyperstyle_deltas)
                clip_score = clip_text_metric(gd_hyperstyle[0], opts.target_text)
                gd_hyperstyle = tensor2im(gd_hyperstyle[0])
                lpips_score, ssim_score, id_score = run_metrics(invert_hyperstyle.resize(resize_to), gd_hyperstyle.resize(resize_to))
                hyperstyle_metrics_text += f'\nGlobal Direction Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
            else:
                gd_hyperstyle = None
            
            if ris_bool:

                ref_hyperstyle_batch, ref_hyperstyle_latents, ref_hyperstyle_deltas, _ = run_inversion(ref_input.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
                blend_hyperstyle, blend_hyperstyle_latents = blend_latents(hyperstyle_latents, ref_hyperstyle_latents,
                                                src_deltas=hyperstyle_deltas, ref_deltas=ref_hyperstyle_deltas,
                                                generator=ris_gen, device=device)
                ris_hyperstyle = tensor2im(blend_hyperstyle[0])

                lpips_score, ssim_score, id_score = run_metrics(invert_hyperstyle.resize(resize_to), ris_hyperstyle.resize(resize_to))
                clip_score = clip_hair(invert_hyperstyle.resize(resize_to), ris_hyperstyle.resize(resize_to))[1]
                hyperstyle_metrics_text += f'\nRIS Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Hair Score: \t{clip_score}'
            else:
                ris_hyperstyle=None

            hyperstyle_output = [invert_hyperstyle, mapped_hyperstyle,gd_hyperstyle, ris_hyperstyle, hyperstyle_metrics_text]
        else:
            hyperstyle_output = [None, None, None, None, hyperstyle_metrics_text]
        output_imgs.extend(hyperstyle_output)
        e4e_metrics_text = ''
        if 'E4E' in inverter_bools:
            e4e_batch, e4e_latents = hyperstyle.w_invert(input_img.unsqueeze(0))
            e4e_deltas = None
            invert_e4e = tensor2im(e4e_batch[0])
            if mapper_bool:
                mapped_e4e, _ = map_latent(mapper, e4e_latents, stylespace=False, weight_deltas=e4e_deltas, strength=mapper_alpha)
                clip_score = clip_text_metric(mapped_e4e[0], mapper_args.description)
                mapped_e4e = tensor2im(mapped_e4e[0])
                lpips_score, ssim_score, id_score = run_metrics(invert_e4e.resize(resize_to), mapped_e4e.resize(resize_to))
                e4e_metrics_text += f'\nMapper Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
            
            else:
                mapped_e4e = None
            
            if gd_bool:
                gd_e4e = edit_image(_, e4e_latents[0], hyperstyle.decoder, direction_calculator, opts, e4e_deltas)
                clip_score = clip_text_metric(gd_e4e[0], opts.target_text)
                gd_e4e = tensor2im(gd_e4e[0])
                lpips_score, ssim_score, id_score = run_metrics(invert_e4e.resize(resize_to), gd_e4e.resize(resize_to))
                e4e_metrics_text += f'\nGlobal Direction Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
            
            else:
                gd_e4e = None
            
            if ris_bool:
                ref_e4e_batch, ref_e4e_latents, = hyperstyle.w_invert(ref_input.unsqueeze(0))
                ref_e4e_deltas= None
                blend_e4e, blend_e4e_latents = blend_latents(e4e_latents, ref_e4e_latents,
                                                src_deltas=None, ref_deltas=None,
                                                generator=ris_gen, device=device)
                ris_e4e = tensor2im(blend_e4e[0])

                lpips_score, ssim_score, id_score = run_metrics(invert_e4e.resize(resize_to), ris_e4e.resize(resize_to))
                clip_score = clip_hair(invert_e4e.resize(resize_to), ris_e4e.resize(resize_to))[1]
                e4e_metrics_text += f'\nRIS Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Hair Score: \t{clip_score}'
            else:
                ris_e4e=None
            
            e4e_output = [invert_e4e, mapped_e4e, gd_e4e, ris_e4e, e4e_metrics_text]
        else:
            e4e_output = [None, None, None, None, e4e_metrics_text]
        output_imgs.extend(e4e_output)
        if 'PTI' in inverter_bools:
            pti_output = None, None, None, None
            manipulator.set_real_img_projection(src, inv_mode='w+', pti_mode='s')
        else:
            pti_output = None, None, None, None
        output_imgs.extend(pti_output)
    return output_imgs