ethanNeuralImage's picture
start implementing PTI, RIS implemented
ab189a8
raw
history blame contribute delete
No virus
20.3 kB
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
from tqdm import tqdm as tqdm
import pickle
import warnings
warnings.filterwarnings("ignore")
from spherical_kmeans import MiniBatchSphericalKMeans as sKmeans
from argparse import Namespace
from torch.utils.data import Dataset, DataLoader
from torchvision.models import vgg19
import glob
from pathlib import Path
import lpips
import argparse
import gc
import cv2
from models.pti.e4e_projection import projection
from model import *
from util import *
from e4e.models.psp import pSp
import torchvision.transforms as transforms
from torch.nn import DataParallel
import torchvision.transforms.functional as TF
from FaceQualityMetrics.utils import FaceMetric
from hyperstyle.utils.model_utils import load_model
from configs.paths_config import model_paths
from models.pti.manipulator import Manipulator
from models.pti.wrapper import Generator_wrapper
def run_inversion(inputs, net, n_iters_per_batch, return_intermediate_results=False, resize_outputs=False, weights_deltas=None):
y_hat, latent, weights_deltas, codes = None, None, weights_deltas, None
if return_intermediate_results:
results_batch = {idx: [] for idx in range(inputs.shape[0])}
results_latent = {idx: [] for idx in range(inputs.shape[0])}
results_deltas = {idx: [] for idx in range(inputs.shape[0])}
else:
results_batch, results_latent, results_deltas = None, None, None
if weights_deltas is None:
for iter in range(n_iters_per_batch):
y_hat, latent, weights_deltas, codes, _ = net.forward(inputs,
y_hat=y_hat,
codes=codes,
weights_deltas=weights_deltas,
return_latents=True,
resize=resize_outputs,
randomize_noise=False,
return_weight_deltas_and_codes=True)
# weights_deltas[14]= None
# weights_deltas[20]= None
# weights_deltas[21]= None
# weights_deltas[23]= None
# weights_deltas[24]= None
if return_intermediate_results:
store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas)
# resize input to 256 before feeding into next iteration
y_hat = net.face_pool(y_hat)
else:
for iter in range(n_iters_per_batch):
y_hat, latent, _, codes, _ = net.forward(inputs,
y_hat=y_hat,
codes=codes,
weights_deltas=weights_deltas,
return_latents=True,
resize=resize_outputs,
randomize_noise=False,
return_weight_deltas_and_codes=True)
if return_intermediate_results:
store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas)
# resize input to 256 before feeding into next iteration
y_hat = net.face_pool(y_hat)
if return_intermediate_results:
return results_batch, results_latent, results_deltas
return y_hat, latent, weights_deltas, codes
def store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas):
for idx in range(y_hat.shape[0]):
results_batch[idx].append(y_hat[idx])
results_latent[idx].append(latent[idx].cpu().numpy())
results_deltas[idx].append([w[idx].cpu().numpy() if w is not None else None for w in weights_deltas])
# compute M given a style code.
@torch.no_grad()
def compute_M(w, weights_deltas=None, device='cuda'):
M = []
# get segmentation
# _, outputs = generator(w, is_cluster=1)
_, outputs = generator(w, weights_deltas=weights_deltas)
cluster_layer = outputs[stop_idx][0]
activation = flatten_act(cluster_layer)
seg_mask = clusterer.predict(activation)
b,c,h,w = cluster_layer.size()
# create masks for each feature
all_seg_mask = []
seg_mask = torch.from_numpy(seg_mask).view(b,1,h,w,1).to(device)
for key in range(n_class):
# combine masks for all indices for a particular segmentation class
indices = labels_map[key].view(1,1,1,1,-1)
key_mask = (seg_mask == indices.to(device)).any(-1) #[b,1,h,w]
all_seg_mask.append(key_mask)
all_seg_mask = torch.stack(all_seg_mask, 1)
# go through each activation layer and compute M
for layer_idx in range(len(outputs)):
layer = outputs[layer_idx][1].to(device)
b,c,h,w = layer.size()
layer = F.instance_norm(layer)
layer = layer.pow(2)
# resize the segmentation masks to current activations' resolution
layer_seg_mask = F.interpolate(all_seg_mask.flatten(0,1).float(), align_corners=False,
size=(h,w), mode='bilinear').view(b,-1,1,h,w)
masked_layer = layer.unsqueeze(1) * layer_seg_mask # [b,k,c,h,w]
masked_layer = (masked_layer.sum([3,4])/ (h*w))#[b,k,c]
M.append(masked_layer.to(device))
M = torch.cat(M, -1) #[b, k, c]
# softmax to assign each channel to a particular segmentation class
M = F.softmax(M/.1, 1)
# simple thresholding
M = (M>.8).float()
# zero out torgb transfers, from https://arxiv.org/abs/2011.12799
for i in range(n_class):
part_M = style2list(M[:, i])
for j in range(len(part_M)):
if j in rgb_layer_idx:
part_M[j].zero_()
part_M = list2style(part_M)
M[:, i] = part_M
return M
#====
# for i in range(len(blend_deltas)):
# if blend_deltas[i] is not None:
# print(f'{i}: {part_M_mask[i].sum()}/{sum(part_M_mask[i].shape)}')
# if part_M_mask[i].sum() >= sum(part_M_mask[i].shape)/2:
# print(i)
# blend_deltas[i] = ref_deltas[i]
def tensor2img(tensor):
tensor = tensor.cpu().clamp(-1, 1)
img = topil(tensor.squeeze())
return img
def hair_transfer_hyperstyle(source_img_path, ref_img_path):
with torch.no_grad():
source_img = align_face(source_img_path, predictor=predictor)
ref_img = align_face(ref_img_path, predictor=predictor)
source_img = Image.fromarray(np.uint8(source_img))
ref_img = Image.fromarray(np.uint8(ref_img))
source_tensor = transform(source_img).unsqueeze(0).to(device)
ref_tensor = transform(ref_img).unsqueeze(0).to(device)
source_batch, source_latent, source_deltas, source_codes = run_inversion(source_tensor, net, n_iters_per_batch=5, return_intermediate_results=False)
ref_batch, ref_latent, ref_deltas, ref_codes = run_inversion(ref_tensor, net, n_iters_per_batch=5, return_intermediate_results=False)
source = generator.get_latent(source_latent[0].unsqueeze(0), truncation=1, is_latent=True)
ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True)
source_out, _ = generator(source, weights_deltas=source_deltas, randomize_noise=False)
ref_out, _ = generator(ref, weights_deltas=ref_deltas, randomize_noise=False)
source_M = compute_M(source, weights_deltas=source_deltas, device='cpu')
ref_M = compute_M(ref, weights_deltas=ref_deltas, device='cpu')
blend_deltas = source_deltas
max_M = torch.max(source_M.expand_as(ref_M), ref_M)
max_M = add_pose(max_M, labels2idx)
idx = labels2idx['hair']
part_M = max_M[:, idx].to(device)
part_M_mask = style2list(part_M)
blend = style2list((add_direction(source, ref, part_M, 1.3)))
blend_out, _ = generator(blend, weights_deltas=blend_deltas)
source_out = tensor2img(source_out)
ref_out = tensor2img(ref_out)
blend_out = tensor2img(blend_out)
lpips_face, _ = lpips(blend_out, source_out)
ssim_face, _ = ssim(blend_out, source_out)
id_face, _ = id_score(blend_out, source_out)
_, lpips_hair = lpips(blend_out, ref_out)
_, ssim_hair = ssim(blend_out, ref_out)
_, clip_hair = clip(blend_out, source_out)
out_str = f'lpips_face: {lpips_face}\nlpips_hair: {lpips_hair}\nssim_face: {ssim_face}\nssim_hair: {ssim_hair}\nid_face: {id_face}\n clip_hair: {clip_hair}'
e4e_blend_out, _ = generator(blend)
e4e_blend_out = tensor2img(e4e_blend_out)
_, _, e4e_blend_hair_mask = lpips.parser(e4e_blend_out)
source_out_np = np.array(source_out)
blend_np =np.array(e4e_blend_out).astype(np.uint8)
e4e_blend_hair_mask = e4e_blend_hair_mask.cpu().numpy().astype(np.uint8)*255
mask_dilate = cv2.dilate(e4e_blend_hair_mask,
kernel=np.ones((50, 50), np.uint8))
mask_dilate_blur = cv2.blur(mask_dilate, ksize=(30, 30))
mask_dilate_blur = (e4e_blend_hair_mask + (255 - e4e_blend_hair_mask) / 255 * mask_dilate_blur).astype(np.uint8)
face_mask = 255 - mask_dilate_blur
index = np.where(face_mask > 0)
cy = (np.min(index[0]) + np.max(index[0])) // 2
cx = (np.min(index[1]) + np.max(index[1])) // 2
center = (cx, cy)
clone_out = cv2.seamlessClone(source_out_np, blend_np, face_mask, center, cv2.NORMAL_CLONE)
return source_out, ref_out, blend_out, out_str, clone_out
def hair_transfer_e4e(source_img_path, ref_img_path):
with torch.no_grad():
source_img = align_face(source_img_path, predictor=predictor)
ref_img = align_face(ref_img_path, predictor=predictor)
source_img = Image.fromarray(np.uint8(source_img))
ref_img = Image.fromarray(np.uint8(ref_img))
source_tensor = transform(source_img).unsqueeze(0).to(device)
ref_tensor = transform(ref_img).unsqueeze(0).to(device)
source_batch, source_latent, source_deltas, source_codes = run_inversion(source_tensor, net, n_iters_per_batch=5, return_intermediate_results=False)
ref_batch, ref_latent, ref_deltas, ref_codes = run_inversion(ref_tensor, net, n_iters_per_batch=5, return_intermediate_results=False)
source = generator.get_latent(source_latent[0].unsqueeze(0), truncation=1, is_latent=True)
ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True)
e4e_source_out, _ = generator(source, randomize_noise=False)
e4e_ref_out, _ = generator(ref, randomize_noise=False)
e4e_source_M = compute_M(source, device='cpu')
e4e_ref_M = compute_M(ref, device='cpu')
e4e_max_M = torch.max(e4e_source_M.expand_as(e4e_ref_M), e4e_ref_M)
e4e_max_M = add_pose(e4e_max_M, labels2idx)
e4e_idx = labels2idx['hair']
e4e_part_M = e4e_max_M[:, e4e_idx].to(device)
e4e_part_M_mask = style2list(e4e_part_M)
e4e_blend = style2list((add_direction(source, ref, e4e_part_M, 1.3)))
e4e_blend_out, _ = generator(e4e_blend)
e4e_source_out = tensor2img(e4e_source_out)
e4e_ref_out = tensor2img(e4e_ref_out)
e4e_blend_out = tensor2img(e4e_blend_out)
e4e_lpips_face, _ = lpips(e4e_blend_out, e4e_source_out)
e4e_ssim_face, _ = ssim(e4e_blend_out, e4e_source_out)
e4e_id_face, _ = id_score(e4e_blend_out, e4e_source_out)
_, e4e_lpips_hair = lpips(e4e_blend_out, e4e_ref_out)
_, e4e_ssim_hair = ssim(e4e_blend_out, e4e_ref_out)
_, e4e_clip_hair = clip(e4e_blend_out, e4e_source_out)
e4e_out_str = f'e4e_lpips_face: {e4e_lpips_face}\ne4e_lpips_hair: {e4e_lpips_hair}\ne4e_ssim_face: {e4e_ssim_face}\ne4e_ssim_hair: {e4e_ssim_hair}\ne4e_id_face: {e4e_id_face}\ne4e_ clip_hair: {e4e_clip_hair}'
return e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str
def hair_transfer_PTI(source_img_path, ref_img_path):
ckpt = 'pretrained/ffhq.pkl'
G = Generator_wrapper(ckpt, device)
manipulator = Manipulator(G, device)
manipulator.set_real_img_projection(source_img_path, inv_mode='w+', pti_mode='s')
with torch.no_grad():
source_img = align_face(source_img_path, predictor=predictor)
ref_img = align_face(ref_img_path, predictor=predictor)
source_img = Image.fromarray(np.uint8(source_img))
projection(source_img, 'source', generator, device)
projection(ref_img, 'ref', generator, device)
source = load_source('source', generator, device)
ref = load_source('ref', generator, device)
e4e_source_out, _ = generator(source, randomize_noise=False)
e4e_ref_out, _ = generator(ref, randomize_noise=False)
e4e_source_M = compute_M(source, device='cpu')
e4e_ref_M = compute_M(ref, device='cpu')
e4e_max_M = torch.max(e4e_source_M.expand_as(e4e_ref_M), e4e_ref_M)
e4e_max_M = add_pose(e4e_max_M, labels2idx)
e4e_idx = labels2idx['hair']
e4e_part_M = e4e_max_M[:, e4e_idx].to(device)
e4e_part_M_mask = style2list(e4e_part_M)
e4e_blend = style2list((add_direction(source, ref, e4e_part_M, 1.3)))
e4e_source_out = tensor2img(e4e_source_out)
e4e_ref_out = tensor2img(e4e_ref_out)
# e4e_blend_out = tensor2img(e4e_blend_out)
# e4e_lpips_face, _ = lpips(e4e_blend_out, e4e_source_out)
# e4e_ssim_face, _ = ssim(e4e_blend_out, e4e_source_out)
# e4e_id_face, _ = id_score(e4e_blend_out, e4e_source_out)
# _, e4e_lpips_hair = lpips(e4e_blend_out, e4e_ref_out)
# _, e4e_ssim_hair = ssim(e4e_blend_out, e4e_ref_out)
# _, e4e_clip_hair = clip(e4e_blend_out, e4e_source_out)
keys = (['G.synthesis.b4.conv1.affine', 'G.synthesis.b4.torgb.affine', 'G.synthesis.b8.conv0.affine', 'G.synthesis.b8.conv1.affine', 'G.synthesis.b8.torgb.affine', 'G.synthesis.b16.conv0.affine', 'G.synthesis.b16.conv1.affine', 'G.synthesis.b16.torgb.affine', 'G.synthesis.b32.conv0.affine', 'G.synthesis.b32.conv1.affine', 'G.synthesis.b32.torgb.affine', 'G.synthesis.b64.conv0.affine', 'G.synthesis.b64.conv1.affine', 'G.synthesis.b64.torgb.affine', 'G.synthesis.b128.conv0.affine', 'G.synthesis.b128.conv1.affine', 'G.synthesis.b128.torgb.affine', 'G.synthesis.b256.conv0.affine', 'G.synthesis.b256.conv1.affine', 'G.synthesis.b256.torgb.affine', 'G.synthesis.b512.conv0.affine', 'G.synthesis.b512.conv1.affine', 'G.synthesis.b512.torgb.affine', 'G.synthesis.b1024.conv0.affine', 'G.synthesis.b1024.conv1.affine', 'G.synthesis.b1024.torgb.affine'])
test_dict = dict(zip(keys, e4e_blend))
manipulator_list = []
manipulator_list.append(test_dict)
all_imgs = manipulator.synthesis_from_styles(manipulator_list, 0, 1)
PTI_outstr = 'PTI_outstr'
blend_out = tensor2img(all_imgs[0])
return e4e_source_out, e4e_ref_out, blend_out, PTI_outstr
# _, _, e4e_blend_hair_mask = lpips.parser(e4e_blend_out)
# blend_out_np = np.array(blend_out)
# blend_np =np.array(e4e_blend_out).astype(np.uint8)
# e4e_blend_hair_mask = e4e_blend_hair_mask.cpu().numpy().astype(np.uint8)*255
# mask_dilate = cv2.dilate(e4e_blend_hair_mask,
# kernel=np.ones((50, 50), np.uint8))
# mask_dilate_blur = cv2.blur(mask_dilate, ksize=(30, 30))
# mask_dilate_blur = (e4e_blend_hair_mask + (255 - e4e_blend_hair_mask) / 255 * mask_dilate_blur).astype(np.uint8)
# face_mask = 255 - mask_dilate_blur
# index = np.where(face_mask > 0)
# cy = (np.min(index[0]) + np.max(index[0])) // 2
# cx = (np.min(index[1]) + np.max(index[1])) // 2
# center = (cx, cy)
# clone_out = cv2.seamlessClone(blend_out_np, blend_np, face_mask, center, cv2.NORMAL_CLONE)
# out_str = f'lpips_face: {lpips_face}\nlpips_hair: {lpips_hair}\nssim_face: {ssim_face}\nssim_hair: {ssim_hair}\nid_face: {id_face}\n clip_hair: {clip_hair}'
# seg_out = torch.tensor(face_mask).float().unsqueeze(-1).repeat(1,1,3)
# seg_out = seg_out.cpu().numpy().astype(np.uint8)
# # seg_out*=255
# seg_out = Image.fromarray(seg_out)
# # return source_out, ref_out, blend_out, out_str, e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str, clone_out, seg_out
# ## Set source_tensor requires_grad=True
# source_tensor.requires_grad = True
# ref_tensor.requires_grad = True
# ckpt = 'pretrained/ffhq.pkl'
# G = Generator_wrapper(ckpt, device)
# manipulator = Manipulator(G, device)
# manipulator.set_real_img_projection(source_img_path, inv_mode='w+', pti_mode='s')
# blend = style2list((add_direction(source_tensor, ref_tensor, part_M, 1.3)))
# keys = (['G.synthesis.b4.conv1.affine', 'G.synthesis.b4.torgb.affine', 'G.synthesis.b8.conv0.affine', 'G.synthesis.b8.conv1.affine', 'G.synthesis.b8.torgb.affine', 'G.synthesis.b16.conv0.affine', 'G.synthesis.b16.conv1.affine', 'G.synthesis.b16.torgb.affine', 'G.synthesis.b32.conv0.affine', 'G.synthesis.b32.conv1.affine', 'G.synthesis.b32.torgb.affine', 'G.synthesis.b64.conv0.affine', 'G.synthesis.b64.conv1.affine', 'G.synthesis.b64.torgb.affine', 'G.synthesis.b128.conv0.affine', 'G.synthesis.b128.conv1.affine', 'G.synthesis.b128.torgb.affine', 'G.synthesis.b256.conv0.affine', 'G.synthesis.b256.conv1.affine', 'G.synthesis.b256.torgb.affine', 'G.synthesis.b512.conv0.affine', 'G.synthesis.b512.conv1.affine', 'G.synthesis.b512.torgb.affine', 'G.synthesis.b1024.conv0.affine', 'G.synthesis.b1024.conv1.affine', 'G.synthesis.b1024.torgb.affine'])
# test_dict = dict(zip(keys, blend))
# manipulator_list = []
# manipulator_list.append(test_dict)
# all_imgs = manipulator.synthesis_from_styles(manipulator_list, 0, 1)
# return source_out, ref_out, blend_out, out_str, e4e_source_out, e4e_ref_out, e4e_blend_out, e4e_out_str, clone_out, all_imgs
## argument for choosing encoder between e4e and hyperstyle
args = argparse.ArgumentParser()
args.add_argument('--encoder', type=str, default='hyperstyle')
opt = args.parse_args()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lpips = FaceMetric(metric_type='lpips', device=device)
ssim = FaceMetric(metric_type='ms-ssim', device=device)
id_score = FaceMetric(metric_type='id', device=device)
clip = FaceMetric(metric_type='cliphair', device=device)
# generator = Generator(1024, 512, 8, channel_multiplier=2).to(device).eval()
# ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
# generator.load_state_dict(ckpt['g_ema'], strict=False)
generator = Generator(1024, 512, 8, channel_multiplier=2).to(device).eval()
ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
generator.load_state_dict(ckpt['g_ema'], strict=False)
ckpt = 'pretrained/ffhq.pkl'
G = Generator_wrapper(ckpt, device)
manipulator = Manipulator(G, device)
if opt.encoder == 'e4e':
from util import align_face
model_path = 'e4e_ffhq_encode.pt'
ensure_checkpoint_exists(model_path)
ckpt = torch.load(model_path, map_location='cpu')
opts = ckpt['opts']
opts['checkpoint_path'] = model_path
opts= Namespace(**opts)
net = pSp(opts, device).eval().to(device)
elif opt.encoder == 'hyperstyle':
from hyperstyle.scripts.align_faces_parallel import align_face
model_path = 'pretrained_models/hyperstyle_ffhq.pt'
predictor = dlib.shape_predictor('pretrained_models/shape_predictor_68_face_landmarks.dat')
net, _ = load_model(model_path)
else:
raise ValueError('invalid encoder')
truncation = 0.5
stop_idx = 11
n_clusters = 18
clusterer = pickle.load(open('catalog.pkl', 'rb'))
labels2idx = {
'nose': 0,
'eyes': 1,
'mouth': 2,
'hair': 3,
'background': 4,
'cheek': 5,
'neck': 6,
'clothes': 7,
}
labels_map = {
0: torch.tensor([7]),
1: torch.tensor([1,6]),
2: torch.tensor([4]),
3: torch.tensor([0,3,5,8,10,15,16]),
4: torch.tensor([11,13,14]),
5: torch.tensor([9]),
6: torch.tensor([17]),
7: torch.tensor([2,12]),
}
lables2idx = dict((v,k) for k,v in labels2idx.items())
n_class = len(lables2idx)
segid_map = dict.fromkeys(labels_map[0].tolist(), 0)
segid_map.update(dict.fromkeys(labels_map[1].tolist(), 1))
segid_map.update(dict.fromkeys(labels_map[2].tolist(), 2))
segid_map.update(dict.fromkeys(labels_map[3].tolist(), 3))
segid_map.update(dict.fromkeys(labels_map[4].tolist(), 4))
segid_map.update(dict.fromkeys(labels_map[5].tolist(), 5))
segid_map.update(dict.fromkeys(labels_map[6].tolist(), 6))
segid_map.update(dict.fromkeys(labels_map[7].tolist(), 7))
torch.manual_seed(0)
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
topil = transforms.Compose(
[
transforms.Normalize([-1, -1, -1], [2, 2, 2]),
transforms.ToPILImage(),
transforms.Resize(1024),
]
)
e4e_ris_demo = gr.Interface(hair_transfer_e4e, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text"])
hyperstyle_ris_demo = gr.Interface(hair_transfer_hyperstyle, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text", "image"])
PTI_ris_demo = gr.Interface(hair_transfer_PTI, inputs=[gr.Image(type='filepath'),gr.Image(type='filepath')], outputs=["image","image","image","text"])
ris_demo = gr.TabbedInterface(interface_list = [hyperstyle_ris_demo,e4e_ris_demo, PTI_ris_demo], tab_names=["hyperstyle", "e4e", "PTI"])
ris_demo.launch(share=True)