| import numpy as np |
| import matplotlib.pyplot as plt |
| from PIL import Image |
| import cv2 |
| import random |
| import math |
| import argparse |
| import torch |
| from torch.utils import data |
| from torch.nn import functional as F |
| from torch import autograd |
| from torch.nn import init |
| import torchvision.transforms as transforms |
| from model.stylegan.op import conv2d_gradfix |
| from model.encoder.encoders.psp_encoders import GradualStyleEncoder |
| from model.encoder.align_all_parallel import get_landmark |
| |
| def visualize(img_arr, dpi): |
| plt.figure(figsize=(10,10),dpi=dpi) |
| plt.imshow(((img_arr.detach().cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)) |
| plt.axis('off') |
| plt.show() |
|
|
| def save_image(img, filename): |
| tmp = ((img.detach().cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8) |
| cv2.imwrite(filename, cv2.cvtColor(tmp, cv2.COLOR_RGB2BGR)) |
| |
| def load_image(filename): |
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]), |
| ]) |
| |
| img = Image.open(filename) |
| img = transform(img) |
| return img.unsqueeze(dim=0) |
|
|
| def data_sampler(dataset, shuffle, distributed): |
| if distributed: |
| return data.distributed.DistributedSampler(dataset, shuffle=shuffle) |
|
|
| if shuffle: |
| return data.RandomSampler(dataset) |
|
|
| else: |
| return data.SequentialSampler(dataset) |
|
|
|
|
| def requires_grad(model, flag=True): |
| for p in model.parameters(): |
| p.requires_grad = flag |
|
|
|
|
| def accumulate(model1, model2, decay=0.999): |
| par1 = dict(model1.named_parameters()) |
| par2 = dict(model2.named_parameters()) |
|
|
| for k in par1.keys(): |
| par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) |
|
|
|
|
| def sample_data(loader): |
| while True: |
| for batch in loader: |
| yield batch |
|
|
|
|
| def d_logistic_loss(real_pred, fake_pred): |
| real_loss = F.softplus(-real_pred) |
| fake_loss = F.softplus(fake_pred) |
|
|
| return real_loss.mean() + fake_loss.mean() |
|
|
|
|
| def d_r1_loss(real_pred, real_img): |
| with conv2d_gradfix.no_weight_gradients(): |
| grad_real, = autograd.grad( |
| outputs=real_pred.sum(), inputs=real_img, create_graph=True |
| ) |
| grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() |
|
|
| return grad_penalty |
|
|
|
|
| def g_nonsaturating_loss(fake_pred): |
| loss = F.softplus(-fake_pred).mean() |
|
|
| return loss |
|
|
|
|
| def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): |
| noise = torch.randn_like(fake_img) / math.sqrt( |
| fake_img.shape[2] * fake_img.shape[3] |
| ) |
| grad, = autograd.grad( |
| outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True |
| ) |
| path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) |
|
|
| path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) |
|
|
| path_penalty = (path_lengths - path_mean).pow(2).mean() |
|
|
| return path_penalty, path_mean.detach(), path_lengths |
|
|
|
|
| def make_noise(batch, latent_dim, n_noise, device): |
| if n_noise == 1: |
| return torch.randn(batch, latent_dim, device=device) |
|
|
| noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0) |
|
|
| return noises |
|
|
|
|
| def mixing_noise(batch, latent_dim, prob, device): |
| if prob > 0 and random.random() < prob: |
| return make_noise(batch, latent_dim, 2, device) |
|
|
| else: |
| return [make_noise(batch, latent_dim, 1, device)] |
|
|
|
|
| def set_grad_none(model, targets): |
| for n, p in model.named_parameters(): |
| if n in targets: |
| p.grad = None |
|
|
| |
| def weights_init(m): |
| classname = m.__class__.__name__ |
| if classname.find('BatchNorm2d') != -1: |
| if hasattr(m, 'weight') and m.weight is not None: |
| init.normal_(m.weight.data, 1.0, 0.02) |
| if hasattr(m, 'bias') and m.bias is not None: |
| init.constant_(m.bias.data, 0.0) |
| elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): |
| init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
| if hasattr(m, 'bias') and m.bias is not None: |
| init.constant_(m.bias.data, 0.0) |
| |
| |
| def load_psp_standalone(checkpoint_path, device='cuda'): |
| ckpt = torch.load(checkpoint_path, map_location='cpu') |
| opts = ckpt['opts'] |
| if 'output_size' not in opts: |
| opts['output_size'] = 1024 |
| opts['n_styles'] = int(math.log(opts['output_size'], 2)) * 2 - 2 |
| opts = argparse.Namespace(**opts) |
| psp = GradualStyleEncoder(50, 'ir_se', opts) |
| psp_dict = {k.replace('encoder.', ''): v for k, v in ckpt['state_dict'].items() if k.startswith('encoder.')} |
| psp.load_state_dict(psp_dict) |
| psp.eval() |
| psp = psp.to(device) |
| latent_avg = ckpt['latent_avg'].to(device) |
|
|
| def add_latent_avg(model, inputs, outputs): |
| return outputs + latent_avg.repeat(outputs.shape[0], 1, 1) |
|
|
| psp.register_forward_hook(add_latent_avg) |
| return psp |
|
|
| def get_video_crop_parameter(filepath, predictor, padding=[200,200,200,200]): |
| if type(filepath) == str: |
| img = dlib.load_rgb_image(filepath) |
| else: |
| img = filepath |
| lm = get_landmark(img, predictor) |
| if lm is None: |
| return None |
| lm_chin = lm[0 : 17] |
| lm_eyebrow_left = lm[17 : 22] |
| lm_eyebrow_right = lm[22 : 27] |
| lm_nose = lm[27 : 31] |
| lm_nostrils = lm[31 : 36] |
| lm_eye_left = lm[36 : 42] |
| lm_eye_right = lm[42 : 48] |
| lm_mouth_outer = lm[48 : 60] |
| lm_mouth_inner = lm[60 : 68] |
| |
| scale = 64. / (np.mean(lm_eye_right[:,0])-np.mean(lm_eye_left[:,0])) |
| center = ((np.mean(lm_eye_right, axis=0)+np.mean(lm_eye_left, axis=0)) / 2) * scale |
| h, w = round(img.shape[0] * scale), round(img.shape[1] * scale) |
| left = max(round(center[0] - padding[0]), 0) // 8 * 8 |
| right = min(round(center[0] + padding[1]), w) // 8 * 8 |
| top = max(round(center[1] - padding[2]), 0) // 8 * 8 |
| bottom = min(round(center[1] + padding[3]), h) // 8 * 8 |
| return h,w,top,bottom,left,right,scale |
|
|
| def tensor2cv2(img): |
| tmp = ((img.cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8) |
| return cv2.cvtColor(tmp, cv2.COLOR_RGB2BGR) |
|
|
| |
| def gather_params(G): |
| params = dict( |
| [(res, {}) for res in range(18)] + [("others", {})] |
| ) |
| for n, p in sorted(list(G.named_buffers()) + list(G.named_parameters())): |
| if n.startswith("convs"): |
| layer = int(n.split(".")[1]) + 1 |
| params[layer][n] = p |
| elif n.startswith("to_rgbs"): |
| layer = int(n.split(".")[1]) * 2 + 3 |
| params[layer][n] = p |
| elif n.startswith("conv1"): |
| params[0][n] = p |
| elif n.startswith("to_rgb1"): |
| params[1][n] = p |
| else: |
| params["others"][n] = p |
| return params |
|
|
| |
| |
| def blend_models(G_low, G_high, weight=[1]*7+[0]*11): |
| params_low = gather_params(G_low) |
| params_high = gather_params(G_high) |
|
|
| for res in range(18): |
| for n, p in params_high[res].items(): |
| params_high[res][n] = params_high[res][n] * (1-weight[res]) + params_low[res][n] * weight[res] |
|
|
| state_dict = {} |
| for _, p in params_high.items(): |
| state_dict.update(p) |
| |
| return state_dict |
|
|
|
|