#!/usr/bin/env python import os import gradio as gr import numpy as np import PIL.Image import spaces import torch from transformers import VitMatteForImageMatting, VitMatteImageProcessor DESCRIPTION = """\ # [ViTMatte](https://github.com/hustvl/ViTMatte) This is the demo for [ViTMatte](https://github.com/hustvl/ViTMatte), an image matting application. You can matte any subject in a given image. If you wish to replace background of the image, simply select the checkbox and drag and drop your background image. You can draw your own foreground mask and unknown (border) mask using the canvas. """ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1500")) MODEL_ID = os.getenv("MODEL_ID", "hustvl/vitmatte-small-distinctions-646") processor = VitMatteImageProcessor.from_pretrained(MODEL_ID) model = VitMatteForImageMatting.from_pretrained(MODEL_ID).to(device) def check_image_size(image: PIL.Image.Image) -> None: if max(image.size) > MAX_IMAGE_SIZE: raise gr.Error(f"Image size is too large. Max image size is {MAX_IMAGE_SIZE} pixels.") def binarize_mask(mask: np.ndarray) -> np.ndarray: mask[mask < 128] = 0 mask[mask > 0] = 1 return mask def update_trimap(foreground_mask: dict[str, np.ndarray], unknown_mask: dict[str, np.ndarray]) -> np.ndarray: foreground = foreground_mask["mask"][:, :, 0] foreground = binarize_mask(foreground) unknown = unknown_mask["mask"][:, :, 0] unknown = binarize_mask(unknown) trimap = np.zeros_like(foreground) trimap[unknown > 0] = 128 trimap[foreground > 0] = 255 return trimap def adjust_background_image(background_image: PIL.Image.Image, target_size: tuple[int, int]) -> PIL.Image.Image: target_w, target_h = target_size bg_w, bg_h = background_image.size scale = max(target_w / bg_w, target_h / bg_h) new_bg_w = int(bg_w * scale) new_bg_h = int(bg_h * scale) background_image = background_image.resize((new_bg_w, new_bg_h)) left = (new_bg_w - target_w) // 2 top = (new_bg_h - target_h) // 2 right = left + target_w bottom = top + target_h background_image = background_image.crop((left, top, right, bottom)) return background_image def replace_background( image: PIL.Image.Image, alpha: np.ndarray, background_image: PIL.Image.Image | None ) -> PIL.Image.Image | None: if background_image is None: return None if image.mode != "RGB": raise gr.Error("Image must be RGB.") background_image = background_image.convert("RGB") background_image = adjust_background_image(background_image, image.size) image = np.array(image).astype(float) / 255 background_image = np.array(background_image).astype(float) / 255 result = image * alpha[:, :, None] + background_image * (1 - alpha[:, :, None]) result = (result * 255).astype(np.uint8) return result @spaces.GPU @torch.inference_mode() def run( image: PIL.Image.Image, trimap: PIL.Image.Image, apply_background_replacement: bool, background_image: PIL.Image.Image | None, ) -> tuple[np.ndarray, PIL.Image.Image, PIL.Image.Image | None]: if image.size != trimap.size: raise gr.Error("Image and trimap must have the same size.") if max(image.size) > MAX_IMAGE_SIZE: raise gr.Error(f"Image size is too large. Max image size is {MAX_IMAGE_SIZE} pixels.") if image.mode != "RGB": raise gr.Error("Image must be RGB.") if trimap.mode != "L": raise gr.Error("Trimap must be grayscale.") pixel_values = processor(images=image, trimaps=trimap, return_tensors="pt").to(device).pixel_values out = model(pixel_values=pixel_values) alpha = out.alphas[0, 0].to("cpu").numpy() w, h = image.size alpha = alpha[:h, :w] foreground = np.array(image).astype(float) / 255 * alpha[:, :, None] + (1 - alpha[:, :, None]) foreground = (foreground * 255).astype(np.uint8) foreground = PIL.Image.fromarray(foreground) if apply_background_replacement: res_bg_replacement = replace_background(image, alpha, background_image) else: res_bg_replacement = None return alpha, foreground, res_bg_replacement with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) gr.DuplicateButton( value="Duplicate Space for private use", elem_id="duplicate-button", visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", ) with gr.Row(): with gr.Column(): with gr.Box(): image = gr.Image(label="Input image", type="pil", height=500) with gr.Tabs(): with gr.Tab(label="Trimap"): trimap = gr.Image(label="Trimap", type="pil", image_mode="L", height=500) with gr.Tab(label="Draw trimap"): load_image_button = gr.Button("Load image") foreground_mask = gr.Image( label="Foreground", tool="sketch", type="numpy", brush_color="green", mask_opacity=0.7, height=500, ) unknown_mask = gr.Image( label="Unknown", tool="sketch", type="numpy", brush_color="green", mask_opacity=0.7, height=500, ) set_trimap_button = gr.Button("Set trimap") apply_background_replacement = gr.Checkbox(label="Apply background replacement", checked=False) background_image = gr.Image(label="Background image", type="pil", height=500, visible=False) run_button = gr.Button("Run") with gr.Column(): with gr.Box(): out_alpha = gr.Image(label="Alpha", height=500) out_foreground = gr.Image(label="Foreground", height=500) out_background_replacement = gr.Image(label="Background replacement", height=500, visible=False) inputs = [ image, trimap, apply_background_replacement, background_image, ] outputs = [ out_alpha, out_foreground, out_background_replacement, ] gr.Examples( examples=[ ["assets/retriever_rgb.png", "assets/retriever_trimap.png", False, None], ["assets/bulb_rgb.png", "assets/bulb_trimap.png", True, "assets/new_bg.jpg"], ], inputs=inputs, outputs=outputs, fn=run, cache_examples=os.getenv("CACHE_EXAMPLES") == "1", ) image.change( fn=check_image_size, inputs=image, queue=False, api_name=False, ) load_image_button.click( fn=lambda image: (image, image), inputs=image, outputs=[foreground_mask, unknown_mask], queue=False, api_name=False, ) set_trimap_button.click( fn=update_trimap, inputs=[foreground_mask, unknown_mask], outputs=trimap, queue=False, api_name=False, ) apply_background_replacement.change( fn=lambda checked: (gr.Image(visible=checked), gr.Image(visible=checked)), inputs=apply_background_replacement, outputs=[background_image, out_background_replacement], queue=False, api_name=False, ) run_button.click( fn=run, inputs=inputs, outputs=outputs, api_name="run", ) if __name__ == "__main__": demo.queue(max_size=20).launch()