File size: 6,950 Bytes
7edd531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import PIL.Image
import cv2
import gradio as gr
import huggingface_hub
import numpy as np
import onnxruntime as rt
from PIL import ImageOps
from carvekit.trimap.generator import TrimapGenerator
from pymatting import estimate_alpha_cf, estimate_foreground_ml, stack_images, load_image

providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
model_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
rmbg_model = rt.InferenceSession(model_path, providers=providers)

trimapGenerator = TrimapGenerator()


# def custom_background(background, foreground):
#     foreground = ImageOps.contain(foreground, background.size)
#     x = (background.size[0] - foreground.size[0]) // 2
#     y = (background.size[1] - foreground.size[1]) // 2
#     background.paste(foreground, (x, y), foreground)
#     return background

def custom_background(background: PIL.Image.Image, foreground: np.ndarray):
    final_foreground = PIL.Image.fromarray(foreground)
    x = (background.size[0] - final_foreground.size[0]) / 2
    y = (background.size[1] - final_foreground.size[1]) / 2
    box = (x, y, final_foreground.size[0] + x, final_foreground.size[1] + y)
    crop = background.crop(box)
    final_image = crop.copy()
    # put the foreground in the centre of the background
    paste_box = (0, final_image.size[1] - final_foreground.size[1], final_image.size[0], final_image.size[1])
    final_image.paste(final_foreground, paste_box, mask=final_foreground)
    return np.array(final_image)


def get_mask(img, s=1024):
    img = (img / 255).astype(np.float32)
    h, w = h0, w0 = img.shape[:-1]
    h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
    ph, pw = s - h, s - w
    img_input = np.zeros([s, s, 3], dtype=np.float32)
    img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h))
    img_input = np.transpose(img_input, (2, 0, 1))
    img_input = img_input[np.newaxis, :]
    mask = rmbg_model.run(None, {'img': img_input})[0][0]
    mask = np.transpose(mask, (1, 2, 0))
    mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
    mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis]
    return mask


def change_background_color(image, color="blue"):
    mask = get_mask(image)
    image = (mask * image + 255 * (1 - mask)).astype(np.uint8)
    mask = (mask * 255).astype(np.uint8)
    image = np.concatenate([image, mask], axis=2, dtype=np.uint8)
    image = PIL.Image.fromarray(image)
    background = PIL.Image.new('RGB', image.size, color)
    background.paste(image, (0, 0), image)
    return background


def generate_trimap(probs, size=7, conf_threshold=0.95):
    """
    This function creates a trimap based on simple dilation algorithm
    Inputs [3]: an image with probabilities of each pixel being the foreground, size of dilation kernel,
    foreground confidence threshold
    Output    : a trimap
    """
    mask = (probs > 0.05).astype(np.uint8) * 255
    pixels = 2 * size + 1
    kernel = np.ones((pixels, pixels), np.uint8)
    dilation = cv2.dilate(mask, kernel, iterations=1)
    remake = np.zeros_like(mask)
    remake[dilation == 255] = 127  # Set every pixel within dilated region as probably foreground.
    remake[probs > conf_threshold] = 255  # Set every pixel with large enough probability as definitely foreground.
    return remake


def image2gray(image):
    image = PIL.Image.fromarray(image).convert("L")
    return np.array(image) / 255.0


def paste(img_orig, alpha):
    img_ = img_orig.astype(np.float32) / 255
    alpha_ = cv2.resize(alpha, (img_.shape[1], img_.shape[0]), cv2.INTER_LANCZOS4)
    fg_alpha = np.concatenate([img_, alpha_[:, :, np.newaxis]], axis=2)
    cv2.imwrite("new_back.png", (fg_alpha * 255).astype(np.uint8))


def predict(image, new_background):
    mask = get_mask(image)
    mask = (mask * 255).astype(np.uint8)
    mask = mask.repeat(3, axis=2)

    trimap = generate_trimap(mask)
    trimap = image2gray(trimap)
    # trimap = load_image("images/trimaps/lemur_trimap.png", "GRAY")

    original = PIL.Image.fromarray(image)
    # mask = image2gray(mask)
    mask = PIL.Image.fromarray(mask).convert("L")
    trimap = trimapGenerator(original_image=original, mask=mask)
    trimap = np.array(trimap) / 255.0

    foreground = image / 255
    alpha = estimate_alpha_cf(foreground, trimap)
    foreground = estimate_foreground_ml(foreground, alpha)
    cutout = stack_images(foreground, alpha)
    cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)

    if new_background is not None:
        return mask, trimap, custom_background(new_background, cutout)
    return alpha, trimap, cutout


# contours
def serendipity(image, new_background):
    mask = get_mask(image)
    mask = 255 - mask
    image = (mask * image + 255 * (1 - mask)).astype(np.uint8)
    mask = (mask * 255).astype(np.uint8)
    image = np.concatenate([image, mask], axis=2, dtype=np.uint8)
    return mask, image


def negative(image, new_background):
    mask = get_mask(image)
    image = (mask * image + 255 * (1 - mask)).astype(np.uint8)
    image = 255 - image
    mask = (mask * 255).astype(np.uint8)
    image = np.concatenate([image, mask], axis=2, dtype=np.uint8)
    return mask, image


def checkit(image, new_background):
    mask = get_mask(image)
    mask = 255 - mask
    image = (mask / image - 255 / (1 + mask)).astype(np.uint8)
    mask = (mask * 255).astype(np.uint8)
    mask = 255 - mask
    image = np.concatenate([image, mask], axis=2, dtype=np.uint8)
    mask = mask.repeat(3, axis=2)
    # if new_background is not None:
    #     foreground = PIL.Image.fromarray(image)
    #     return mask, custom_background(new_background, foreground)
    return mask, image


footer = r"""
<center>
<b>
Demo based on <a href='https://github.com/SkyTNT/anime-segmentation'>SkyTNT Anime Segmentation</a>
</b>
</center>
"""

with gr.Blocks(title="Face Shine") as app:
    gr.HTML("<center><h1>Anime Remove Background</h1></center>")
    with gr.Row():
        with gr.Column():
            input_img = gr.Image(type="numpy", image_mode="RGB", label="Input image")
            new_img = gr.Image(type="pil", image_mode="RGBA", label="Custom background")
            run_btn = gr.Button(variant="primary")
        with gr.Column():
            with gr.Accordion(label="Image mask", open=False):
                output_mask = gr.Image(type="numpy", label="mask")
                output_trimap = gr.Image(type="numpy", label="trimap")
            output_img = gr.Image(type="numpy", label="result")

    run_btn.click(predict, [input_img, new_img], [output_mask, output_trimap, output_img])

    with gr.Row():
        examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
        examples = gr.Dataset(components=[input_img], samples=examples_data)
        examples.click(lambda x: x[0], [examples], [input_img])

    with gr.Row():
        gr.HTML(footer)

app.launch(share=False, debug=True, enable_queue=True, show_error=True)