LB5's picture
Upload 45 files
e6a22e6
raw
history blame contribute delete
No virus
12.4 kB
import numpy as np
import torch
import torch.nn.functional as F
from typing import Iterable, Tuple, Union
from pathlib import Path
from torchvision import transforms
import kornia
from omegaconf import DictConfig
from src.FaceDetector.face_detector import Detection
from src.FaceAlign.face_align import align_face, inverse_transform_batch
from src.PostProcess.utils import SoftErosion
from src.model_loader import get_model
from src.Misc.types import CheckpointType, FaceAlignmentType
from src.Misc.utils import tensor2img
class SimSwap:
def __init__(
self,
config: DictConfig,
id_image: Union[np.ndarray, None] = None,
specific_image: Union[np.ndarray, None] = None,
):
self.id_image: Union[np.ndarray, None] = id_image
self.id_latent: Union[torch.Tensor, None] = None
self.specific_id_image: Union[np.ndarray, None] = specific_image
self.specific_latent: Union[torch.Tensor, None] = None
self.use_mask: Union[bool, None] = True
self.crop_size: Union[int, None] = None
self.checkpoint_type: Union[CheckpointType, None] = None
self.face_alignment_type: Union[FaceAlignmentType, None] = None
self.smooth_mask_iter: Union[int, None] = None
self.smooth_mask_kernel_size: Union[int, None] = None
self.smooth_mask_threshold: Union[float, None] = None
self.face_detector_threshold: Union[float, None] = None
self.specific_latent_match_threshold: Union[float, None] = None
self.device = torch.device(config.device)
self.set_parameters(config)
# For BiSeNet and for official_224 SimSwap
self.to_tensor_normalize = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
# For SimSwap models trained with the updated code
self.to_tensor = transforms.ToTensor()
self.face_detector = get_model(
"face_detector",
device=self.device,
load_state_dice=False,
model_path=Path(config.face_detector_weights),
det_thresh=self.face_detector_threshold,
det_size=(640, 640),
mode="ffhq",
)
self.face_id_net = get_model(
"arcface",
device=self.device,
load_state_dice=False,
model_path=Path(config.face_id_weights),
)
self.bise_net = get_model(
"parsing_model",
device=self.device,
load_state_dice=True,
model_path=Path(config.parsing_model_weights),
n_classes=19,
)
gen_model = "generator_512" if self.crop_size == 512 else "generator_224"
self.simswap_net = get_model(
gen_model,
device=self.device,
load_state_dice=True,
model_path=Path(config.simswap_weights),
input_nc=3,
output_nc=3,
latent_size=512,
n_blocks=9,
deep=True if self.crop_size == 512 else False,
use_last_act=True
if self.checkpoint_type == CheckpointType.OFFICIAL_224
else False,
)
self.blend = get_model(
"blend_module",
device=self.device,
load_state_dice=False,
model_path=Path(config.blend_module_weights)
)
self.enhance_output = config.enhance_output
if config.enhance_output:
self.gfpgan_net = get_model(
"gfpgan",
device=self.device,
load_state_dice=True,
model_path=Path(config.gfpgan_weights)
)
def set_parameters(self, config) -> None:
self.set_crop_size(config.crop_size)
self.set_checkpoint_type(config.checkpoint_type)
self.set_face_alignment_type(config.face_alignment_type)
self.set_face_detector_threshold(config.face_detector_threshold)
self.set_specific_latent_match_threshold(config.specific_latent_match_threshold)
self.set_smooth_mask_kernel_size(config.smooth_mask_kernel_size)
self.set_smooth_mask_threshold(config.smooth_mask_threshold)
self.set_smooth_mask_iter(config.smooth_mask_iter)
def set_crop_size(self, crop_size: int) -> None:
if crop_size < 0:
raise "Invalid crop_size! Must be a positive value."
self.crop_size = crop_size
def set_checkpoint_type(self, checkpoint_type: str) -> None:
type = CheckpointType(checkpoint_type)
if type not in (CheckpointType.OFFICIAL_224, CheckpointType.UNOFFICIAL):
raise "Invalid checkpoint_type! Must be one of the predefined values."
self.checkpoint_type = type
def set_face_alignment_type(self, face_alignment_type: str) -> None:
type = FaceAlignmentType(face_alignment_type)
if type not in (
FaceAlignmentType.FFHQ,
FaceAlignmentType.DEFAULT,
):
raise "Invalid face_alignment_type! Must be one of the predefined values."
self.face_alignment_type = type
def set_face_detector_threshold(self, face_detector_threshold: float) -> None:
if face_detector_threshold < 0.0 or face_detector_threshold > 1.0:
raise "Invalid face_detector_threshold! Must be a positive value in range [0.0...1.0]."
self.face_detector_threshold = face_detector_threshold
def set_specific_latent_match_threshold(
self, specific_latent_match_threshold: float
) -> None:
if specific_latent_match_threshold < 0.0:
raise "Invalid specific_latent_match_th! Must be a positive value."
self.specific_latent_match_threshold = specific_latent_match_threshold
def re_initialize_soft_mask(self):
self.smooth_mask = SoftErosion(kernel_size=self.smooth_mask_kernel_size,
threshold=self.smooth_mask_threshold,
iterations=self.smooth_mask_iter).to(self.device)
def set_smooth_mask_kernel_size(self, smooth_mask_kernel_size: int) -> None:
if smooth_mask_kernel_size < 0:
raise "Invalid smooth_mask_kernel_size! Must be a positive value."
smooth_mask_kernel_size += 1 if smooth_mask_kernel_size % 2 == 0 else 0
self.smooth_mask_kernel_size = smooth_mask_kernel_size
self.re_initialize_soft_mask()
def set_smooth_mask_threshold(self, smooth_mask_threshold: int) -> None:
if smooth_mask_threshold < 0 or smooth_mask_threshold > 1.0:
raise "Invalid smooth_mask_threshold! Must be within 0...1 range."
self.smooth_mask_threshold = smooth_mask_threshold
self.re_initialize_soft_mask()
def set_smooth_mask_iter(self, smooth_mask_iter: float) -> None:
if smooth_mask_iter < 0:
raise "Invalid smooth_mask_iter! Must be a positive value.."
self.smooth_mask_iter = smooth_mask_iter
self.re_initialize_soft_mask()
def run_detect_align(self, image: np.ndarray, for_id: bool = False) -> Tuple[Union[Iterable[np.ndarray], None],
Union[Iterable[np.ndarray], None],
np.ndarray]:
detection: Detection = self.face_detector(image)
if detection.bbox is None:
if for_id:
raise "Can't detect a face! Please change the ID image!"
return None, None, detection.score
kps = detection.key_points
if for_id:
max_score_ind = np.argmax(detection.score, axis=0)
kps = detection.key_points[max_score_ind]
kps = kps[None, ...]
align_imgs, transforms = align_face(
image,
kps,
crop_size=self.crop_size,
mode="ffhq"
if self.face_alignment_type == FaceAlignmentType.FFHQ
else "none",
)
return align_imgs, transforms, detection.score
def __call__(self, att_image: np.ndarray) -> np.ndarray:
if self.id_latent is None:
align_id_imgs, id_transforms, _ = self.run_detect_align(
self.id_image, for_id=True
)
# normalize=True, because official SimSwap model trained with normalized id_lattent
self.id_latent: torch.Tensor = self.face_id_net(
align_id_imgs, normalize=True
)
if self.specific_id_image is not None and self.specific_latent is None:
align_specific_imgs, specific_transforms, _ = self.run_detect_align(
self.specific_id_image, for_id=True
)
self.specific_latent: torch.Tensor = self.face_id_net(
align_specific_imgs, normalize=False
)
# for_id=False, because we want to get all faces
align_att_imgs, att_transforms, att_detection_score = self.run_detect_align(
att_image, for_id=False
)
if align_att_imgs is None and att_transforms is None:
return att_image
# Select specific crop from the target image
if self.specific_latent is not None:
att_latent: torch.Tensor = self.face_id_net(align_att_imgs, normalize=False)
latent_dist = torch.mean(
F.mse_loss(
att_latent,
self.specific_latent.repeat(att_latent.shape[0], 1),
reduction="none",
),
dim=-1,
)
att_detection_score = torch.tensor(
att_detection_score, device=latent_dist.device
)
min_index = torch.argmin(latent_dist * att_detection_score)
min_value = latent_dist[min_index]
if min_value < self.specific_latent_match_threshold:
align_att_imgs = [align_att_imgs[min_index]]
att_transforms = [att_transforms[min_index]]
else:
return att_image
swapped_img: torch.Tensor = self.simswap_net(align_att_imgs, self.id_latent)
if self.enhance_output:
swapped_img = self.gfpgan_net.enhance(swapped_img, weight=0.5)
# Put all crops/transformations into a batch
align_att_img_batch_for_parsing_model: torch.Tensor = torch.stack(
[self.to_tensor_normalize(x) for x in align_att_imgs], dim=0
)
align_att_img_batch_for_parsing_model = (
align_att_img_batch_for_parsing_model.to(self.device)
)
att_transforms: torch.Tensor = torch.stack(
[torch.tensor(x).float() for x in att_transforms], dim=0
)
att_transforms = att_transforms.to(self.device, non_blocking=True)
align_att_img_batch: torch.Tensor = torch.stack(
[self.to_tensor(x) for x in align_att_imgs], dim=0
)
align_att_img_batch = align_att_img_batch.to(self.device, non_blocking=True)
# Get face masks for the attribute image
face_mask, ignore_mask_ids = self.bise_net.get_mask(
align_att_img_batch_for_parsing_model, self.crop_size
)
inv_att_transforms: torch.Tensor = inverse_transform_batch(att_transforms)
soft_face_mask, _ = self.smooth_mask(face_mask)
swapped_img[ignore_mask_ids, ...] = align_att_img_batch[ignore_mask_ids, ...]
frame_size = (att_image.shape[0], att_image.shape[1])
att_image = self.to_tensor(att_image).to(self.device, non_blocking=True).unsqueeze(0)
target_image = kornia.geometry.transform.warp_affine(
swapped_img,
inv_att_transforms,
frame_size,
mode="bilinear",
padding_mode="border",
align_corners=True,
fill_value=torch.zeros(3),
)
soft_face_mask = kornia.geometry.transform.warp_affine(
soft_face_mask,
inv_att_transforms,
frame_size,
mode="bilinear",
padding_mode="zeros",
align_corners=True,
fill_value=torch.zeros(3),
)
result = self.blend(target_image, soft_face_mask, att_image)
return tensor2img(result)