Spaces:
Runtime error
Runtime error
import sys | |
import os | |
import torch | |
sys.path.append(".") | |
from gradio_wrapper.gradio_options import GradioTestOptions | |
from models.hyperstyle.utils.model_utils import load_model | |
from models.hyperstyle.utils.common import tensor2im | |
from models.hyperstyle.utils.inference_utils import run_inversion | |
from hyperstyle_global_directions.edit import load_direction_calculator, edit_image | |
from torchvision import transforms | |
import gradio as gr | |
from utils.alignment import align_face | |
import dlib | |
from argparse import Namespace | |
from mapper.styleclip_mapper import StyleCLIPMapper | |
from PIL import Image | |
opts_args = ['--no_fine_mapper'] | |
opts = GradioTestOptions().parse(opts_args) | |
mapper_dict = { | |
'afro':'./pretrained_models/styleCLIP_mappers/afro_hairstyle.pt', | |
'bob':'./pretrained_models/styleCLIP_mappers/bob_hairstyle.pt', | |
'bowl':'./pretrained_models/styleCLIP_mappers/bowl_hairstyle.pt', | |
'buzz':'./pretrained_models/styleCLIP_mappers/buzz_hairstyle.pt', | |
'caesar':'./pretrained_models/styleCLIP_mappers/caesar_hairstyle.pt', | |
'crew':'./pretrained_models/styleCLIP_mappers/crew_hairstyle.pt', | |
'pixie':'./pretrained_models/styleCLIP_mappers/pixie_hairstyle.pt', | |
'straight':'./pretrained_models/styleCLIP_mappers/straight_hairstyle.pt', | |
'undercut':'./pretrained_models/styleCLIP_mappers/undercut_hairstyle.pt', | |
'wavy':'./pretrained_models/styleCLIP_mappers/wavy_hairstyle.pt' | |
} | |
mapper_descs = { | |
'afro':'A face with an afro', | |
'bob':'A face with a bob-cut hairstyle', | |
'bowl':'A face with a bowl cut hairstyle', | |
'buzz':'A face with a buzz cut hairstyle', | |
'caesar':'A face with a caesar cut hairstyle', | |
'crew':'A face with a crew cut hairstyle', | |
'pixie':'A face with a pixie cut hairstyle', | |
'straight':'A face with a straight hair hairstyle', | |
'undercut':'A face with a undercut hairstyle', | |
'wavy':'A face with a wavy hair hairstyle', | |
} | |
predictor = dlib.shape_predictor("./pretrained_models/hyperstyle/shape_predictor_68_face_landmarks.lfs.dat") | |
hyperstyle, hyperstyle_args = load_model(opts.hyperstyle_checkpoint_path, update_opts=opts) | |
resize_amount = (256, 256) if hyperstyle_args.resize_outputs else (hyperstyle_args.output_size, hyperstyle_args.output_size) | |
im2tensor_transforms = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | |
direction_calculator = load_direction_calculator(opts) | |
with gr.Blocks() as demo: | |
with gr.Row() as row: | |
with gr.Column() as inputs: | |
source = gr.Image(label="Image to Map", type='filepath') | |
align = gr.Checkbox(True, label='Align Image') | |
inverter_bools = gr.CheckboxGroup(["Hyperstyle", "E4E"], value=['Hyperstyle'], label='Inverter Choices') | |
n_hyperstyle_iterations = gr.Number(5, label='Number of Iterations For Hyperstyle', precision=0) | |
with gr.Box(): | |
mapper_bool = gr.Checkbox(True, label='Output Mapper Result') | |
with gr.Box() as mapper_opts: | |
mapper_choice = gr.Dropdown(list(mapper_dict.keys()), value='afro', label='What Hairstyle Mapper to Use?') | |
mapper_alpha = gr.Slider(minimum=-0.5, maximum=0.5, value=0.1, step=0.01, label='Strength of Mapper Alpha',) | |
with gr.Box(): | |
gd_bool = gr.Checkbox(False, label='Output Global Direction Result') | |
with gr.Box(visible=False) as gd_opts: | |
neutral_text = gr.Text(value='A face with hair', label='Neutral Text') | |
target_text = gr.Text(value=mapper_descs['afro'], label='Target Text') | |
alpha = gr.Slider(minimum=-10.0, maximum=10.0, value=4.1, step=0.1, label="Alpha for Global Direction") | |
beta = gr.Slider(minimum=0.0, maximum=0.30, value=0.15, step=0.01, label="Beta for Global Direction") | |
submit_button = gr.Button("Edit Image") | |
with gr.Column() as outputs: | |
with gr.Row() as hyperstyle_images: | |
output_hyperstyle_mapper = gr.Image(type='pil', label="Hyperstyle Mapper") | |
output_hyperstyle_gd = gr.Image(type='pil', label="Hyperstyle Global Directions", visible=False) | |
with gr.Row(visible=False) as e4e_images: | |
output_e4e_mapper = gr.Image(type='pil', label="E4E Mapper") | |
output_e4e_gd = gr.Image(type='pil', label="E4E Global Directions", visible=False) | |
def n_iter_change(number): | |
if number < 0: | |
return 0 | |
else: | |
return number | |
def mapper_change(new_mapper): | |
return mapper_descs[new_mapper] | |
def inverter_toggles(bools): | |
e4e_bool = 'E4E' in bools | |
hyperstyle_bool = 'Hyperstyle' in bools | |
return { | |
hyperstyle_images: gr.update(visible=hyperstyle_bool), | |
e4e_images: gr.update(visible=e4e_bool), | |
n_hyperstyle_iterations: gr.update(visible=hyperstyle_bool) | |
} | |
def mapper_toggles(bool): | |
return { | |
mapper_opts: gr.update(visible=bool), | |
output_hyperstyle_mapper: gr.update(visible=bool), | |
output_e4e_mapper: gr.update(visible=bool) | |
} | |
def gd_toggles(bool): | |
return { | |
gd_opts: gr.update(visible=bool), | |
output_hyperstyle_gd: gr.update(visible=bool), | |
output_e4e_gd: gr.update(visible=bool) | |
} | |
n_hyperstyle_iterations.change(n_iter_change, n_hyperstyle_iterations, n_hyperstyle_iterations) | |
mapper_choice.change(mapper_change, mapper_choice, [target_text]) | |
inverter_bools.change(inverter_toggles, inverter_bools, [hyperstyle_images, e4e_images, n_hyperstyle_iterations]) | |
mapper_bool.change(mapper_toggles, mapper_bool, [mapper_opts, output_hyperstyle_mapper, output_e4e_mapper]) | |
gd_bool.change(gd_toggles, gd_bool, [gd_opts, output_hyperstyle_gd, output_e4e_gd]) | |
def map_latent(mapper, inputs, stylespace=False, weight_deltas=None, strength=0.1): | |
w = inputs.cuda() | |
with torch.no_grad(): | |
if stylespace: | |
delta = mapper.mapper(w) | |
w_hat = [c + strength * delta_c for (c, delta_c) in zip(w, delta)] | |
x_hat, _, w_hat = mapper.decoder([w_hat], input_is_latent=True, return_latents=True, | |
randomize_noise=False, truncation=1, input_is_stylespace=True, weights_deltas=weight_deltas) | |
else: | |
delta = mapper.mapper(w) | |
w_hat = w + strength * delta | |
x_hat, w_hat, _ = mapper.decoder([w_hat], input_is_latent=True, return_latents=True, | |
randomize_noise=False, truncation=1, weights_deltas=weight_deltas) | |
result_batch = (x_hat, w_hat) | |
return result_batch | |
def submit( | |
src, align_img, inverter_bools, n_iterations, | |
mapper_bool, mapper_choice, mapper_alpha, | |
gd_bool, neutral_text, target_text, alpha, beta, | |
): | |
torch.cuda.empty_cache() | |
opts.checkpoint_path = mapper_dict[mapper_choice] | |
ckpt = torch.load(mapper_dict[mapper_choice], map_location='cpu') | |
mapper_args = ckpt['opts'] | |
mapper_args.update(vars(opts)) | |
mapper_args = Namespace(**mapper_args) | |
mapper = StyleCLIPMapper(mapper_args) | |
mapper.eval() | |
mapper.cuda() | |
with torch.no_grad(): | |
output_imgs = [] | |
if align_img: | |
input_img = align_face(src, predictor) | |
else: | |
input_img = Image.open(src).convert('RGB') | |
input_img = im2tensor_transforms(input_img).cuda() | |
if gd_bool: | |
opts.neutral_text = neutral_text | |
opts.target_text = target_text | |
opts.alpha = alpha | |
opts.beta = beta | |
if 'Hyperstyle' in inverter_bools: | |
hyperstyle_batch, hyperstyle_latents, hyperstyle_deltas, _ = run_inversion(input_img.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False) | |
if mapper_bool: | |
mapped_hyperstyle, _ = map_latent(mapper, hyperstyle_latents, stylespace=False, weight_deltas=hyperstyle_deltas, strength=mapper_alpha) | |
mapped_hyperstyle = tensor2im(mapped_hyperstyle[0]) | |
else: | |
mapped_hyperstyle = None | |
if gd_bool: | |
gd_hyperstyle = edit_image(_, hyperstyle_latents[0], hyperstyle.decoder, direction_calculator, opts, hyperstyle_deltas)[0] | |
gd_hyperstyle = tensor2im(gd_hyperstyle) | |
else: | |
gd_hyperstyle = None | |
hyperstyle_output = [mapped_hyperstyle,gd_hyperstyle] | |
else: | |
hyperstyle_output = [None, None] | |
output_imgs.extend(hyperstyle_output) | |
if 'E4E' in inverter_bools: | |
e4e_batch, e4e_latents = hyperstyle.w_invert(input_img.unsqueeze(0)) | |
e4e_deltas = None | |
if mapper_bool: | |
mapped_e4e, _ = map_latent(mapper, e4e_latents, stylespace=False, weight_deltas=e4e_deltas, strength=mapper_alpha) | |
mapped_e4e = tensor2im(mapped_e4e[0]) | |
else: | |
mapped_e4e = None | |
if gd_bool: | |
gd_e4e = edit_image(_, e4e_latents[0], hyperstyle.decoder, direction_calculator, opts, e4e_deltas)[0] | |
gd_e4e = tensor2im(gd_e4e) | |
else: | |
gd_e4e = None | |
e4e_output = [mapped_e4e, gd_e4e] | |
else: | |
e4e_output = [None, None] | |
output_imgs.extend(e4e_output) | |
return output_imgs | |
submit_button.click( | |
submit, | |
[ | |
source, align, inverter_bools, n_hyperstyle_iterations, | |
mapper_bool, mapper_choice, mapper_alpha, | |
gd_bool, neutral_text, target_text, alpha, beta, | |
], | |
[output_hyperstyle_mapper, output_hyperstyle_gd, output_e4e_mapper, output_e4e_gd] | |
) | |
demo.launch() |