ethanNeuralImage's picture
initial commit -- gradio
07ecdd9
raw
history blame
9.96 kB
import sys
import os
import torch
sys.path.append(".")
from gradio_wrapper.gradio_options import GradioTestOptions
from models.hyperstyle.utils.model_utils import load_model
from models.hyperstyle.utils.common import tensor2im
from models.hyperstyle.utils.inference_utils import run_inversion
from hyperstyle_global_directions.edit import load_direction_calculator, edit_image
from torchvision import transforms
import gradio as gr
from utils.alignment import align_face
import dlib
from argparse import Namespace
from mapper.styleclip_mapper import StyleCLIPMapper
from PIL import Image
opts_args = ['--no_fine_mapper']
opts = GradioTestOptions().parse(opts_args)
mapper_dict = {
'afro':'./pretrained_models/styleCLIP_mappers/afro_hairstyle.pt',
'bob':'./pretrained_models/styleCLIP_mappers/bob_hairstyle.pt',
'bowl':'./pretrained_models/styleCLIP_mappers/bowl_hairstyle.pt',
'buzz':'./pretrained_models/styleCLIP_mappers/buzz_hairstyle.pt',
'caesar':'./pretrained_models/styleCLIP_mappers/caesar_hairstyle.pt',
'crew':'./pretrained_models/styleCLIP_mappers/crew_hairstyle.pt',
'pixie':'./pretrained_models/styleCLIP_mappers/pixie_hairstyle.pt',
'straight':'./pretrained_models/styleCLIP_mappers/straight_hairstyle.pt',
'undercut':'./pretrained_models/styleCLIP_mappers/undercut_hairstyle.pt',
'wavy':'./pretrained_models/styleCLIP_mappers/wavy_hairstyle.pt'
}
predictor = dlib.shape_predictor("./pretrained_models/hyperstyle/shape_predictor_68_face_landmarks.dat")
hyperstyle, hyperstyle_args = load_model(opts.hyperstyle_checkpoint_path, update_opts=opts)
resize_amount = (256, 256) if hyperstyle_args.resize_outputs else (hyperstyle_args.output_size, hyperstyle_args.output_size)
im2tensor_transforms = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
direction_calculator = load_direction_calculator(opts)
ckpt = torch.load(mapper_dict['afro'], map_location='cpu')
opts.checkpoint_path = mapper_dict['afro']
mapper_args = ckpt['opts']
mapper_args.update(vars(opts))
mapper_args = Namespace(**mapper_args)
mapper = StyleCLIPMapper(mapper_args)
mapper.eval()
mapper.cuda()
def change_mapper(desc):
global mapper
global mapper_args
mapper = None
ckpt = None
mapper_args = None
torch.cuda.empty_cache()
opts.checkpoint_path = mapper_dict[desc]
ckpt = torch.load(mapper_dict[desc], map_location='cpu')
mapper_args = ckpt['opts']
mapper_args.update(vars(opts))
mapper_args = Namespace(**mapper_args)
mapper = StyleCLIPMapper(mapper_args)
mapper.eval()
mapper.cuda()
with gr.Blocks() as demo:
with gr.Row() as row:
with gr.Column() as inputs:
source = gr.Image(label="Image to Map", type='filepath')
align = gr.Checkbox(True, label='Align Image')
inverter_bools = gr.CheckboxGroup(["Hyperstyle", "E4E"], value=['Hyperstyle'], label='Inverter Choices')
n_hyperstyle_iterations = gr.Number(3, label='Number of Iterations For Hyperstyle', precision=0)
with gr.Box():
mapper_bool = gr.Checkbox(True, label='Output Mapper Result')
with gr.Box() as mapper_opts:
mapper_choice = gr.Dropdown(['afro', 'bob', 'bowl', 'buzz', 'caesar', 'crew', 'pixie', 'straight', 'undercut', 'wavy'], value='afro', label='What Hairstyle Mapper to Use?')
mapper_alpha = gr.Slider(minimum=-0.5, maximum=0.5, value=0.01, step=0.1, label='Strength of Mapper Alpha',)
with gr.Box():
gd_bool = gr.Checkbox(False, label='Output Global Direction Result')
with gr.Box(visible=False) as gd_opts:
neutral_text = gr.Text(value='A face with hair', label='Neutral Text')
target_text = gr.Text(value=mapper_args.description, label='Target Text')
alpha = gr.Slider(minimum=-10.0, maximum=10.0, value=4.1, step=0.1, label="Alpha for Global Direction")
beta = gr.Slider(minimum=0.0, maximum=0.30, value=0.15, step=0.01, label="Beta for Global Direction")
submit_button = gr.Button("Edit Image")
with gr.Column() as outputs:
with gr.Row() as hyperstyle_images:
output_hyperstyle_mapper = gr.Image(type='pil', label="Hyperstyle Mapper")
output_hyperstyle_gd = gr.Image(type='pil', label="Hyperstyle Global Directions", visible=False)
with gr.Row(visible=False) as e4e_images:
output_e4e_mapper = gr.Image(type='pil', label="E4E Mapper")
output_e4e_gd = gr.Image(type='pil', label="E4E Global Directions", visible=False)
def mapper_change(new_mapper):
change_mapper(new_mapper)
return mapper_args.description
def inverter_toggles(bools):
e4e_bool = 'E4E' in bools
hyperstyle_bool = 'Hyperstyle' in bools
return {
hyperstyle_images: gr.update(visible=hyperstyle_bool),
e4e_images: gr.update(visible=e4e_bool),
n_hyperstyle_iterations: gr.update(visible=hyperstyle_bool)
}
def mapper_toggles(bool):
return {
mapper_opts: gr.update(visible=bool),
output_hyperstyle_mapper: gr.update(visible=bool),
output_e4e_mapper: gr.update(visible=bool)
}
def gd_toggles(bool):
return {
gd_opts: gr.update(visible=bool),
output_hyperstyle_gd: gr.update(visible=bool),
output_e4e_gd: gr.update(visible=bool)
}
mapper_choice.change(mapper_change, mapper_choice, [target_text])
inverter_bools.change(inverter_toggles, inverter_bools, [hyperstyle_images, e4e_images, n_hyperstyle_iterations])
mapper_bool.change(mapper_toggles, mapper_bool, [mapper_opts, output_hyperstyle_mapper, output_e4e_mapper])
gd_bool.change(gd_toggles, gd_bool, [gd_opts, output_hyperstyle_gd, output_e4e_gd])
def map_latent(inputs, stylespace=False, weight_deltas=None, strength=0.1):
w = inputs.cuda()
with torch.no_grad():
if stylespace:
delta = mapper.mapper(w)
w_hat = [c + strength * delta_c for (c, delta_c) in zip(w, delta)]
x_hat, _, w_hat = mapper.decoder([w_hat], input_is_latent=True, return_latents=True,
randomize_noise=False, truncation=1, input_is_stylespace=True, weights_deltas=weight_deltas)
else:
delta = mapper.mapper(w)
w_hat = w + strength * delta
x_hat, w_hat, _ = mapper.decoder([w_hat], input_is_latent=True, return_latents=True,
randomize_noise=False, truncation=1, weights_deltas=weight_deltas)
result_batch = (x_hat, w_hat)
return result_batch
def submit(
src, align_img, inverter_bools, n_iterations,
mapper_bool, mapper_choice, mapper_alpha,
gd_bool, neutral_text, target_text, alpha, beta,
):
torch.cuda.empty_cache()
with torch.no_grad():
output_imgs = []
if align_img:
input_img = align_face(src, predictor)
else:
input_img = Image.open(src).convert('RGB')
input_img = im2tensor_transforms(input_img).cuda()
if gd_bool:
opts.neutral_text = neutral_text
opts.target_text = target_text
opts.alpha = alpha
opts.beta = beta
if 'Hyperstyle' in inverter_bools:
hyperstyle_batch, hyperstyle_latents, hyperstyle_deltas, _ = run_inversion(input_img.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
if mapper_bool:
mapped_hyperstyle, _ = map_latent(hyperstyle_latents, stylespace=False, weight_deltas=hyperstyle_deltas, strength=mapper_alpha)
mapped_hyperstyle = tensor2im(mapped_hyperstyle[0])
else:
mapped_hyperstyle = None
if gd_bool:
gd_hyperstyle = edit_image(_, hyperstyle_latents[0], hyperstyle.decoder, direction_calculator, opts, hyperstyle_deltas)[0]
gd_hyperstyle = tensor2im(gd_hyperstyle)
else:
gd_hyperstyle = None
hyperstyle_output = [mapped_hyperstyle,gd_hyperstyle]
else:
hyperstyle_output = [None, None]
output_imgs.extend(hyperstyle_output)
if 'E4E' in inverter_bools:
e4e_batch, e4e_latents = hyperstyle.w_invert(input_img.unsqueeze(0))
e4e_deltas = None
if mapper_bool:
mapped_e4e, _ = map_latent(e4e_latents, stylespace=False, weight_deltas=e4e_deltas, strength=mapper_alpha)
mapped_e4e = tensor2im(mapped_e4e[0])
else:
mapped_e4e = None
if gd_bool:
gd_e4e = edit_image(_, e4e_latents[0], hyperstyle.decoder, direction_calculator, opts, e4e_deltas)[0]
gd_e4e = tensor2im(gd_e4e)
else:
gd_e4e = None
e4e_output = [mapped_e4e, gd_e4e]
else:
e4e_output = [None, None]
output_imgs.extend(e4e_output)
return output_imgs
submit_button.click(
submit,
[
source, align, inverter_bools, n_hyperstyle_iterations,
mapper_bool, mapper_choice, mapper_alpha,
gd_bool, neutral_text, target_text, alpha, beta,
],
[output_hyperstyle_mapper, output_hyperstyle_gd, output_e4e_mapper, output_e4e_gd]
)
demo.launch()