import sys import os import torch from metrics.metrics import ClipHair sys.path.append(".") from gradio_wrapper.gradio_options import GradioTestOptions from models.hyperstyle.utils.model_utils import load_model from models.hyperstyle.utils.common import tensor2im from models.hyperstyle.utils.inference_utils import run_inversion from hyperstyle_global_directions.edit import load_direction_calculator, edit_image from torchvision import transforms import gradio as gr from utils.alignment import align_face import dlib from argparse import Namespace from mapper.styleclip_mapper import StyleCLIPMapper import ris.spherical_kmeans as spherical_kmeans from ris.blend import blend_latents from ris.model import Generator as RIS_Generator #from models.pti.manipulator import Manipulator #from models.pti.wrapper import Generator as Generator_wrapper #from models.pti.e4e_projection import projection from metrics import FaceMetric from metrics.criteria.clip_loss import CLIPLoss import clip from PIL import Image opts_args = ['--no_fine_mapper'] opts = GradioTestOptions().parse(opts_args) device = 'cuda' if torch.cuda.is_available() else 'cpu' opts.device= device mapper_dict = { 'afro':'./pretrained_models/styleCLIP_mappers/afro_hairstyle.pt', 'bob':'./pretrained_models/styleCLIP_mappers/bob_hairstyle.pt', 'bowl':'./pretrained_models/styleCLIP_mappers/bowl_hairstyle.pt', 'buzz':'./pretrained_models/styleCLIP_mappers/buzz_hairstyle.pt', 'caesar':'./pretrained_models/styleCLIP_mappers/caesar_hairstyle.pt', 'crew':'./pretrained_models/styleCLIP_mappers/crew_hairstyle.pt', 'pixie':'./pretrained_models/styleCLIP_mappers/pixie_hairstyle.pt', 'straight':'./pretrained_models/styleCLIP_mappers/straight_hairstyle.pt', 'undercut':'./pretrained_models/styleCLIP_mappers/undercut_hairstyle.pt', 'wavy':'./pretrained_models/styleCLIP_mappers/wavy_hairstyle.pt' } mapper_descs = { 'afro':'A face with an afro', 'bob':'A face with a bob-cut hairstyle', 'bowl':'A face with a bowl cut hairstyle', 'buzz':'A face with a buzz cut hairstyle', 'caesar':'A face with a caesar cut hairstyle', 'crew':'A face with a crew cut hairstyle', 'pixie':'A face with a pixie cut hairstyle', 'straight':'A face with a straight hair hairstyle', 'undercut':'A face with a undercut hairstyle', 'wavy':'A face with a wavy hair hairstyle', } predictor = dlib.shape_predictor("./pretrained_models/hyperstyle/shape_predictor_68_face_landmarks.lfs.dat") hyperstyle, hyperstyle_args = load_model(opts.hyperstyle_checkpoint_path, device=device, update_opts=opts) resize_amount = (256, 256) if hyperstyle_args.resize_outputs else (hyperstyle_args.output_size, hyperstyle_args.output_size) im2tensor_transforms = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) direction_calculator = load_direction_calculator(opts) ris_gen = RIS_Generator(1024, 512, 8, channel_multiplier=2).to(device).eval() ris_ckpt = torch.load('./pretrained_models/ris/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage) ris_gen.load_state_dict(ris_ckpt['g_ema'], strict=False) lpips_metric = FaceMetric(metric_type='lpips', device=device) ssim_metric = FaceMetric(metric_type='ms-ssim', device=device) id_metric = FaceMetric(metric_type='id', device=device) clip_hair = FaceMetric(metric_type='cliphair', device=device) clip_text = CLIPLoss(hyperstyle_args) #G = Generator_wrapper('./pretrained_models/pti/ffhq.pkl', device) #manipulator = Manipulator(G, device) with gr.Blocks() as demo: with gr.Row() as row: with gr.Column() as inputs: source = gr.Image(label="Image to Map", type='filepath') align = gr.Checkbox(True, label='Align Image') inverter_bools = gr.CheckboxGroup(["Hyperstyle", "E4E"], value=['Hyperstyle'], label='Inverter Choices') n_hyperstyle_iterations = gr.Number(5, label='Number of Iterations For Hyperstyle', precision=0) with gr.Box(): invert_bool = gr.Checkbox(False, label='Output Inverter Result') with gr.Box(): mapper_bool = gr.Checkbox(True, label='Output Mapper Result') with gr.Box() as mapper_opts: mapper_choice = gr.Dropdown(list(mapper_dict.keys()), value='afro', label='What Hairstyle Mapper to Use?') mapper_alpha = gr.Slider(minimum=-0.5, maximum=0.5, value=0.1, step=0.01, label='Strength of Mapper Alpha',) with gr.Box(): gd_bool = gr.Checkbox(False, label='Output Global Direction Result') with gr.Box(visible=False) as gd_opts: neutral_text = gr.Text(value='A face with hair', label='Neutral Text') target_text = gr.Text(value=mapper_descs['afro'], label='Target Text') alpha = gr.Slider(minimum=-10.0, maximum=10.0, value=4.1, step=0.1, label="Alpha for Global Direction") beta = gr.Slider(minimum=0.0, maximum=0.30, value=0.15, step=0.01, label="Beta for Global Direction") with gr.Box(): ris_bool = gr.Checkbox(False, label='Output RIS Result') with gr.Box(visible=False) as ris_opts: ref_img = gr.Image(label='Refrence Image for Hair', type='filepath') submit_button = gr.Button("Edit Image") with gr.Column() as outputs: with gr.Row() as hyperstyle_images: output_hyperstyle_invert = gr.Image(type='pil', label="Hyperstyle Inverted", visible=False) output_hyperstyle_mapper = gr.Image(type='pil', label="Hyperstyle Mapper") output_hyperstyle_gd = gr.Image(type='pil', label="Hyperstyle Global Directions", visible=False) output_hyperstyle_ris = gr.Image(type='pil', label='Hyperstyle RIS', visible=False) with gr.Row() as hyperstyle_metrics: output_hypersyle_metrics = gr.Text(label='Hyperstyle Metrics') with gr.Row(visible=False) as e4e_images: output_e4e_invert = gr.Image(type='pil', label="E4E Inverted", visible=False) output_e4e_mapper = gr.Image(type='pil', label="E4E Mapper") output_e4e_gd = gr.Image(type='pil', label="E4E Global Directions", visible=False) output_e4e_ris = gr.Image(type='pil', label='E4E RIS', visible=False) with gr.Row(visible=False) as e4e_metrics: output_e4e_metrics = gr.Text(label='E4E Metrics') with gr.Row(visible=False) as pti_images: output_pti_invert = gr.Image(type='pil', label="PTI Inverted", visible=False) output_pti_mapper = gr.Image(type='pil', label="PTI Mapper") output_pti_gd = gr.Image(type='pil', label="PTI Global Directions", visible=False) output_pti_ris = gr.Image(type='pil', label='PTI RIS', visible=False) with gr.Row(visible=False) as pti_metrics: output_pti_metrics = gr.Text(label='PTI Metrics') def n_iter_change(number): if number < 0: return 0 else: return number def mapper_change(new_mapper): return mapper_descs[new_mapper] def inverter_toggles(bools): e4e_bool = 'E4E' in bools hyperstyle_bool = 'Hyperstyle' in bools return { hyperstyle_images: gr.update(visible=hyperstyle_bool), hyperstyle_metrics: gr.update(visible=hyperstyle_bool), e4e_images: gr.update(visible=e4e_bool), e4e_metrics: gr.update(visible=e4e_bool), n_hyperstyle_iterations: gr.update(visible=hyperstyle_bool) } def outp_toggles(bool): return { output_hyperstyle_invert: gr.update(visible=bool), output_e4e_invert: gr.update(visible=bool) } def mapper_toggles(bool): return { mapper_opts: gr.update(visible=bool), output_hyperstyle_mapper: gr.update(visible=bool), output_e4e_mapper: gr.update(visible=bool) } def gd_toggles(bool): return { gd_opts: gr.update(visible=bool), output_hyperstyle_gd: gr.update(visible=bool), output_e4e_gd: gr.update(visible=bool) } def ris_toggles(bool): return { ris_opts: gr.update(visible=bool), output_hyperstyle_ris: gr.update(visible=bool), output_e4e_ris: gr.update(visible=bool) } n_hyperstyle_iterations.change(n_iter_change, n_hyperstyle_iterations, n_hyperstyle_iterations) mapper_choice.change(mapper_change, mapper_choice, [target_text]) inverter_bools.change(inverter_toggles, inverter_bools, [hyperstyle_images, hyperstyle_metrics, e4e_images, e4e_metrics, n_hyperstyle_iterations]) invert_bool.change(outp_toggles, invert_bool, [output_hyperstyle_invert, output_e4e_invert]) mapper_bool.change(mapper_toggles, mapper_bool, [mapper_opts, output_hyperstyle_mapper, output_e4e_mapper]) gd_bool.change(gd_toggles, gd_bool, [gd_opts, output_hyperstyle_gd, output_e4e_gd]) ris_bool.change(ris_toggles, ris_bool, [ris_opts, output_hyperstyle_ris, output_e4e_ris]) def map_latent(mapper, inputs, stylespace=False, weight_deltas=None, strength=0.1): w = inputs.to(device) with torch.no_grad(): if stylespace: delta = mapper.mapper(w) w_hat = [c + strength * delta_c for (c, delta_c) in zip(w, delta)] x_hat, _, w_hat = mapper.decoder([w_hat], input_is_latent=True, return_latents=True, randomize_noise=False, truncation=1, input_is_stylespace=True, weights_deltas=weight_deltas) else: delta = mapper.mapper(w) w_hat = w + strength * delta x_hat, w_hat, _ = mapper.decoder([w_hat], input_is_latent=True, return_latents=True, randomize_noise=False, truncation=1, weights_deltas=weight_deltas) result_batch = (x_hat, w_hat) return result_batch def run_metrics(base_img, edited_img): #print(base_img.shape, edited_img.shape) #base_img = base_img.unsqueeze(0) #edited_img = edited_img.unqueeze(0) lpips_score = lpips_metric(base_img, edited_img)[0] ssim_score = ssim_metric(base_img, edited_img)[0] id_score = id_metric(base_img, edited_img)[0] return lpips_score, ssim_score, id_score def clip_text_metric(tensor, text): clip_embed = torch.cat([clip.tokenize(text)]).cpu() clip_score = 1-clip_text(tensor.unsqueeze(0), clip_embed).item() return clip_score def submit( src, align_img, inverter_bools, n_iterations, invert_bool, mapper_bool, mapper_choice, mapper_alpha, gd_bool, neutral_text, target_text, alpha, beta, ris_bool, ref_img, ): if device == 'cuda': torch.cuda.empty_cache() opts.checkpoint_path = mapper_dict[mapper_choice] ckpt = torch.load(mapper_dict[mapper_choice], map_location='cpu') mapper_args = ckpt['opts'] mapper_args.update(vars(opts)) mapper_args = Namespace(**mapper_args) mapper = StyleCLIPMapper(mapper_args) mapper.eval() mapper.to(device) resize_to = (256, 256) if hyperstyle_args.resize_outputs else (hyperstyle_args.output_size, hyperstyle_args.output_size) with torch.no_grad(): output_imgs = [] if align_img: input_img = align_face(src, predictor) else: input_img = Image.open(src).convert('RGB') input_img = im2tensor_transforms(input_img).to(device) if gd_bool: opts.neutral_text = neutral_text opts.target_text = target_text opts.alpha = alpha opts.beta = beta if ris_bool: if align_img: ref_input = align_face(ref_img, predictor) else: ref_input = Image.open(src).convert('RGB') ref_input = im2tensor_transforms(ref_input).to(device) hyperstyle_metrics_text = '' if 'Hyperstyle' in inverter_bools: hyperstyle_batch, hyperstyle_latents, hyperstyle_deltas, _ = run_inversion(input_img.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False) invert_hyperstyle = tensor2im(hyperstyle_batch[0]) if mapper_bool: mapped_hyperstyle, _ = map_latent(mapper, hyperstyle_latents, stylespace=False, weight_deltas=hyperstyle_deltas, strength=mapper_alpha) clip_score = clip_text_metric(mapped_hyperstyle[0], mapper_args.description) mapped_hyperstyle = tensor2im(mapped_hyperstyle[0]) lpips_score, ssim_score, id_score = run_metrics(invert_hyperstyle.resize(resize_to), mapped_hyperstyle.resize(resize_to)) hyperstyle_metrics_text += f'\nMapper Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}' else: mapped_hyperstyle = None if gd_bool: gd_hyperstyle = edit_image(_, hyperstyle_latents[0], hyperstyle.decoder, direction_calculator, opts, hyperstyle_deltas) clip_score = clip_text_metric(gd_hyperstyle[0], opts.target_text) gd_hyperstyle = tensor2im(gd_hyperstyle[0]) lpips_score, ssim_score, id_score = run_metrics(invert_hyperstyle.resize(resize_to), gd_hyperstyle.resize(resize_to)) hyperstyle_metrics_text += f'\nGlobal Direction Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}' else: gd_hyperstyle = None if ris_bool: ref_hyperstyle_batch, ref_hyperstyle_latents, ref_hyperstyle_deltas, _ = run_inversion(ref_input.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False) blend_hyperstyle, blend_hyperstyle_latents = blend_latents(hyperstyle_latents, ref_hyperstyle_latents, src_deltas=hyperstyle_deltas, ref_deltas=ref_hyperstyle_deltas, generator=ris_gen, device=device) ris_hyperstyle = tensor2im(blend_hyperstyle[0]) lpips_score, ssim_score, id_score = run_metrics(invert_hyperstyle.resize(resize_to), ris_hyperstyle.resize(resize_to)) clip_score = clip_hair(invert_hyperstyle.resize(resize_to), ris_hyperstyle.resize(resize_to))[1] hyperstyle_metrics_text += f'\nRIS Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Hair Score: \t{clip_score}' else: ris_hyperstyle=None hyperstyle_output = [invert_hyperstyle, mapped_hyperstyle,gd_hyperstyle, ris_hyperstyle, hyperstyle_metrics_text] else: hyperstyle_output = [None, None, None, None, hyperstyle_metrics_text] output_imgs.extend(hyperstyle_output) e4e_metrics_text = '' if 'E4E' in inverter_bools: e4e_batch, e4e_latents = hyperstyle.w_invert(input_img.unsqueeze(0)) e4e_deltas = None invert_e4e = tensor2im(e4e_batch[0]) if mapper_bool: mapped_e4e, _ = map_latent(mapper, e4e_latents, stylespace=False, weight_deltas=e4e_deltas, strength=mapper_alpha) clip_score = clip_text_metric(mapped_e4e[0], mapper_args.description) mapped_e4e = tensor2im(mapped_e4e[0]) lpips_score, ssim_score, id_score = run_metrics(invert_e4e.resize(resize_to), mapped_e4e.resize(resize_to)) e4e_metrics_text += f'\nMapper Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}' else: mapped_e4e = None if gd_bool: gd_e4e = edit_image(_, e4e_latents[0], hyperstyle.decoder, direction_calculator, opts, e4e_deltas) clip_score = clip_text_metric(gd_e4e[0], opts.target_text) gd_e4e = tensor2im(gd_e4e[0]) lpips_score, ssim_score, id_score = run_metrics(invert_e4e.resize(resize_to), gd_e4e.resize(resize_to)) e4e_metrics_text += f'\nGlobal Direction Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}' else: gd_e4e = None if ris_bool: ref_e4e_batch, ref_e4e_latents, = hyperstyle.w_invert(ref_input.unsqueeze(0)) ref_e4e_deltas= None blend_e4e, blend_e4e_latents = blend_latents(e4e_latents, ref_e4e_latents, src_deltas=None, ref_deltas=None, generator=ris_gen, device=device) ris_e4e = tensor2im(blend_e4e[0]) lpips_score, ssim_score, id_score = run_metrics(invert_e4e.resize(resize_to), ris_e4e.resize(resize_to)) clip_score = clip_hair(invert_e4e.resize(resize_to), ris_e4e.resize(resize_to))[1] e4e_metrics_text += f'\nRIS Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Hair Score: \t{clip_score}' else: ris_e4e=None e4e_output = [invert_e4e, mapped_e4e, gd_e4e, ris_e4e, e4e_metrics_text] else: e4e_output = [None, None, None, None, e4e_metrics_text] output_imgs.extend(e4e_output) return output_imgs submit_button.click( submit, [ source, align, inverter_bools, n_hyperstyle_iterations, invert_bool, mapper_bool, mapper_choice, mapper_alpha, gd_bool, neutral_text, target_text, alpha, beta, ris_bool, ref_img ], [ output_hyperstyle_invert, output_hyperstyle_mapper, output_hyperstyle_gd, output_hyperstyle_ris, output_hypersyle_metrics, output_e4e_invert, output_e4e_mapper, output_e4e_gd, output_e4e_ris, output_e4e_metrics, ] ) demo.launch()