File size: 3,980 Bytes
e6a22e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import numpy as np
import torch
import torch.nn.functional as F

from typing import Tuple


class SoftErosion(torch.nn.Module):
    def __init__(
        self, kernel_size: int = 15, threshold: float = 0.6, iterations: int = 1
    ):
        super(SoftErosion, self).__init__()
        r = kernel_size // 2
        self.padding = r
        self.iterations = iterations
        self.threshold = threshold

        # Create kernel
        y_indices, x_indices = torch.meshgrid(
            torch.arange(0.0, kernel_size), torch.arange(0.0, kernel_size)
        )
        dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
        kernel = dist.max() - dist
        kernel /= kernel.sum()
        kernel = kernel.view(1, 1, *kernel.shape)
        self.register_buffer("weight", kernel)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        for i in range(self.iterations - 1):
            x = torch.min(
                x,
                F.conv2d(
                    x, weight=self.weight, groups=x.shape[1], padding=self.padding
                ),
            )
        x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)

        mask = x >= self.threshold

        x[mask] = 1.0
        # add small epsilon to avoid Nans
        x[~mask] /= (x[~mask].max() + 1e-7)

        return x, mask


def encode_segmentation_rgb(
    segmentation: np.ndarray, no_neck: bool = True
) -> np.ndarray:
    parse = segmentation
    # https://github.com/zllrunning/face-parsing.PyTorch/blob/master/prepropess_data.py
    face_part_ids = (
        [1, 2, 3, 4, 5, 6, 10, 12, 13]
        if no_neck
        else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14]
    )
    mouth_id = 11
    # hair_id = 17
    face_map = np.zeros([parse.shape[0], parse.shape[1]])
    mouth_map = np.zeros([parse.shape[0], parse.shape[1]])
    # hair_map = np.zeros([parse.shape[0], parse.shape[1]])

    for valid_id in face_part_ids:
        valid_index = np.where(parse == valid_id)
        face_map[valid_index] = 255
    valid_index = np.where(parse == mouth_id)
    mouth_map[valid_index] = 255
    # valid_index = np.where(parse==hair_id)
    # hair_map[valid_index] = 255
    # return np.stack([face_map, mouth_map,hair_map], axis=2)
    return np.stack([face_map, mouth_map], axis=2)


def encode_segmentation_rgb_batch(
    segmentation: torch.Tensor, no_neck: bool = True
) -> torch.Tensor:
    # https://github.com/zllrunning/face-parsing.PyTorch/blob/master/prepropess_data.py
    face_part_ids = (
        [1, 2, 3, 4, 5, 6, 10, 12, 13]
        if no_neck
        else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14]
    )
    mouth_id = 11
    # hair_id = 17
    segmentation = segmentation.int()
    face_map = torch.zeros_like(segmentation)
    mouth_map = torch.zeros_like(segmentation)
    # hair_map = np.zeros([parse.shape[0], parse.shape[1]])

    white_tensor = face_map + 255
    for valid_id in face_part_ids:
        face_map = torch.where(segmentation == valid_id, white_tensor, face_map)
    mouth_map = torch.where(segmentation == mouth_id, white_tensor, mouth_map)

    return torch.cat([face_map, mouth_map], dim=1)


def postprocess(
    swapped_face: np.ndarray,
    target: np.ndarray,
    target_mask: np.ndarray,
    smooth_mask: torch.nn.Module,
) -> np.ndarray:
    # target_mask = cv2.resize(target_mask, (self.size,  self.size))

    mask_tensor = (
        torch.from_numpy(target_mask.copy().transpose((2, 0, 1)))
        .float()
        .mul_(1 / 255.0)
        .cuda()
    )
    face_mask_tensor = mask_tensor[0] + mask_tensor[1]

    soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
    soft_face_mask_tensor.squeeze_()

    soft_face_mask = soft_face_mask_tensor.cpu().numpy()
    soft_face_mask = soft_face_mask[:, :, np.newaxis]

    result = swapped_face * soft_face_mask + target * (1 - soft_face_mask)
    result = result[:, :, ::-1]  # .astype(np.uint8)
    return result