referdino / app.py
liangtm's picture
Update app.py
33d471c verified
import gradio as gr
import os
import warnings
so_path = "models/GroundingDINO/ops/MultiScaleDeformableAttention.cpython-39-x86_64-linux-gnu.so"
if not os.path.exists(so_path):
os.system("python models/GroundingDINO/ops/setup.py build_ext develop --user")
import torchvision.transforms as T
from models import build_model
import torch
import misc as utils
import numpy as np
import torch.nn.functional as F
from torchvision.io import read_video
import torchvision.transforms.functional as Func
from ruamel.yaml import YAML
from easydict import EasyDict
from misc import nested_tensor_from_videos_list
from torch.cuda.amp import autocast
from PIL import Image, ImageDraw
import imageio.v3 as iio
import cv2
import tempfile
import argparse
import time
from huggingface_hub import hf_hub_download
os.environ["TOKENIZERS_PARALLELISM"] = "false"
DURATION = 6
CHECKPOINT = "ryt_mevis_swinb.pth"
# Transform for video frames
transform = T.Compose([
T.Resize(360),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Colormap
color_list = utils.colormap()
color_list = color_list.astype('uint8').tolist()
# Global model variable
model = None
def load_model_once(config_path, device='cpu'):
"""Load model once at startup"""
global model
if model is None:
# Create args object for model loading
with open(config_path) as f:
yaml = YAML(typ='safe', pure=True)
config = yaml.load(f)
config = {k: v['value'] for k, v in config.items()}
args = EasyDict(config)
args.device = device
model = build_model(args)
model.to(device)
cache_file = hf_hub_download(repo_id="liangtm/referdino", filename=CHECKPOINT)
# cache_file = 'ckpt/' + CHECKPOINT
checkpoint = torch.load(cache_file, map_location='cpu')
state_dict = checkpoint["model_state_dict"]
model.load_state_dict(state_dict, strict=False)
model.eval()
print("Model loaded successfully!")
return model
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x[:, 0], x[:, 1], x[:, 2], x[:, 3]
b = np.stack([
x_c - 0.5 * w,
y_c - 0.5 * h,
x_c + 0.5 * w,
y_c + 0.5 * h
], axis=1)
return b
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * np.array([img_w, img_h, img_w, img_h], dtype=np.float32)
return b
def vis_add_mask(img, mask, color, edge_width=3):
origin_img = np.asarray(img.convert('RGB')).copy()
color = np.array(color)
mask = mask.reshape(mask.shape[0], mask.shape[1]).astype('uint8')
mask = mask > 0.5
# Increase the edge width using dilation
kernel = np.ones((edge_width, edge_width), np.uint8)
mask_dilated = cv2.dilate(mask.astype(np.uint8), kernel, iterations=1).astype(bool)
edge_mask = mask_dilated & ~mask
origin_img[mask] = origin_img[mask] * 0.5 + color * 0.5
origin_img[edge_mask] = color
origin_img = Image.fromarray(origin_img)
return origin_img
def run_video_inference(input_video, text_prompt, tracking_alpha=0.1, fps=15):
"""Main inference function for Gradio"""
global model
model.tracking_alpha = tracking_alpha
# Set default values for other parameters
show_box = True
mask_edge_width = 6
if input_video is None:
return None, "Please upload a video file."
if not text_prompt or text_prompt.strip() == "":
return None, "Please enter a text prompt."
# Process text prompt
exp = " ".join(text_prompt.lower().split())
# Read video
video_frames, _, info = read_video(input_video, end_pts=DURATION, pts_unit='sec') # (T, H, W, C)
frame_step = max(round(info['video_fps'] / fps), 1)
frames = []
for i in range(0, len(video_frames), frame_step):
source_frame = Func.to_pil_image(video_frames[i].permute(2, 0, 1))
frames.append(source_frame)
video_len = len(frames)
if video_len == 0:
return None, "No frames found in the video."
frames_ids = [x for x in range(video_len)]
imgs = []
for t in frames_ids:
img = frames[t]
origin_w, origin_h = img.size
imgs.append(transform(img))
device = next(model.parameters()).device
imgs = torch.stack(imgs, dim=0).to(device)
samples = nested_tensor_from_videos_list(imgs[None], size_divisibility=16)
img_h, img_w = imgs.shape[-2:]
size = torch.as_tensor([int(img_h), int(img_w)]).to(device)
target = {"size": size}
start_infer = time.time()
# Run inference
with torch.no_grad():
with autocast(True):
outputs = model(samples, [exp], [target])
end_infer = time.time()
pred_logits = outputs["pred_logits"][0] # [t, q, k]
pred_masks = outputs["pred_masks"][0] # [t, q, h, w]
pred_boxes = outputs["pred_boxes"][0] # [t, q, 4]
# Select the query index according to pred_logits
pred_scores = pred_logits.sigmoid() # [t, q, k]
pred_scores = pred_scores.mean(0) # [q, K]
max_scores, _ = pred_scores.max(-1) # [q,]
_, max_ind = max_scores.max(-1) # [1,]
max_inds = max_ind.repeat(video_len)
pred_masks = pred_masks[range(video_len), max_inds, ...] # [t, h, w]
pred_masks = pred_masks.unsqueeze(0)
pred_boxes = pred_boxes[range(video_len), max_inds].cpu().numpy() # [t, 4]
# Unpad and resize
pred_masks = pred_masks[:, :, :img_h, :img_w].cpu()
pred_masks = F.interpolate(pred_masks, size=(origin_h, origin_w), mode='bilinear', align_corners=False)
pred_masks = (pred_masks.sigmoid() > 0.5).squeeze(0).cpu().numpy()
# Visualization
color = np.array([220, 20, 60], dtype=np.uint8)
start_save = time.time()
save_imgs = []
for t, img in enumerate(frames):
# Draw mask
img = vis_add_mask(img, pred_masks[t], color, mask_edge_width)
draw = ImageDraw.Draw(img)
draw_boxes = pred_boxes[t][None]
draw_boxes = rescale_bboxes(draw_boxes, (origin_w, origin_h)).tolist()
# Draw box if enabled
if show_box:
xmin, ymin, xmax, ymax = draw_boxes[0]
draw.rectangle(((xmin, ymin), (xmax, ymax)), outline=tuple(color), width=5)
save_imgs.append(np.asarray(img).copy())
# Save result video
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
iio.imwrite(tmp_file.name, save_imgs, fps=fps)
result_video_path = tmp_file.name
end_save = time.time()
status = (
f"Inference Time: {(end_infer - start_infer):.1f}s\n"
f"Saving Time: {(end_save - start_save):.1f}s"
)
return result_video_path, status
def main():
# Configuration
config_path = "configs/ytvos_swinb.yaml" # Update this path
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
# Load model at startup
print("Loading model...")
load_model_once(config_path, device)
print(f"Model loaded on device: {device}")
# Create Gradio interface
with gr.Blocks(
title="ReferDINO",
css="""
#hero { text-align: center; }
#hero h1, #hero h2, #hero h3, #hero p {
text-align: center !important;
margin: 0.25rem 0;
}
"""
) as demo:
gr.Markdown(
"""
<h1>Referring Video Object Segmentation with
<a href="https://github.com/iSEE-Laboratory/ReferDINO">ReferDINO</a>
</h1>
<h3>Note that this demo runs on CPU, so the video will be trimmed to ≀6 seconds.</h3>
""",
elem_id="hero",
)
with gr.Row():
with gr.Column(scale=1):
# Input components
input_video = gr.Video(
label="πŸ“Ή Upload Video",
height=300
)
text_prompt = gr.Textbox(
label="πŸ“ Text Description",
placeholder="Describe the object you want to segment (e.g., 'red car', 'person in blue shirt')",
lines=2
)
run_button = gr.Button(
"πŸš€ Run Inference",
variant="primary",
size="lg"
)
tracking_alpha = gr.Slider(
label="Momentum",
minimum=0.0,
maximum=1.0,
value=0.1,
step=0.05,
info="controls the memory updating (lower = longer memory)"
)
target_fps = gr.Slider(
label="FPS",
minimum=1,
maximum=30,
value=10,
step=1,
info="controls the FPS (lower = faster processing)"
)
with gr.Column(scale=1):
output_video = gr.Video(
label="🎯 Segmentation Result",
height=400
)
status_text = gr.Textbox(
label="πŸ“Š Status",
lines=3,
interactive=False
)
# Examples
gr.Examples(
examples=[
["dogs.mp4", "the dog is drinking water", 0.1, 10],
["dogs.mp4", "the dog is sleeping", 0.1, 10],
],
inputs=[input_video, text_prompt, tracking_alpha, target_fps],
outputs=[output_video],
fn=run_video_inference,
cache_examples=False,
label="πŸ“‹ Try these examples:"
)
# Event handlers
run_button.click(
fn=run_video_inference,
inputs=[input_video, text_prompt, tracking_alpha, target_fps],
outputs=[output_video, status_text],
show_progress=True
)
return demo
if __name__ == "__main__":
demo = main()
demo.launch(
show_api=False,
show_error=True
)