inversion_testing / run_on_batch.py
ethanNeuralImage's picture
start implementing PTI, RIS implemented
ab189a8
import sys
import os
import torch
from metrics.metrics import ClipHair
sys.path.append(".")
from gradio_wrapper.gradio_options import GradioTestOptions
from models.hyperstyle.utils.model_utils import load_model
from models.hyperstyle.utils.common import tensor2im
from models.hyperstyle.utils.inference_utils import run_inversion
from hyperstyle_global_directions.edit import load_direction_calculator, edit_image
from torchvision import transforms
import gradio as gr
from utils.alignment import align_face
import dlib
from argparse import Namespace
from mapper.styleclip_mapper import StyleCLIPMapper
import ris.spherical_kmeans as spherical_kmeans
from ris.blend import blend_latents
from ris.model import Generator as RIS_Generator
from models.pti.manipulator import Manipulator
from models.pti.wrapper import Generator_wrapper
#from models.pti.e4e_projection import projection
from metrics import FaceMetric
from metrics.criteria.clip_loss import CLIPLoss
import clip
from PIL import Image
opts_args = ['--no_fine_mapper']
opts = GradioTestOptions().parse(opts_args)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
opts.device= device
mapper_dict = {
'afro':'./pretrained_models/styleCLIP_mappers/afro_hairstyle.pt',
'bob':'./pretrained_models/styleCLIP_mappers/bob_hairstyle.pt',
'bowl':'./pretrained_models/styleCLIP_mappers/bowl_hairstyle.pt',
'buzz':'./pretrained_models/styleCLIP_mappers/buzz_hairstyle.pt',
'caesar':'./pretrained_models/styleCLIP_mappers/caesar_hairstyle.pt',
'crew':'./pretrained_models/styleCLIP_mappers/crew_hairstyle.pt',
'pixie':'./pretrained_models/styleCLIP_mappers/pixie_hairstyle.pt',
'straight':'./pretrained_models/styleCLIP_mappers/straight_hairstyle.pt',
'undercut':'./pretrained_models/styleCLIP_mappers/undercut_hairstyle.pt',
'wavy':'./pretrained_models/styleCLIP_mappers/wavy_hairstyle.pt'
}
mapper_descs = {
'afro':'A face with an afro',
'bob':'A face with a bob-cut hairstyle',
'bowl':'A face with a bowl cut hairstyle',
'buzz':'A face with a buzz cut hairstyle',
'caesar':'A face with a caesar cut hairstyle',
'crew':'A face with a crew cut hairstyle',
'pixie':'A face with a pixie cut hairstyle',
'straight':'A face with a straight hair hairstyle',
'undercut':'A face with a undercut hairstyle',
'wavy':'A face with a wavy hair hairstyle',
}
predictor = dlib.shape_predictor("./pretrained_models/hyperstyle/shape_predictor_68_face_landmarks.lfs.dat")
hyperstyle, hyperstyle_args = load_model(opts.hyperstyle_checkpoint_path, device=device, update_opts=opts)
resize_amount = (256, 256) if hyperstyle_args.resize_outputs else (hyperstyle_args.output_size, hyperstyle_args.output_size)
im2tensor_transforms = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
direction_calculator = load_direction_calculator(opts)
ris_gen = RIS_Generator(1024, 512, 8, channel_multiplier=2).to(device).eval()
ris_ckpt = torch.load('./pretrained_models/ris/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
ris_gen.load_state_dict(ris_ckpt['g_ema'], strict=False)
lpips_metric = FaceMetric(metric_type='lpips', device=device)
ssim_metric = FaceMetric(metric_type='ms-ssim', device=device)
id_metric = FaceMetric(metric_type='id', device=device)
clip_hair = FaceMetric(metric_type='cliphair', device=device)
clip_text = CLIPLoss(hyperstyle_args)
G = Generator_wrapper('./pretrained_models/pti/ffhq.pkl', device)
manipulator = Manipulator(G, device)
def map_latent(mapper, inputs, stylespace=False, weight_deltas=None, strength=0.1):
w = inputs.to(device)
with torch.no_grad():
if stylespace:
delta = mapper.mapper(w)
w_hat = [c + strength * delta_c for (c, delta_c) in zip(w, delta)]
x_hat, _, w_hat = mapper.decoder([w_hat], input_is_latent=True, return_latents=True,
randomize_noise=False, truncation=1, input_is_stylespace=True, weights_deltas=weight_deltas)
else:
delta = mapper.mapper(w)
w_hat = w + strength * delta
x_hat, w_hat, _ = mapper.decoder([w_hat], input_is_latent=True, return_latents=True,
randomize_noise=False, truncation=1, weights_deltas=weight_deltas)
result_batch = (x_hat, w_hat)
return result_batch
def run_metrics(base_img, edited_img):
#print(base_img.shape, edited_img.shape)
#base_img = base_img.unsqueeze(0)
#edited_img = edited_img.unqueeze(0)
lpips_score = lpips_metric(base_img, edited_img)[0]
ssim_score = ssim_metric(base_img, edited_img)[0]
id_score = id_metric(base_img, edited_img)[0]
return lpips_score, ssim_score, id_score
def clip_text_metric(tensor, text):
clip_embed = torch.cat([clip.tokenize(text)]).cuda()
clip_score = 1-clip_text(tensor.unsqueeze(0), clip_embed).item()
return clip_score
def submit(
src, align_img, inverter_bools, n_iterations, invert_bool,
mapper_bool, mapper_choice, mapper_alpha,
gd_bool, neutral_text, target_text, alpha, beta,
ris_bool, ref_img,
):
if device == 'cuda': torch.cuda.empty_cache()
opts.checkpoint_path = mapper_dict[mapper_choice]
ckpt = torch.load(mapper_dict[mapper_choice], map_location='cpu')
mapper_args = ckpt['opts']
mapper_args.update(vars(opts))
mapper_args = Namespace(**mapper_args)
mapper = StyleCLIPMapper(mapper_args)
mapper.eval()
mapper.to(device)
resize_to = (256, 256) if hyperstyle_args.resize_outputs else (hyperstyle_args.output_size, hyperstyle_args.output_size)
with torch.no_grad():
output_imgs = []
if align_img:
input_img = align_face(src, predictor)
else:
input_img = Image.open(src).convert('RGB')
input_img = im2tensor_transforms(input_img).to(device)
if gd_bool:
opts.neutral_text = neutral_text
opts.target_text = target_text
opts.alpha = alpha
opts.beta = beta
if ris_bool:
if align_img:
ref_input = align_face(ref_img, predictor)
else:
ref_input = Image.open(src).convert('RGB')
ref_input = im2tensor_transforms(ref_input).to(device)
hyperstyle_metrics_text = ''
if 'Hyperstyle' in inverter_bools:
hyperstyle_batch, hyperstyle_latents, hyperstyle_deltas, _ = run_inversion(input_img.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
invert_hyperstyle = tensor2im(hyperstyle_batch[0])
if mapper_bool:
mapped_hyperstyle, _ = map_latent(mapper, hyperstyle_latents, stylespace=False, weight_deltas=hyperstyle_deltas, strength=mapper_alpha)
clip_score = clip_text_metric(mapped_hyperstyle[0], mapper_args.description)
mapped_hyperstyle = tensor2im(mapped_hyperstyle[0])
lpips_score, ssim_score, id_score = run_metrics(invert_hyperstyle.resize(resize_to), mapped_hyperstyle.resize(resize_to))
hyperstyle_metrics_text += f'\nMapper Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
else:
mapped_hyperstyle = None
if gd_bool:
gd_hyperstyle = edit_image(_, hyperstyle_latents[0], hyperstyle.decoder, direction_calculator, opts, hyperstyle_deltas)
clip_score = clip_text_metric(gd_hyperstyle[0], opts.target_text)
gd_hyperstyle = tensor2im(gd_hyperstyle[0])
lpips_score, ssim_score, id_score = run_metrics(invert_hyperstyle.resize(resize_to), gd_hyperstyle.resize(resize_to))
hyperstyle_metrics_text += f'\nGlobal Direction Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
else:
gd_hyperstyle = None
if ris_bool:
ref_hyperstyle_batch, ref_hyperstyle_latents, ref_hyperstyle_deltas, _ = run_inversion(ref_input.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
blend_hyperstyle, blend_hyperstyle_latents = blend_latents(hyperstyle_latents, ref_hyperstyle_latents,
src_deltas=hyperstyle_deltas, ref_deltas=ref_hyperstyle_deltas,
generator=ris_gen, device=device)
ris_hyperstyle = tensor2im(blend_hyperstyle[0])
lpips_score, ssim_score, id_score = run_metrics(invert_hyperstyle.resize(resize_to), ris_hyperstyle.resize(resize_to))
clip_score = clip_hair(invert_hyperstyle.resize(resize_to), ris_hyperstyle.resize(resize_to))[1]
hyperstyle_metrics_text += f'\nRIS Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Hair Score: \t{clip_score}'
else:
ris_hyperstyle=None
hyperstyle_output = [invert_hyperstyle, mapped_hyperstyle,gd_hyperstyle, ris_hyperstyle, hyperstyle_metrics_text]
else:
hyperstyle_output = [None, None, None, None, hyperstyle_metrics_text]
output_imgs.extend(hyperstyle_output)
e4e_metrics_text = ''
if 'E4E' in inverter_bools:
e4e_batch, e4e_latents = hyperstyle.w_invert(input_img.unsqueeze(0))
e4e_deltas = None
invert_e4e = tensor2im(e4e_batch[0])
if mapper_bool:
mapped_e4e, _ = map_latent(mapper, e4e_latents, stylespace=False, weight_deltas=e4e_deltas, strength=mapper_alpha)
clip_score = clip_text_metric(mapped_e4e[0], mapper_args.description)
mapped_e4e = tensor2im(mapped_e4e[0])
lpips_score, ssim_score, id_score = run_metrics(invert_e4e.resize(resize_to), mapped_e4e.resize(resize_to))
e4e_metrics_text += f'\nMapper Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
else:
mapped_e4e = None
if gd_bool:
gd_e4e = edit_image(_, e4e_latents[0], hyperstyle.decoder, direction_calculator, opts, e4e_deltas)
clip_score = clip_text_metric(gd_e4e[0], opts.target_text)
gd_e4e = tensor2im(gd_e4e[0])
lpips_score, ssim_score, id_score = run_metrics(invert_e4e.resize(resize_to), gd_e4e.resize(resize_to))
e4e_metrics_text += f'\nGlobal Direction Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Text Score: \t{clip_score}'
else:
gd_e4e = None
if ris_bool:
ref_e4e_batch, ref_e4e_latents, = hyperstyle.w_invert(ref_input.unsqueeze(0))
ref_e4e_deltas= None
blend_e4e, blend_e4e_latents = blend_latents(e4e_latents, ref_e4e_latents,
src_deltas=None, ref_deltas=None,
generator=ris_gen, device=device)
ris_e4e = tensor2im(blend_e4e[0])
lpips_score, ssim_score, id_score = run_metrics(invert_e4e.resize(resize_to), ris_e4e.resize(resize_to))
clip_score = clip_hair(invert_e4e.resize(resize_to), ris_e4e.resize(resize_to))[1]
e4e_metrics_text += f'\nRIS Metrics:\n\tLPIPS: \t{lpips_score} \n\tSSIM: \t{ssim_score}\n\tID Score: \t{id_score}\n\tCLIP Hair Score: \t{clip_score}'
else:
ris_e4e=None
e4e_output = [invert_e4e, mapped_e4e, gd_e4e, ris_e4e, e4e_metrics_text]
else:
e4e_output = [None, None, None, None, e4e_metrics_text]
output_imgs.extend(e4e_output)
if 'PTI' in inverter_bools:
pti_output = None, None, None, None
manipulator.set_real_img_projection(src, inv_mode='w+', pti_mode='s')
else:
pti_output = None, None, None, None
output_imgs.extend(pti_output)
return output_imgs