File size: 3,969 Bytes
6d314be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import subprocess
import tempfile
from pathlib import Path

import torch
import torch.nn.functional as F
from PIL import Image
from torchvision.transforms import transforms
from torchvision.utils import save_image

from models.Net import get_segmentation


def equal_replacer(images: list[torch.Tensor]) -> list[torch.Tensor]:
    for i in range(len(images)):
        if images[i].dtype is torch.uint8:
            images[i] = images[i] / 255

    for i in range(len(images)):
        for j in range(i + 1, len(images)):
            if torch.allclose(images[i], images[j]):
                images[j] = images[i]
    return images


class DilateErosion:
    def __init__(self, dilate_erosion=5, device='cuda'):
        self.dilate_erosion = dilate_erosion
        self.weight = torch.Tensor([
            [False, True, False],
            [True, True, True],
            [False, True, False]
        ]).float()[None, None, ...].to(device)

    def hair_from_mask(self, mask):
        mask = torch.where(mask == 13, torch.ones_like(mask), torch.zeros_like(mask))
        mask = F.interpolate(mask, size=(256, 256), mode='nearest')
        dilate, erosion = self.mask(mask)
        return dilate, erosion

    def mask(self, mask):
        masks = mask.clone().repeat(*([2] + [1] * (len(mask.shape) - 1))).float()
        sum_w = self.weight.sum().item()
        n = len(mask)

        for _ in range(self.dilate_erosion):
            masks = F.conv2d(masks, self.weight,
                             bias=None, stride=1, padding='same', dilation=1, groups=1)
            masks[:n] = (masks[:n] > 0).float()
            masks[n:] = (masks[n:] == sum_w).float()

        hair_mask_dilate, hair_mask_erode = masks[:n], masks[n:]

        return hair_mask_dilate, hair_mask_erode


def poisson_image_blending(final_image, face_image, dilate_erosion=30, maxn=115):
    dilate_erosion = DilateErosion(dilate_erosion=dilate_erosion)
    transform = transforms.ToTensor()

    if isinstance(face_image, str):
        face_image = transform(Image.open(face_image))
    elif not isinstance(face_image, torch.Tensor):
        face_image = transform(face_image)

    final_mask = get_segmentation(final_image.cuda().unsqueeze(0), resize=False)
    face_mask = get_segmentation(face_image.cuda().unsqueeze(0), resize=False)

    hair_target = torch.where(final_mask == 13, torch.ones_like(final_mask),
                              torch.zeros_like(final_mask))
    hair_face = torch.where(face_mask == 13, torch.ones_like(face_mask),
                            torch.zeros_like(face_mask))

    final_mask = F.interpolate(((1 - hair_target) * (1 - hair_face)).float(), size=(1024, 1024), mode='bicubic')
    dilation, _ = dilate_erosion.mask(1 - final_mask)
    mask_save = 1 - dilation[0]

    with tempfile.TemporaryDirectory() as temp_dir:
        final_image_path = os.path.join(temp_dir, 'final_image.png')
        face_image_path = os.path.join(temp_dir, 'face_image.png')
        mask_path = os.path.join(temp_dir, 'mask_save.png')
        save_image(final_image, final_image_path)
        save_image(face_image, face_image_path)
        save_image(mask_save, mask_path)

        out_image_path = os.path.join(temp_dir, 'out_image_path.png')
        result = subprocess.run(
            ["fpie", "-s", face_image_path, "-m", mask_path, "-t", final_image_path, "-o", out_image_path, "-n",
             str(maxn), "-b", "taichi-gpu", "-g", "max"],
            check=True
        )

        return Image.open(out_image_path), Image.open(mask_path)


def list_image_files(directory):
    image_extensions = ['.jpg', '.jpeg', '.png']
    image_files = []

    for entry in sorted(os.listdir(directory)):
        file_path = os.path.join(directory, entry)
        if os.path.isfile(file_path):
            file_extension = Path(file_path).suffix.lower()
            if file_extension in image_extensions:
                image_files.append(entry)

    return image_files