ViTMatte / app.py
hysts's picture
hysts HF staff
Update
9aeb2df
#!/usr/bin/env python
import os
import gradio as gr
import numpy as np
import PIL.Image
import spaces
import torch
from transformers import VitMatteForImageMatting, VitMatteImageProcessor
DESCRIPTION = """\
# [ViTMatte](https://github.com/hustvl/ViTMatte)
This is the demo for [ViTMatte](https://github.com/hustvl/ViTMatte), an image matting application.
You can matte any subject in a given image.
If you wish to replace background of the image, simply select the checkbox and drag and drop your background image.
You can draw your own foreground mask and unknown (border) mask using the canvas.
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1500"))
MODEL_ID = os.getenv("MODEL_ID", "hustvl/vitmatte-small-distinctions-646")
processor = VitMatteImageProcessor.from_pretrained(MODEL_ID)
model = VitMatteForImageMatting.from_pretrained(MODEL_ID).to(device)
def check_image_size(image: PIL.Image.Image) -> None:
if max(image.size) > MAX_IMAGE_SIZE:
raise gr.Error(f"Image size is too large. Max image size is {MAX_IMAGE_SIZE} pixels.")
def binarize_mask(mask: np.ndarray) -> np.ndarray:
mask[mask < 128] = 0
mask[mask > 0] = 1
return mask
def update_trimap(foreground_mask: dict[str, np.ndarray], unknown_mask: dict[str, np.ndarray]) -> np.ndarray:
foreground = foreground_mask["mask"][:, :, 0]
foreground = binarize_mask(foreground)
unknown = unknown_mask["mask"][:, :, 0]
unknown = binarize_mask(unknown)
trimap = np.zeros_like(foreground)
trimap[unknown > 0] = 128
trimap[foreground > 0] = 255
return trimap
def adjust_background_image(background_image: PIL.Image.Image, target_size: tuple[int, int]) -> PIL.Image.Image:
target_w, target_h = target_size
bg_w, bg_h = background_image.size
scale = max(target_w / bg_w, target_h / bg_h)
new_bg_w = int(bg_w * scale)
new_bg_h = int(bg_h * scale)
background_image = background_image.resize((new_bg_w, new_bg_h))
left = (new_bg_w - target_w) // 2
top = (new_bg_h - target_h) // 2
right = left + target_w
bottom = top + target_h
background_image = background_image.crop((left, top, right, bottom))
return background_image
def replace_background(
image: PIL.Image.Image, alpha: np.ndarray, background_image: PIL.Image.Image | None
) -> PIL.Image.Image | None:
if background_image is None:
return None
if image.mode != "RGB":
raise gr.Error("Image must be RGB.")
background_image = background_image.convert("RGB")
background_image = adjust_background_image(background_image, image.size)
image = np.array(image).astype(float) / 255
background_image = np.array(background_image).astype(float) / 255
result = image * alpha[:, :, None] + background_image * (1 - alpha[:, :, None])
result = (result * 255).astype(np.uint8)
return result
@spaces.GPU
@torch.inference_mode()
def run(
image: PIL.Image.Image,
trimap: PIL.Image.Image,
apply_background_replacement: bool,
background_image: PIL.Image.Image | None,
) -> tuple[np.ndarray, PIL.Image.Image, PIL.Image.Image | None]:
if image.size != trimap.size:
raise gr.Error("Image and trimap must have the same size.")
if max(image.size) > MAX_IMAGE_SIZE:
raise gr.Error(f"Image size is too large. Max image size is {MAX_IMAGE_SIZE} pixels.")
if image.mode != "RGB":
raise gr.Error("Image must be RGB.")
if trimap.mode != "L":
raise gr.Error("Trimap must be grayscale.")
pixel_values = processor(images=image, trimaps=trimap, return_tensors="pt").to(device).pixel_values
out = model(pixel_values=pixel_values)
alpha = out.alphas[0, 0].to("cpu").numpy()
w, h = image.size
alpha = alpha[:h, :w]
foreground = np.array(image).astype(float) / 255 * alpha[:, :, None] + (1 - alpha[:, :, None])
foreground = (foreground * 255).astype(np.uint8)
foreground = PIL.Image.fromarray(foreground)
if apply_background_replacement:
res_bg_replacement = replace_background(image, alpha, background_image)
else:
res_bg_replacement = None
return alpha, foreground, res_bg_replacement
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Row():
with gr.Column():
with gr.Box():
image = gr.Image(label="Input image", type="pil", height=500)
with gr.Tabs():
with gr.Tab(label="Trimap"):
trimap = gr.Image(label="Trimap", type="pil", image_mode="L", height=500)
with gr.Tab(label="Draw trimap"):
load_image_button = gr.Button("Load image")
foreground_mask = gr.Image(
label="Foreground",
tool="sketch",
type="numpy",
brush_color="green",
mask_opacity=0.7,
height=500,
)
unknown_mask = gr.Image(
label="Unknown",
tool="sketch",
type="numpy",
brush_color="green",
mask_opacity=0.7,
height=500,
)
set_trimap_button = gr.Button("Set trimap")
apply_background_replacement = gr.Checkbox(label="Apply background replacement", checked=False)
background_image = gr.Image(label="Background image", type="pil", height=500, visible=False)
run_button = gr.Button("Run")
with gr.Column():
with gr.Box():
out_alpha = gr.Image(label="Alpha", height=500)
out_foreground = gr.Image(label="Foreground", height=500)
out_background_replacement = gr.Image(label="Background replacement", height=500, visible=False)
inputs = [
image,
trimap,
apply_background_replacement,
background_image,
]
outputs = [
out_alpha,
out_foreground,
out_background_replacement,
]
gr.Examples(
examples=[
["assets/retriever_rgb.png", "assets/retriever_trimap.png", False, None],
["assets/bulb_rgb.png", "assets/bulb_trimap.png", True, "assets/new_bg.jpg"],
],
inputs=inputs,
outputs=outputs,
fn=run,
cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
)
image.change(
fn=check_image_size,
inputs=image,
queue=False,
api_name=False,
)
load_image_button.click(
fn=lambda image: (image, image),
inputs=image,
outputs=[foreground_mask, unknown_mask],
queue=False,
api_name=False,
)
set_trimap_button.click(
fn=update_trimap,
inputs=[foreground_mask, unknown_mask],
outputs=trimap,
queue=False,
api_name=False,
)
apply_background_replacement.change(
fn=lambda checked: (gr.Image(visible=checked), gr.Image(visible=checked)),
inputs=apply_background_replacement,
outputs=[background_image, out_background_replacement],
queue=False,
api_name=False,
)
run_button.click(
fn=run,
inputs=inputs,
outputs=outputs,
api_name="run",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()