File size: 7,793 Bytes
d578b5a
 
 
 
 
 
 
 
 
 
 
9aeb2df
 
 
 
 
 
 
 
 
 
d578b5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11433a5
d578b5a
 
11433a5
d578b5a
 
 
 
 
 
 
 
85f6a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d578b5a
 
85f6a01
 
 
 
 
 
d578b5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85f6a01
 
 
 
 
 
d578b5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d5bddb
d578b5a
 
 
 
 
 
 
85f6a01
 
d578b5a
 
 
 
 
85f6a01
 
 
 
 
 
 
 
 
 
 
 
 
d578b5a
 
85f6a01
 
d578b5a
85f6a01
 
d578b5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85f6a01
9e392e3
85f6a01
 
 
 
 
 
d578b5a
 
85f6a01
 
d578b5a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
#!/usr/bin/env python

import os

import gradio as gr
import numpy as np
import PIL.Image
import spaces
import torch
from transformers import VitMatteForImageMatting, VitMatteImageProcessor

DESCRIPTION = """\
# [ViTMatte](https://github.com/hustvl/ViTMatte)

This is the demo for [ViTMatte](https://github.com/hustvl/ViTMatte), an image matting application.
You can matte any subject in a given image.

If you wish to replace background of the image, simply select the checkbox and drag and drop your background image.

You can draw your own foreground mask and unknown (border) mask using the canvas.
"""

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1500"))
MODEL_ID = os.getenv("MODEL_ID", "hustvl/vitmatte-small-distinctions-646")

processor = VitMatteImageProcessor.from_pretrained(MODEL_ID)
model = VitMatteForImageMatting.from_pretrained(MODEL_ID).to(device)


def check_image_size(image: PIL.Image.Image) -> None:
    if max(image.size) > MAX_IMAGE_SIZE:
        raise gr.Error(f"Image size is too large. Max image size is {MAX_IMAGE_SIZE} pixels.")


def binarize_mask(mask: np.ndarray) -> np.ndarray:
    mask[mask < 128] = 0
    mask[mask > 0] = 1
    return mask


def update_trimap(foreground_mask: dict[str, np.ndarray], unknown_mask: dict[str, np.ndarray]) -> np.ndarray:
    foreground = foreground_mask["mask"][:, :, 0]
    foreground = binarize_mask(foreground)

    unknown = unknown_mask["mask"][:, :, 0]
    unknown = binarize_mask(unknown)

    trimap = np.zeros_like(foreground)
    trimap[unknown > 0] = 128
    trimap[foreground > 0] = 255
    return trimap


def adjust_background_image(background_image: PIL.Image.Image, target_size: tuple[int, int]) -> PIL.Image.Image:
    target_w, target_h = target_size
    bg_w, bg_h = background_image.size

    scale = max(target_w / bg_w, target_h / bg_h)
    new_bg_w = int(bg_w * scale)
    new_bg_h = int(bg_h * scale)
    background_image = background_image.resize((new_bg_w, new_bg_h))
    left = (new_bg_w - target_w) // 2
    top = (new_bg_h - target_h) // 2
    right = left + target_w
    bottom = top + target_h
    background_image = background_image.crop((left, top, right, bottom))
    return background_image


def replace_background(
    image: PIL.Image.Image, alpha: np.ndarray, background_image: PIL.Image.Image | None
) -> PIL.Image.Image | None:
    if background_image is None:
        return None

    if image.mode != "RGB":
        raise gr.Error("Image must be RGB.")

    background_image = background_image.convert("RGB")
    background_image = adjust_background_image(background_image, image.size)

    image = np.array(image).astype(float) / 255
    background_image = np.array(background_image).astype(float) / 255
    result = image * alpha[:, :, None] + background_image * (1 - alpha[:, :, None])
    result = (result * 255).astype(np.uint8)
    return result


