import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch.backends.cudnn as cudnn cudnn.benchmark = True import matplotlib.pyplot as plt import numpy as np import os import sys from tqdm import tqdm as tqdm import pickle import warnings warnings.filterwarnings("ignore") from spherical_kmeans import MiniBatchSphericalKMeans as sKmeans from argparse import Namespace from torch.utils.data import Dataset, DataLoader from torchvision.models import vgg19 import glob from pathlib import Path import lpips import argparse import gc import cv2 from models.pti.e4e_projection import projection from model import * from util import * from e4e.models.psp import pSp import torchvision.transforms as transforms from torch.nn import DataParallel import torchvision.transforms.functional as TF from FaceQualityMetrics.utils import FaceMetric from hyperstyle.utils.model_utils import load_model from configs.paths_config import model_paths from models.pti.manipulator import Manipulator from models.pti.wrapper import Generator_wrapper def run_inversion(inputs, net, n_iters_per_batch, return_intermediate_results=False, resize_outputs=False, weights_deltas=None): y_hat, latent, weights_deltas, codes = None, None, weights_deltas, None if return_intermediate_results: results_batch = {idx: [] for idx in range(inputs.shape[0])} results_latent = {idx: [] for idx in range(inputs.shape[0])} results_deltas = {idx: [] for idx in range(inputs.shape[0])} else: results_batch, results_latent, results_deltas = None, None, None if weights_deltas is None: for iter in range(n_iters_per_batch): y_hat, latent, weights_deltas, codes, _ = net.forward(inputs, y_hat=y_hat, codes=codes, weights_deltas=weights_deltas, return_latents=True, resize=resize_outputs, randomize_noise=False, return_weight_deltas_and_codes=True) # weights_deltas[14]= None # weights_deltas[20]= None # weights_deltas[21]= None # weights_deltas[23]= None # weights_deltas[24]= None if return_intermediate_results: store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas) # resize input to 256 before feeding into next iteration y_hat = net.face_pool(y_hat) else: for iter in range(n_iters_per_batch): y_hat, latent, _, codes, _ = net.forward(inputs, y_hat=y_hat, codes=codes, weights_deltas=weights_deltas, return_latents=True, resize=resize_outputs, randomize_noise=False, return_weight_deltas_and_codes=True) if return_intermediate_results: store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas) # resize input to 256 before feeding into next iteration y_hat = net.face_pool(y_hat) if return_intermediate_results: return results_batch, results_latent, results_deltas return y_hat, latent, weights_deltas, codes def store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas): for idx in range(y_hat.shape[0]): results_batch[idx].append(y_hat[idx]) results_latent[idx].append(latent[idx].cpu().numpy()) results_deltas[idx].append([w[idx].cpu().numpy() if w is not None else None for w in weights_deltas]) # compute M given a style code. @torch.no_grad() def compute_M(w, weights_deltas=None, device='cuda'): M = [] # get segmentation # _, outputs = generator(w, is_cluster=1) _, outputs = generator(w, weights_deltas=weights_deltas) cluster_layer = outputs[stop_idx][0] activation = flatten_act(cluster_layer) seg_mask = clusterer.predict(activation) b,c,h,w = cluster_layer.size() # create masks for each feature all_seg_mask = [] seg_mask = torch.from_numpy(seg_mask).view(b,1,h,w,1).to(device) for key in range(n_class): # combine masks for all indices for a particular segmentation class indices = labels_map[key].view(1,1,1,1,-1) key_mask = (seg_mask == indices.to(device)).any(-1) #[b,1,h,w] all_seg_mask.append(key_mask) all_seg_mask = torch.stack(all_seg_mask, 1) # go through each activation layer and compute M for layer_idx in range(len(outputs)): layer = outputs[layer_idx][1].to(device) b,c,h,w = layer.size() layer = F.instance_norm(layer) layer = layer.pow(2) # resize the segmentation masks to current activations' resolution layer_seg_mask = F.interpolate(all_seg_mask.flatten(0,1).float(), align_corners=False, size=(h,w), mode='bilinear').view(b,-1,1,h,w) masked_layer = layer.unsqueeze(1) * layer_seg_mask # [b,k,c,h,w] masked_layer = (masked_layer.sum([3,4])/ (h*w))#[b,k,c] M.append(masked_layer.to(device)) M = torch.cat(M, -1) #[b, k, c] # softmax to assign each channel to a particular segmentation class M = F.softmax(M/.1, 1) # simple thresholding M = (M>.8).float() # zero out torgb transfers, from https://arxiv.org/abs/2011.12799 for i in range(n_class): part_M = style2list(M[:, i]) for j in range(len(part_M)): if j in rgb_layer_idx: part_M[j].zero_() part_M = list2style(part_M) M[:, i] = part_M return M #==== # for i in range(len(blend_deltas)): # if blend_deltas[i] is not None: # print(f'{i}: {part_M_mask[i].sum()}/{sum(part_M_mask[i].shape)}') # if part_M_mask[i].sum() >= sum(part_M_mask[i].shape)/2: # print(i) # blend_deltas[i] = ref_deltas[i] def tensor2img(tensor): tensor = tensor.cpu().clamp(-1, 1) img = topil(tensor.squeeze()) return img def hair_transfer_hyperstyle(source_img_path, ref_img_path): with torch.no_grad(): source_img = align_face(source_img_path, predictor=predictor) ref_img = align_face(ref_img_path, predictor=predictor) source_img = Image.fromarray(np.uint8(source_img)) ref_img = Image.fromarray(np.uint8(ref_img)) source_tensor = transform(source_img).unsqueeze(0).to(device) ref_tensor = transform(ref_img).unsqueeze(0).to(device) source_batch, source_latent, source_deltas, source_codes = run_inversion(source_tensor, net, n_iters_per_batch=5, return_intermediate_results=False) ref_batch, ref_latent, ref_deltas, ref_codes = run_inversion(ref_tensor, net, n_iters_per_batch=5, return_intermediate_results=False) source = generator.get_latent(source_latent[0].unsqueeze(0), truncation=1, is_latent=True) ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True) source_out, _ = generator(source, weights_deltas=source_deltas, randomize_noise=False) ref_out, _ = generator(ref, weights_deltas=ref_deltas, randomize_noise=False) source_M = compute_M(source, weights_deltas=source_deltas, device='cpu') ref_M = compute_M(ref, weights_deltas=ref_deltas, device='cpu') blend_deltas = source_deltas max_M = torch.max(source_M.expand_as(ref_M), ref_M) max_M = add_pose(max_M, labels2idx) idx = labels2idx['hair'] part_M = max_M[:, idx].to(device) part_M_mask = style2list(part_M) blend = style2list((add_direction(source, ref, part_M, 1.3))) blend_out, _ = generator(blend, weights_deltas=blend_deltas) source_out = tensor2img(source_out) ref_out = tensor2img(ref_out) blend_out = tensor2img(blend_out) lpips_face, _ = lpips(blend_out, source_out) ssim_face, _ = ssim(blend_out, source_out) id_face, _ = id_score(blend_out, source_out) _, lpips_hair = lpips(blend_out, ref_out) _, ssim_hair = ssim(blend_out, ref_out) _, clip_hair = clip(blend_out, source_out) out_str = f'lpips_face: {lpips_face}\nlpips_hair: {lpips_hair}\nssim_face: {ssim_face}\nssim_hair: {ssim_hair}\nid_face: {id_face}\n clip_hair: {clip_hair}' e4e_blend_out, _ = generator(blend) e4e_blend_out = tensor2img(e4e_blend_out) _, _, e4e_blend_hair_mask = lpips.parser(e4e_blend_out) source_out_np = np.array(source_out) blend_np =np.array(e4e_blend_out).astype(np.uint8) e4e_blend_hair_mask = e4e_blend_hair_mask.cpu().numpy().astype(np.uint8)*255 mask_dilate = cv2.dilate(e4e_blend_hair_mask, kernel=np.ones((50, 50), np.uint8)) mask_dilate_blur = cv2.blur(mask_dilate, ksize=(30, 30)) mask_dilate_blur = (e4e_blend_hair_mask + (255 - e4e_blend_hair_mask) / 255 * mask_dilate_blur).astype(np.uint8) face_mask = 255 - mask_dilate_blur index = np.where(face_mask > 0) cy = (np.min(index[0]) + np.max(index[0])) // 2 cx = (np.min(index[1]) + np.max(index[1])) // 2 center = (cx, cy) clone_out = cv2.seamlessClone(source_out_np, blend_np, face_mask, center, cv2.NORMAL_CLONE) return source_out, ref_out, blend_out, out_str, clone_out def hair_transfer_e4e(source_img_path, ref_img_path): with torch.no_grad(): source_img = align_face(source_img_path, predictor=predictor) ref_img = align_face(ref_img_path, predictor=predictor) source_img = Image.fromarray(np.uint8(source_img)) ref_img = Image.fromarray(np.uint8(ref_img)) source_tensor = transform(source_img).unsqueeze(0).to(device) ref_tensor = transform(ref_img).unsqueeze(0).to(device) source_batch, source_latent, source_deltas, source_codes = run_inversion(source_tensor, net, n_iters_per_batch=5, return_intermediate_results=False) ref_batch, ref_latent, ref_deltas, ref_codes = run_inversion(ref_tensor, net, n_iters_per_batch=5, return_intermediate_results=False) source = generator.get_latent(source_latent[0].unsqueeze(0), truncation=1, is_latent=True) ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True) e4e_source_out, _ = generator(source, randomize_noise=False) e4e_ref_out, _ = generator(ref, randomize_noise=False) e4e_source_M = compute_M(source, device='cpu') e4e_ref_M = compute_M(ref, device='cpu') e4e_max_M = torch.max(e4e_source_M.expand_as(e4e_ref_M), e4e_ref_M) e4e_max_M = add_pose(e4e_max_M, labels2idx) e4e_idx = labels2idx['hair'] e4e_part_M = e4e_max_M[:, e4e_idx].to(device) e4e_part_M_mask = style2list(e4e_part_M) e4e_blend = style2list((add_direction(source, ref, e4e_part_M, 1.3))) e4e_blend_out, _ = generator(e4e_blend) e4e_source_out = tensor2img(e4e_source_out) e4e_ref_out = tensor2img(e4e_ref_out) e4e_blend_out = tensor2img(e4e_blend_out) e4e_lpips_face, _ = lpips(e4e_blend_out, e4e_source_out) e4e_ssim_face, _ = ssim(e4e_blend_out, e4e_source_out) e4e_id_face, _ = id_score(e4e_blend_out, e4e_source_out) _, e4e_lpips_hair = lpips(e4e_blend_out, e4e_ref_out) _, e4e_ssim_hair = ssim(e4e_blend_out, e4e_ref_out) _, e4e_clip_hair = clip(e4e_blend_out, e4e_source_out) e4e_out_str = f'e4e_lpips_face: {e4e_lpips_face}\ne4e_lpips_hair: {e4e_lpips_hair}\ne4e_ssim_face: {e4e_ssim_face}\ne4e_ssim_hair: {e4e_ssim_hair}\ne4e_id_face: {e4e_id_face}\ne4e_ clip_hair: {e4e_clip_hair}' return e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str def hair_transfer_PTI(source_img_path, ref_img_path): ckpt = 'pretrained/ffhq.pkl' G = Generator_wrapper(ckpt, device) manipulator = Manipulator(G, device) manipulator.set_real_img_projection(source_img_path, inv_mode='w+', pti_mode='s') with torch.no_grad(): source_img = align_face(source_img_path, predictor=predictor) ref_img = align_face(ref_img_path, predictor=predictor) source_img = Image.fromarray(np.uint8(source_img)) projection(source_img, 'source', generator, device) projection(ref_img, 'ref', generator, device) source = load_source('source', generator, device) ref = load_source('ref', generator, device) e4e_source_out, _ = generator(source, randomize_noise=False) e4e_ref_out, _ = generator(ref, randomize_noise=False) e4e_source_M = compute_M(source, device='cpu') e4e_ref_M = compute_M(ref, device='cpu') e4e_max_M = torch.max(e4e_source_M.expand_as(e4e_ref_M), e4e_ref_M) e4e_max_M = add_pose(e4e_max_M, labels2idx) e4e_idx = labels2idx['hair'] e4e_part_M = e4e_max_M[:, e4e_idx].to(device) e4e_part_M_mask = style2list(e4e_part_M) e4e_blend = style2list((add_direction(source, ref, e4e_part_M, 1.3))) e4e_source_out = tensor2img(e4e_source_out) e4e_ref_out = tensor2img(e4e_ref_out) # e4e_blend_out = tensor2img(e4e_blend_out) # e4e_lpips_face, _ = lpips(e4e_blend_out, e4e_source_out) # e4e_ssim_face, _ = ssim(e4e_blend_out, e4e_source_out) # e4e_id_face, _ = id_score(e4e_blend_out, e4e_source_out) # _, e4e_lpips_hair = lpips(e4e_blend_out, e4e_ref_out) # _, e4e_ssim_hair = ssim(e4e_blend_out, e4e_ref_out) # _, e4e_clip_hair = clip(e4e_blend_out, e4e_source_out) keys = (['G.synthesis.b4.conv1.affine', 'G.synthesis.b4.torgb.affine', 'G.synthesis.b8.conv0.affine', 'G.synthesis.b8.conv1.affine', 'G.synthesis.b8.torgb.affine', 'G.synthesis.b16.conv0.affine', 'G.synthesis.b16.conv1.affine', 'G.synthesis.b16.torgb.affine', 'G.synthesis.b32.conv0.affine', 'G.synthesis.b32.conv1.affine', 'G.synthesis.b32.torgb.affine', 'G.synthesis.b64.conv0.affine', 'G.synthesis.b64.conv1.affine', 'G.synthesis.b64.torgb.affine', 'G.synthesis.b128.conv0.affine', 'G.synthesis.b128.conv1.affine', 'G.synthesis.b128.torgb.affine', 'G.synthesis.b256.conv0.affine', 'G.synthesis.b256.conv1.affine', 'G.synthesis.b256.torgb.affine', 'G.synthesis.b512.conv0.affine', 'G.synthesis.b512.conv1.affine', 'G.synthesis.b512.torgb.affine', 'G.synthesis.b1024.conv0.affine', 'G.synthesis.b1024.conv1.affine', 'G.synthesis.b1024.torgb.affine']) test_dict = dict(zip(keys, e4e_blend)) manipulator_list = [] manipulator_list.append(test_dict) all_imgs = manipulator.synthesis_from_styles(manipulator_list, 0, 1) PTI_outstr = 'PTI_outstr' blend_out = tensor2img(all_imgs[0]) return e4e_source_out, e4e_ref_out, blend_out, PTI_outstr # _, _, e4e_blend_hair_mask = lpips.parser(e4e_blend_out) # blend_out_np = np.array(blend_out) # blend_np =np.array(e4e_blend_out).astype(np.uint8) # e4e_blend_hair_mask = e4e_blend_hair_mask.cpu().numpy().astype(np.uint8)*255 # mask_dilate = cv2.dilate(e4e_blend_hair_mask, # kernel=np.ones((50, 50), np.uint8)) # mask_dilate_blur = cv2.blur(mask_dilate, ksize=(30, 30)) # mask_dilate_blur = (e4e_blend_hair_mask + (255 - e4e_blend_hair_mask) / 255 * mask_dilate_blur).astype(np.uint8) # face_mask = 255 - mask_dilate_blur # index = np.where(face_mask > 0) # cy = (np.min(index[0]) + np.max(index[0])) // 2 # cx = (np.min(index[1]) + np.max(index[1])) // 2 # center = (cx, cy) # clone_out = cv2.seamlessClone(blend_out_np, blend_np, face_mask, center, cv2.NORMAL_CLONE) # out_str = f'lpips_face: {lpips_face}\nlpips_hair: {lpips_hair}\nssim_face: {ssim_face}\nssim_hair: {ssim_hair}\nid_face: {id_face}\n clip_hair: {clip_hair}' # seg_out = torch.tensor(face_mask).float().unsqueeze(-1).repeat(1,1,3) # seg_out = seg_out.cpu().numpy().astype(np.uint8) # # seg_out*=255 # seg_out = Image.fromarray(seg_out) # # return source_out, ref_out, blend_out, out_str, e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str, clone_out, seg_out # ## Set source_tensor requires_grad=True # source_tensor.requires_grad = True # ref_tensor.requires_grad = True # ckpt = 'pretrained/ffhq.pkl' # G = Generator_wrapper(ckpt, device) # manipulator = Manipulator(G, device) # manipulator.set_real_img_projection(source_img_path, inv_mode='w+', pti_mode='s') # blend = style2list((add_direction(source_tensor, ref_tensor, part_M, 1.3))) # keys = (['G.synthesis.b4.conv1.affine', 'G.synthesis.b4.torgb.affine', 'G.synthesis.b8.conv0.affine', 'G.synthesis.b8.conv1.affine', 'G.synthesis.b8.torgb.affine', 'G.synthesis.b16.conv0.affine', 'G.synthesis.b16.conv1.affine', 'G.synthesis.b16.torgb.affine', 'G.synthesis.b32.conv0.affine', 'G.synthesis.b32.conv1.affine', 'G.synthesis.b32.torgb.affine', 'G.synthesis.b64.conv0.affine', 'G.synthesis.b64.conv1.affine', 'G.synthesis.b64.torgb.affine', 'G.synthesis.b128.conv0.affine', 'G.synthesis.b128.conv1.affine', 'G.synthesis.b128.torgb.affine', 'G.synthesis.b256.conv0.affine', 'G.synthesis.b256.conv1.affine', 'G.synthesis.b256.torgb.affine', 'G.synthesis.b512.conv0.affine', 'G.synthesis.b512.conv1.affine', 'G.synthesis.b512.torgb.affine', 'G.synthesis.b1024.conv0.affine', 'G.synthesis.b1024.conv1.affine', 'G.synthesis.b1024.torgb.affine']) # test_dict = dict(zip(keys, blend)) # manipulator_list = [] # manipulator_list.append(test_dict) # all_imgs = manipulator.synthesis_from_styles(manipulator_list, 0, 1) # return source_out, ref_out, blend_out, out_str, e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str, clone_out, all_imgs ## argument for choosing encoder between e4e and hyperstyle args = argparse.ArgumentParser() args.add_argument('--encoder', type=str, default='hyperstyle') opt = args.parse_args() device = 'cuda' if torch.cuda.is_available() else 'cpu' lpips = FaceMetric(metric_type='lpips', device=device) ssim = FaceMetric(metric_type='ms-ssim', device=device) id_score = FaceMetric(metric_type='id', device=device) clip = FaceMetric(metric_type='cliphair', device=device) # generator = Generator(1024, 512, 8, channel_multiplier=2).to(device).eval() # ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage) # generator.load_state_dict(ckpt['g_ema'], strict=False) generator = Generator(1024, 512, 8, channel_multiplier=2).to(device).eval() ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage) generator.load_state_dict(ckpt['g_ema'], strict=False) ckpt = 'pretrained/ffhq.pkl' G = Generator_wrapper(ckpt, device) manipulator = Manipulator(G, device) if opt.encoder == 'e4e': from util import align_face model_path = 'e4e_ffhq_encode.pt' ensure_checkpoint_exists(model_path) ckpt = torch.load(model_path, map_location='cpu') opts = ckpt['opts'] opts['checkpoint_path'] = model_path opts= Namespace(**opts) net = pSp(opts, device).eval().to(device) elif opt.encoder == 'hyperstyle': from hyperstyle.scripts.align_faces_parallel import align_face model_path = 'pretrained_models/hyperstyle_ffhq.pt' predictor = dlib.shape_predictor('pretrained_models/shape_predictor_68_face_landmarks.dat') net, _ = load_model(model_path) else: raise ValueError('invalid encoder') truncation = 0.5 stop_idx = 11 n_clusters = 18 clusterer = pickle.load(open('catalog.pkl', 'rb')) labels2idx = { 'nose': 0, 'eyes': 1, 'mouth': 2, 'hair': 3, 'background': 4, 'cheek': 5, 'neck': 6, 'clothes': 7, } labels_map = { 0: torch.tensor([7]), 1: torch.tensor([1,6]), 2: torch.tensor([4]), 3: torch.tensor([0,3,5,8,10,15,16]), 4: torch.tensor([11,13,14]), 5: torch.tensor([9]), 6: torch.tensor([17]), 7: torch.tensor([2,12]), } lables2idx = dict((v,k) for k,v in labels2idx.items()) n_class = len(lables2idx) segid_map = dict.fromkeys(labels_map[0].tolist(), 0) segid_map.update(dict.fromkeys(labels_map[1].tolist(), 1)) segid_map.update(dict.fromkeys(labels_map[2].tolist(), 2)) segid_map.update(dict.fromkeys(labels_map[3].tolist(), 3)) segid_map.update(dict.fromkeys(labels_map[4].tolist(), 4)) segid_map.update(dict.fromkeys(labels_map[5].tolist(), 5)) segid_map.update(dict.fromkeys(labels_map[6].tolist(), 6)) segid_map.update(dict.fromkeys(labels_map[7].tolist(), 7)) torch.manual_seed(0) transform = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ) topil = transforms.Compose( [ transforms.Normalize([-1, -1, -1], [2, 2, 2]), transforms.ToPILImage(), transforms.Resize(1024), ] ) e4e_ris_demo = gr.Interface(hair_transfer_e4e, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text"]) hyperstyle_ris_demo = gr.Interface(hair_transfer_hyperstyle, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text", "image"]) PTI_ris_demo = gr.Interface(hair_transfer_PTI, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text"]) ris_demo = gr.TabbedInterface(interface_list = [hyperstyle_ris_demo,e4e_ris_demo, PTI_ris_demo], tab_names=["hyperstyle", "e4e", "PTI"]) ris_demo.launch(share=True)