Spaces:
Running
Running
File size: 6,621 Bytes
523fb10 fbf97aa 523fb10 127df95 523fb10 b9be4e6 523fb10 4f18a93 82206b3 4f18a93 82206b3 4f18a93 523fb10 4f18a93 523fb10 f057d66 b9be4e6 f057d66 b9be4e6 523fb10 b9be4e6 523fb10 b9be4e6 523fb10 b9be4e6 523fb10 b9be4e6 523fb10 b9be4e6 523fb10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
from third_party.bisenet.bisenet import BiSeNet
from third_party.GPEN.infer_image import GPENImageInfer
make_abs_path = lambda fn: os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), fn))
class Trick(object):
def __init__(self):
self.gpen_model = None
self.mouth_helper = None
@staticmethod
def get_any_mask(img, par=None, normalized=False):
# [0, 'background', 1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye',
# 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 10 'nose', 11 'mouth', 12 'u_lip',
# 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat']
ori_h, ori_w = img.shape[2], img.shape[3]
with torch.no_grad():
img = F.interpolate(img, size=512, mode="nearest", )
if not normalized:
img = img * 0.5 + 0.5
img = img.sub(vgg_mean.detach()).div(vgg_std.detach())
out = global_bisenet(img)[0]
parsing = out.softmax(1).argmax(1)
mask = torch.zeros_like(parsing)
for p in par:
mask = mask + ((parsing == p).float())
mask = mask.unsqueeze(1)
mask = F.interpolate(mask, size=(ori_h, ori_w), mode="bilinear", align_corners=True)
return mask
@staticmethod
def finetune_mask(facial_mask: np.ndarray, lmk_98: np.ndarray = None):
assert facial_mask.shape[1] == 256
facial_mask = (facial_mask * 255).astype(np.uint8)
# h_min = lmk_98[33:41, 0].min() + 20
h_min = 80
facial_mask = cv2.dilate(facial_mask, (40, 40), iterations=1)
facial_mask[:h_min] = 0 # black
facial_mask[255 - 20:] = 0
kernel_size = (20, 20)
blur_size = tuple(2 * j + 1 for j in kernel_size)
facial_mask = cv2.GaussianBlur(facial_mask, blur_size, 0)
return facial_mask.astype(np.float32) / 255
@staticmethod
def smooth_mask(mask_tensor: torch.Tensor):
mask_tensor, _ = global_smooth_mask(mask_tensor)
return mask_tensor
@staticmethod
def tensor_to_arr(tensor):
return ((tensor + 1.) * 127.5).permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
@staticmethod
def arr_to_tensor(arr, norm: bool = True):
tensor = torch.tensor(arr, dtype=torch.float).to(global_device) / 255 # in [0,1]
tensor = (tensor - 0.5) / 0.5 if norm else tensor # in [-1,1]
tensor = tensor.permute(0, 3, 1, 2)
return tensor
def gpen(self, img_np: np.ndarray, use_gpen=True):
if not use_gpen:
return img_np
if self.gpen_model is None:
self.gpen_model = GPENImageInfer(device=global_device)
img_np = self.gpen_model.image_infer(img_np)
return img_np
def finetune_mouth(self, i_s, i_t, i_r):
if self.mouth_helper is None:
self.load_mouth_helper()
helper_face = self.mouth_helper(i_s, i_t)[0]
i_r_mouth_mask = self.get_any_mask(i_r, par=[11, 12, 13]) # (B,1,H,W)
''' dilate and blur by cv2 '''
i_r_mouth_mask = self.tensor_to_arr(i_r_mouth_mask)[0] # (H,W,C)
i_r_mouth_mask = cv2.dilate(i_r_mouth_mask, (20, 20), iterations=1)
kernel_size = (5, 5)
blur_size = tuple(2 * j + 1 for j in kernel_size)
i_r_mouth_mask = cv2.GaussianBlur(i_r_mouth_mask, blur_size, 0) # (H,W,C)
i_r_mouth_mask = i_r_mouth_mask.squeeze()[None, :, :, None] # (1,H,W,1)
i_r_mouth_mask = self.arr_to_tensor(i_r_mouth_mask, norm=False) # in [0,1]
return helper_face * i_r_mouth_mask + i_r * (1 - i_r_mouth_mask)
def load_mouth_helper(self):
from modules.networks.faceshifter import FSGenerator
# mouth_helper_pl = EvaluatorFaceShifter(
# load_path="/apdcephfs/share_1290939/gavinyuan/out/triplet10w_34/epoch=13-step=737999.ckpt",
# pt_path=make_abs_path("../ffplus/extracted_ckpt/G_t34_helper_post.pth"),
# benchmark=None,
# demo_folder=None,
# )
pt_path = make_abs_path("../weights/extracted/G_t34_helper_post.pth")
self.mouth_helper = FSGenerator(
make_abs_path("../weights/arcface/ms1mv3_arcface_r100_fp16/backbone.pth"),
mouth_net_param={"use": False},
in_size=256,
downup=False,
)
self.mouth_helper.load_state_dict(torch.load(pt_path, "cpu"), strict=True)
self.mouth_helper.eval()
print("[Mouth helper] loaded.")
""" From MegaFS: https://github.com/zyainfal/One-Shot-Face-Swapping-on-Megapixels/tree/main/inference """
class SoftErosion(nn.Module):
def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
super(SoftErosion, self).__init__()
r = kernel_size // 2
self.padding = r
self.iterations = iterations
self.threshold = threshold
# Create kernel
y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
kernel = dist.max() - dist
kernel /= kernel.sum()
kernel = kernel.view(1, 1, *kernel.shape)
self.register_buffer('weight', kernel)
def forward(self, x):
x = x.float()
for i in range(self.iterations - 1):
x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding))
x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
mask = x >= self.threshold
x[mask] = 1.0
x[~mask] /= x[~mask].max()
return x, mask
if torch.cuda.is_available():
global_device = torch.device(0)
else:
global_device = torch.device('cpu')
vgg_mean = torch.tensor([[[0.485]], [[0.456]], [[0.406]]],
requires_grad=False, device=global_device)
vgg_std = torch.tensor([[[0.229]], [[0.224]], [[0.225]]],
requires_grad=False, device=global_device)
def load_bisenet():
bisenet_model = BiSeNet(n_classes=19)
bisenet_model.load_state_dict(
torch.load(make_abs_path("../weights/bisenet/79999_iter.pth",), map_location="cpu")
)
bisenet_model.eval()
bisenet_model = bisenet_model.to(global_device)
smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=7).to(global_device)
print('[Global] bisenet loaded.')
return bisenet_model, smooth_mask
global_bisenet, global_smooth_mask = load_bisenet()
|