@spaces.GPU
@torch.inference_mode()
def run(
    image: PIL.Image.Image,
    trimap: PIL.Image.Image,
    apply_background_replacement: bool,
    background_image: PIL.Image.Image | None,
) -> tuple[np.ndarray, PIL.Image.Image, PIL.Image.Image | None]:
    if image.size != trimap.size:
        raise gr.Error("Image and trimap must have the same size.")
    if max(image.size) > MAX_IMAGE_SIZE:
        raise gr.Error(f"Image size is too large. Max image size is {MAX_IMAGE_SIZE} pixels.")
    if image.mode != "RGB":
        raise gr.Error("Image must be RGB.")
    if trimap.mode != "L":
        raise gr.Error("Trimap must be grayscale.")

    pixel_values = processor(images=image, trimaps=trimap, return_tensors="pt").to(device).pixel_values
    out = model(pixel_values=pixel_values)
    alpha = out.alphas[0, 0].to("cpu").numpy()

    w, h = image.size
    alpha = alpha[:h, :w]

    foreground = np.array(image).astype(float) / 255 * alpha[:, :, None] + (1 - alpha[:, :, None])
    foreground = (foreground * 255).astype(np.uint8)
    foreground = PIL.Image.fromarray(foreground)

    if apply_background_replacement:
        res_bg_replacement = replace_background(image, alpha, background_image)
    else:
        res_bg_replacement = None

    return alpha, foreground, res_bg_replacement


with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(
        value="Duplicate Space for private use",
        elem_id="duplicate-button",
        visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
    )

    with gr.Row():
        with gr.Column():
            with gr.Box():
                image = gr.Image(label="Input image", type="pil", height=500)
                with gr.Tabs():
                    with gr.Tab(label="Trimap"):
                        trimap = gr.Image(label="Trimap", type="pil", image_mode="L", height=500)
                    with gr.Tab(label="Draw trimap"):
                        load_image_button = gr.Button("Load image")
                        foreground_mask = gr.Image(
                            label="Foreground",
                            tool="sketch",
                            type="numpy",
                            brush_color="green",
                            mask_opacity=0.7,
                            height=500,
                        )
                        unknown_mask = gr.Image(
                            label="Unknown",
                            tool="sketch",
                            type="numpy",
                            brush_color="green",
                            mask_opacity=0.7,
                            height=500,
                        )
                        set_trimap_button = gr.Button("Set trimap")
                apply_background_replacement = gr.Checkbox(label="Apply background replacement", checked=False)
                background_image = gr.Image(label="Background image", type="pil", height=500, visible=False)
                run_button = gr.Button("Run")
        with gr.Column():
            with gr.Box():
                out_alpha = gr.Image(label="Alpha", height=500)
                out_foreground = gr.Image(label="Foreground", height=500)
                out_background_replacement = gr.Image(label="Background replacement", height=500, visible=False)

    inputs = [
        image,
        trimap,
        apply_background_replacement,
        background_image,
    ]
    outputs = [
        out_alpha,
        out_foreground,
        out_background_replacement,
    ]
    gr.Examples(
        examples=[
            ["assets/retriever_rgb.png", "assets/retriever_trimap.png", False, None],
            ["assets/bulb_rgb.png", "assets/bulb_trimap.png", True, "assets/new_bg.jpg"],
        ],
        inputs=inputs,
        outputs=outputs,
        fn=run,
        cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
    )

    image.change(
        fn=check_image_size,
        inputs=image,
        queue=False,
        api_name=False,
    )
    load_image_button.click(
        fn=lambda image: (image, image),
        inputs=image,
        outputs=[foreground_mask, unknown_mask],
        queue=False,
        api_name=False,
    )
    set_trimap_button.click(
        fn=update_trimap,
        inputs=[foreground_mask, unknown_mask],
        outputs=trimap,
        queue=False,
        api_name=False,
    )
    apply_background_replacement.change(
        fn=lambda checked: (gr.Image(visible=checked), gr.Image(visible=checked)),
        inputs=apply_background_replacement,
        outputs=[background_image, out_background_replacement],
        queue=False,
        api_name=False,
    )

    run_button.click(
        fn=run,
        inputs=inputs,
        outputs=outputs,
        api_name="run",
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch()