ObjectClear / app.py
jixin0101's picture
Change pipeline loading method
1854c04
import gradio as gr
import os
from PIL import Image
import torch
from diffusers.utils import load_image, check_min_version
from pipeline_objectclear import ObjectClearPipeline
from tools.download_util import load_file_from_url
from tools.painter import mask_painter
import argparse
from safetensors.torch import load_file
from model import CLIPImageEncoder, PostfuseModule
import numpy as np
import torchvision.transforms.functional as TF
from scipy.ndimage import convolve, zoom
import cv2
import time
from huggingface_hub import hf_hub_download
import spaces
from tools.interact_tools import SamControler
from tools.misc import get_device
import json
check_min_version("0.30.2")
def parse_augment():
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default=None)
parser.add_argument('--sam_model_type', type=str, default="vit_h")
parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications")
args = parser.parse_args()
if not args.device:
args.device = str(get_device())
return args
def pad_to_multiple(image: np.ndarray, multiple: int = 8):
h, w = image.shape[:2]
pad_h = (multiple - h % multiple) % multiple
pad_w = (multiple - w % multiple) % multiple
if image.ndim == 3:
padded = np.pad(image, ((0, pad_h), (0, pad_w), (0,0)), mode='reflect')
else:
padded = np.pad(image, ((0, pad_h), (0, pad_w)), mode='reflect')
return padded, h, w
def crop_to_original(image: np.ndarray, h: int, w: int):
return image[:h, :w]
def wavelet_blur_np(image: np.ndarray, radius: int):
kernel = np.array([
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625]
], dtype=np.float32)
blurred = np.empty_like(image)
for c in range(image.shape[0]):
blurred_c = convolve(image[c], kernel, mode='nearest')
if radius > 1:
blurred_c = zoom(zoom(blurred_c, 1 / radius, order=1), radius, order=1)
blurred[c] = blurred_c
return blurred
def wavelet_decomposition_np(image: np.ndarray, levels=5):
high_freq = np.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur_np(image, radius)
high_freq += (image - low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction_np(content_feat: np.ndarray, style_feat: np.ndarray):
content_high, _ = wavelet_decomposition_np(content_feat)
_, style_low = wavelet_decomposition_np(style_feat)
return content_high + style_low
def wavelet_color_fix_np(fused: np.ndarray, mask: np.ndarray) -> np.ndarray:
fused_np = fused.astype(np.float32) / 255.0
mask_np = mask.astype(np.float32) / 255.0
fused_np = fused_np.transpose(2, 0, 1)
mask_np = mask_np.transpose(2, 0, 1)
result_np = wavelet_reconstruction_np(fused_np, mask_np)
result_np = result_np.transpose(1, 2, 0)
result_np = np.clip(result_np * 255.0, 0, 255).astype(np.uint8)
return result_np
def fuse_with_wavelet(ori: np.ndarray, removed: np.ndarray, attn_map: np.ndarray, multiple: int = 8):
H, W = ori.shape[:2]
attn_map = attn_map.astype(np.float32)
_, attn_map = cv2.threshold(attn_map, 128, 255, cv2.THRESH_BINARY)
am = attn_map.astype(np.float32)
am = am/255.0
am_up = cv2.resize(am, (W, H), interpolation=cv2.INTER_NEAREST)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (21,21))
am_d = cv2.dilate(am_up, kernel, iterations=1)
am_d = cv2.GaussianBlur(am_d.astype(np.float32), (9,9), sigmaX=2)
am_merged = np.maximum(am_up, am_d)
am_merged = np.clip(am_merged, 0, 1)
attn_up_3c = np.stack([am_merged]*3, axis=-1)
attn_up_ori_3c = np.stack([am_up]*3, axis=-1)
ori_out = ori * (1 - attn_up_ori_3c)
rem_out = removed * (1 - attn_up_ori_3c)
ori_pad, h0, w0 = pad_to_multiple(ori_out, multiple)
rem_pad, _, _ = pad_to_multiple(rem_out, multiple)
wave_rgb = wavelet_color_fix_np(ori_pad, rem_pad)
wave = crop_to_original(wave_rgb, h0, w0)
# fusion
fused = (wave * (1 - attn_up_3c) + removed * attn_up_3c).astype(np.uint8)
return fused
def resize_by_short_side(image, target_short=512, resample=Image.BICUBIC):
w, h = image.size
if w < h:
new_w = target_short
new_h = int(h * target_short / w)
new_h = (new_h + 15) // 16 * 16
else:
new_h = target_short
new_w = int(w * target_short / h)
new_w = (new_w + 15) // 16 * 16
return image.resize((new_w, new_h), resample=resample)
# convert points input to prompt state
def get_prompt(click_state, click_input):
inputs = json.loads(click_input)
points = click_state[0]
labels = click_state[1]
for input in inputs:
points.append(input[:2])
labels.append(input[2])
click_state[0] = points
click_state[1] = labels
prompt = {
"prompt_type":["click"],
"input_point":click_state[0],
"input_label":click_state[1],
"multimask_output":"True",
}
return prompt
# use sam to get the mask
@spaces.GPU
def sam_refine(image_state, point_prompt, click_state, evt:gr.SelectData):
if point_prompt == "Positive":
coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
else:
coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
# prompt for sam model
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(image_state["origin_image"])
prompt = get_prompt(click_state=click_state, click_input=coordinate)
mask, logit, painted_image = model.first_frame_click(
image=image_state["origin_image"],
points=np.array(prompt["input_point"]),
labels=np.array(prompt["input_label"]),
multimask=prompt["multimask_output"],
)
image_state["mask"] = mask
image_state["logit"] = logit
image_state["painted_image"] = painted_image
return painted_image, image_state, click_state
def add_multi_mask(image_state, interactive_state, mask_dropdown):
mask = image_state["mask"]
interactive_state["masks"].append(mask)
interactive_state["mask_names"].append("mask_{:03d}".format(len(interactive_state["masks"])))
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["masks"])))
select_frame = show_mask(image_state, interactive_state, mask_dropdown)
return interactive_state, gr.update(choices=interactive_state["mask_names"], value=mask_dropdown), select_frame, [[],[]]
def clear_click(image_state, click_state):
click_state = [[],[]]
input_image = image_state["origin_image"]
return input_image, click_state
def remove_multi_mask(interactive_state, click_state, image_state):
interactive_state["mask_names"]= []
interactive_state["masks"] = []
click_state = [[],[]]
input_image = image_state["origin_image"]
return interactive_state, gr.update(choices=[],value=[]), input_image, click_state
def show_mask(image_state, interactive_state, mask_dropdown):
mask_dropdown.sort()
if image_state["origin_image"] is not None:
select_frame = image_state["origin_image"]
for i in range(len(mask_dropdown)):
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
mask = interactive_state["masks"][mask_number]
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
return select_frame
@spaces.GPU
def upload_and_reset(image_input, interactive_state):
click_state = [[], []]
interactive_state["mask_names"]= []
interactive_state["masks"] = []
image_state, image_info, image_input = update_image_state_on_upload(image_input)
return (
image_state,
image_info,
image_input,
interactive_state,
click_state,
gr.update(choices=[], value=[]),
)
def update_image_state_on_upload(image_input):
frame = image_input
image_size = (frame.size[1], frame.size[0])
frame_np = np.array(frame)
image_state = {
"origin_image": frame_np,
"painted_image": frame_np.copy(),
"mask": np.zeros((image_size[0], image_size[1]), np.uint8),
"logit": None,
}
image_info = f"Image Name: uploaded.png,\nImage Size: {image_size}"
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(frame_np)
return image_state, image_info, image_input
# SAM generator
class MaskGenerator():
def __init__(self, sam_checkpoint, args):
self.args = args
self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
return mask, logit, painted_image
# args, defined in track_anything.py
args = parse_augment()
sam_checkpoint_url_dict = {
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
}
checkpoint_folder = os.path.join('/home/user/app/', 'pretrained_models')
sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_folder)
# initialize sams
model = MaskGenerator(sam_checkpoint, args)
# Build pipeline
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pipe = ObjectClearPipeline.from_pretrained_with_custom_modules(
"jixin0101/ObjectClear",
torch_dtype=torch.float16,
save_cross_attn=True,
cache_dir="/home/jovyan/shared/jixinzhao/models",
)
pipe.to(device)
@spaces.GPU
def process(image_state, interactive_state, mask_dropdown, guidance_scale, seed, num_inference_steps, strength
):
generator = torch.Generator(device="cuda").manual_seed(seed)
image_np = image_state["origin_image"]
image = Image.fromarray(image_np)
if interactive_state["masks"]:
if len(mask_dropdown) == 0:
mask_dropdown = ["mask_001"]
mask_dropdown.sort()
template_mask = interactive_state["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
for i in range(1,len(mask_dropdown)):
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
template_mask = np.clip(template_mask+interactive_state["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
image_state["mask"]= template_mask
else:
template_mask = image_state["mask"]
mask = Image.fromarray((template_mask).astype(np.uint8) * 255)
image_or = image.copy()
image = image.convert("RGB")
mask = mask.convert("RGB")
image = resize_by_short_side(image, 512, resample=Image.BICUBIC)
mask = resize_by_short_side(mask, 512, resample=Image.NEAREST)
w, h = image.size
result = pipe(
prompt="remove the instance of object",
image=image,
mask_image=mask,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
guidance_scale=guidance_scale,
height=h,
width=w,
)
inpainted_img = result[0].images[0]
attn_map = result[1]
attn_np = attn_map.mean(dim=1)[0].cpu().numpy() * 255.
fused_img = fuse_with_wavelet(np.array(image), np.array(inpainted_img), attn_np)
fused_img_pil = Image.fromarray(fused_img.astype(np.uint8))
return fused_img_pil.resize((image_or.size[:2])), (image.resize((image_or.size[:2])), fused_img_pil.resize((image_or.size[:2])))
import base64
with open("./Logo.png", "rb") as f:
img_bytes = f.read()
img_b64 = base64.b64encode(img_bytes).decode()
html_img = f'''
<div style="display:flex; justify-content:center; align-items:center; width:100%;">
<img src="data:image/png;base64,{img_b64}" style="border:none; width:200px; height:auto;"/>
</div>
'''
tutorial_url = "https://github.com/zjx0101/ObjectClear/releases/download/media/tutorial.mp4"
assets_path = os.path.join('/home/user/app/hugging_face/', "assets/")
load_file_from_url(tutorial_url, assets_path)
description = r"""
<b>Official Gradio demo</b> for <a href='https://github.com/zjx0101/ObjectClear' target='_blank'><b>ObjectClear: Complete Object Removal via Object-Effect Attention</b></a>.<br>
🔥 ObjectClear is an object removal model that can jointly eliminate the target object and its associated effects leveraging Object-Effect Attention, while preserving background consistency.<br>
🖼️ Try to drop your image, assign the target masks with a few clicks, and get the object removal results!<br>
*Note: Due to online GPU memory constraints, all input images will be resized during inference so that the shortest side is 512 pixels.<br>*
"""
article = r"""<h3>
<b>If ObjectClear is helpful, please help to star the <a href='https://github.com/zjx0101/ObjectClear' target='_blank'>Github Repo</a>. Thanks!</b></h3>
<hr>
📑 **Citation**
<br>
If our work is useful for your research, please consider citing:
```bibtex
@InProceedings{zhao2025ObjectClear,
title = {{ObjectClear}: Complete Object Removal via Object-Effect Attention},
author = {Zhao, Jixin and Zhou, Shangchen and Wang, Zhouxia and Yang, Peiqing and Loy, Chen Change},
booktitle = {arXiv preprint arXiv:2505.22636},
year = {2025}
}
```
📧 **Contact**
<br>
If you have any questions, please feel free to reach me out at <b>jixinzhao0101@gmail.com</b>.
<br>
👏 **Acknowledgement**
<br>
This demo is adapted from [MatAnyone](https://github.com/pq-yang/MatAnyone), and leveraging segmentation capabilities from [Segment Anything](https://github.com/facebookresearch/segment-anything). Thanks for their awesome works!
"""
custom_css = """
#input-image {
aspect-ratio: 1 / 1;
width: 100%;
max-width: 100%;
height: auto;
display: flex;
align-items: center;
justify-content: center;
}
#input-image img {
max-width: 100%;
max-height: 100%;
object-fit: contain;
display: block;
}
#main-columns {
gap: 60px;
}
#main-columns > .gr-column {
flex: 1;
}
#compare-image {
width: 100%;
aspect-ratio: 1 / 1;
display: flex;
align-items: center;
justify-content: center;
margin: 0;
padding: 0;
max-width: 100%;
box-sizing: border-box;
}
#compare-image svg.svelte-zyxd38 {
position: absolute !important;
top: 50% !important;
left: 50% !important;
transform: translate(-50%, -50%) !important;
}
#compare-image .icon.svelte-1oiin9d {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
}
#compare-image {
position: relative;
overflow: hidden;
}
.new_button {background-color: #171717 !important; color: #ffffff !important; border: none !important;}
.new_button:hover {background-color: #4b4b4b !important;}
#start-button {
background: linear-gradient(135deg, #2575fc 0%, #6a11cb 100%);
color: white;
border: none;
padding: 12px 24px;
font-size: 16px;
font-weight: bold;
border-radius: 12px;
cursor: pointer;
box-shadow: 0 0 12px rgba(100, 100, 255, 0.7);
transition: all 0.3s ease;
}
#start-button:hover {
transform: scale(1.05);
box-shadow: 0 0 20px rgba(100, 100, 255, 1);
}
<style>
.button-wrapper {
width: 30%;
text-align: center;
}
.wide-button {
width: 83% !important;
background-color: black !important;
color: white !important;
border: none !important;
padding: 8px 0 !important;
font-size: 16px !important;
display: inline-block;
margin: 30px 0px 0px 50px ;
}
.wide-button:hover {
background-color: #656262 !important;
}
</style>
"""
with gr.Blocks(css=custom_css) as demo:
gr.HTML(html_img)
gr.Markdown(description)
with gr.Group(elem_classes="gr-monochrome-group", visible=True):
with gr.Row():
with gr.Accordion('SAM Settings (click to expand)', open=False):
with gr.Row():
point_prompt = gr.Radio(
choices=["Positive", "Negative"],
value="Positive",
label="Point Prompt",
info="Click to add positive or negative point for target mask",
interactive=True,
min_width=100,
scale=1)
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2")
with gr.Row(elem_id="main-columns"):
with gr.Column():
click_state = gr.State([[],[]])
interactive_state = gr.State(
{
"mask_names": [],
"masks": []
}
)
image_state = gr.State(
{
"origin_image": None,
"painted_image": None,
"mask": None,
"logit": None
}
)
image_info = gr.Textbox(label="Image Info", visible=False)
input_image = gr.Image(
label='Input',
type='pil',
sources=["upload"],
image_mode='RGB',
interactive=True,
elem_id="input-image"
)
with gr.Row(equal_height=True, elem_classes="mask_button_group"):
clear_button_click = gr.Button(value="Clear Clicks",elem_classes="new_button", min_width=100)
add_mask_button = gr.Button(value="Add Mask", elem_classes="new_button", min_width=100)
remove_mask_button = gr.Button(value="Delete Mask", elem_classes="new_button", min_width=100)
submit_button_component = gr.Button(
value='Start ObjectClear', elem_id="start-button"
)
with gr.Accordion('ObjectClear Settings', open=True):
strength = gr.Radio(
choices=[0.99, 1.0],
value=0.99,
label="Strength",
info="0.99 better preserves the background and color; use 1.0 if object/shadow is not fully removed (default: 0.99)"
)
guidance_scale = gr.Slider(
minimum=1, maximum=10, step=0.5, value=2.5,
label="Guidance Scale",
info="Higher = stronger removal; lower = better background preservation (default: 2.5)"
)
seed = gr.Slider(
minimum=0, maximum=1000000, step=1, value=300000,
label="Seed Value",
info="Different seeds can lead to noticeably different object removal results (default: 300000)"
)
num_inference_steps = gr.Slider(
minimum=1, maximum=40, step=1, value=20,
label="Num Inference Steps",
info="Higher values may improve quality but take longer (default: 20)"
)
with gr.Column():
output_image_component = gr.Image(
type='pil', image_mode='RGB', label='Output', format="png", elem_id="input-image")
output_compare_image_component = gr.ImageSlider(
label="Comparison",
type="pil",
format='png',
elem_id="compare-image"
)
input_image.upload(
fn=upload_and_reset,
inputs=[input_image, interactive_state],
outputs=[
image_state,
image_info,
input_image,
interactive_state,
click_state,
mask_dropdown,
]
)
# click select image to get mask using sam
input_image.select(
fn=sam_refine,
inputs=[image_state, point_prompt, click_state],
outputs=[input_image, image_state, click_state]
)
# add different mask
add_mask_button.click(
fn=add_multi_mask,
inputs=[image_state, interactive_state, mask_dropdown],
outputs=[interactive_state, mask_dropdown, input_image, click_state]
)
remove_mask_button.click(
fn=remove_multi_mask,
inputs=[interactive_state, click_state, image_state],
outputs=[interactive_state, mask_dropdown, input_image, click_state]
)
# points clear
clear_button_click.click(
fn = clear_click,
inputs = [image_state, click_state,],
outputs = [input_image, click_state],
)
submit_button_component.click(
fn=process,
inputs=[
image_state,
interactive_state,
mask_dropdown,
guidance_scale,
seed,
num_inference_steps,
strength
],
outputs=[
output_image_component, output_compare_image_component
]
)
with gr.Accordion("📕 Video Tutorial (click to expand)", open=False, elem_classes="custom-bg"):
with gr.Row():
gr.Video(value="/home/user/app/hugging_face/assets/tutorial.mp4", elem_classes="video")
gr.Markdown("---")
gr.Markdown("## Examples")
example_images = [
os.path.join(os.path.dirname(__file__), "examples", f"test{i}.png")
for i in range(10)
]
examples_data = [
[example_images[i], None] for i in range(len(example_images))
]
examples = gr.Examples(
examples=examples_data,
inputs=[input_image, interactive_state],
outputs=[image_state, image_info, input_image,
interactive_state, click_state, mask_dropdown],
fn=upload_and_reset,
run_on_click=True,
cache_examples=False,
label="Click below to load example images"
)
gr.Markdown(article)
def pre_update_input_image():
return gr.update(value=None)
demo.load(
fn=pre_update_input_image,
inputs=[],
outputs=[input_image]
)
demo.launch(debug=True, show_error=True)