|
|
import gradio as gr |
|
|
import os |
|
|
import warnings |
|
|
|
|
|
so_path = "models/GroundingDINO/ops/MultiScaleDeformableAttention.cpython-39-x86_64-linux-gnu.so" |
|
|
if not os.path.exists(so_path): |
|
|
os.system("python models/GroundingDINO/ops/setup.py build_ext develop --user") |
|
|
|
|
|
import torchvision.transforms as T |
|
|
from models import build_model |
|
|
import torch |
|
|
import misc as utils |
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
from torchvision.io import read_video |
|
|
import torchvision.transforms.functional as Func |
|
|
from ruamel.yaml import YAML |
|
|
from easydict import EasyDict |
|
|
from misc import nested_tensor_from_videos_list |
|
|
from torch.cuda.amp import autocast |
|
|
from PIL import Image, ImageDraw |
|
|
import imageio.v3 as iio |
|
|
import cv2 |
|
|
import tempfile |
|
|
import argparse |
|
|
import time |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
DURATION = 6 |
|
|
CHECKPOINT = "ryt_mevis_swinb.pth" |
|
|
|
|
|
|
|
|
transform = T.Compose([ |
|
|
T.Resize(360), |
|
|
T.ToTensor(), |
|
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
color_list = utils.colormap() |
|
|
color_list = color_list.astype('uint8').tolist() |
|
|
|
|
|
|
|
|
model = None |
|
|
|
|
|
|
|
|
def load_model_once(config_path, device='cpu'): |
|
|
"""Load model once at startup""" |
|
|
global model |
|
|
if model is None: |
|
|
|
|
|
with open(config_path) as f: |
|
|
yaml = YAML(typ='safe', pure=True) |
|
|
config = yaml.load(f) |
|
|
config = {k: v['value'] for k, v in config.items()} |
|
|
|
|
|
args = EasyDict(config) |
|
|
args.device = device |
|
|
|
|
|
model = build_model(args) |
|
|
model.to(device) |
|
|
cache_file = hf_hub_download(repo_id="liangtm/referdino", filename=CHECKPOINT) |
|
|
|
|
|
checkpoint = torch.load(cache_file, map_location='cpu') |
|
|
state_dict = checkpoint["model_state_dict"] |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
model.eval() |
|
|
print("Model loaded successfully!") |
|
|
return model |
|
|
|
|
|
|
|
|
def box_cxcywh_to_xyxy(x): |
|
|
x_c, y_c, w, h = x[:, 0], x[:, 1], x[:, 2], x[:, 3] |
|
|
b = np.stack([ |
|
|
x_c - 0.5 * w, |
|
|
y_c - 0.5 * h, |
|
|
x_c + 0.5 * w, |
|
|
y_c + 0.5 * h |
|
|
], axis=1) |
|
|
return b |
|
|
|
|
|
|
|
|
def rescale_bboxes(out_bbox, size): |
|
|
img_w, img_h = size |
|
|
b = box_cxcywh_to_xyxy(out_bbox) |
|
|
b = b * np.array([img_w, img_h, img_w, img_h], dtype=np.float32) |
|
|
return b |
|
|
|
|
|
|
|
|
def vis_add_mask(img, mask, color, edge_width=3): |
|
|
origin_img = np.asarray(img.convert('RGB')).copy() |
|
|
color = np.array(color) |
|
|
|
|
|
mask = mask.reshape(mask.shape[0], mask.shape[1]).astype('uint8') |
|
|
mask = mask > 0.5 |
|
|
|
|
|
|
|
|
kernel = np.ones((edge_width, edge_width), np.uint8) |
|
|
mask_dilated = cv2.dilate(mask.astype(np.uint8), kernel, iterations=1).astype(bool) |
|
|
edge_mask = mask_dilated & ~mask |
|
|
|
|
|
origin_img[mask] = origin_img[mask] * 0.5 + color * 0.5 |
|
|
origin_img[edge_mask] = color |
|
|
origin_img = Image.fromarray(origin_img) |
|
|
return origin_img |
|
|
|
|
|
|
|
|
def run_video_inference(input_video, text_prompt, tracking_alpha=0.1, fps=15): |
|
|
"""Main inference function for Gradio""" |
|
|
global model |
|
|
model.tracking_alpha = tracking_alpha |
|
|
|
|
|
|
|
|
show_box = True |
|
|
mask_edge_width = 6 |
|
|
|
|
|
if input_video is None: |
|
|
return None, "Please upload a video file." |
|
|
|
|
|
if not text_prompt or text_prompt.strip() == "": |
|
|
return None, "Please enter a text prompt." |
|
|
|
|
|
|
|
|
exp = " ".join(text_prompt.lower().split()) |
|
|
|
|
|
|
|
|
video_frames, _, info = read_video(input_video, end_pts=DURATION, pts_unit='sec') |
|
|
|
|
|
frame_step = max(round(info['video_fps'] / fps), 1) |
|
|
|
|
|
frames = [] |
|
|
for i in range(0, len(video_frames), frame_step): |
|
|
source_frame = Func.to_pil_image(video_frames[i].permute(2, 0, 1)) |
|
|
frames.append(source_frame) |
|
|
|
|
|
video_len = len(frames) |
|
|
if video_len == 0: |
|
|
return None, "No frames found in the video." |
|
|
|
|
|
frames_ids = [x for x in range(video_len)] |
|
|
imgs = [] |
|
|
for t in frames_ids: |
|
|
img = frames[t] |
|
|
origin_w, origin_h = img.size |
|
|
imgs.append(transform(img)) |
|
|
|
|
|
device = next(model.parameters()).device |
|
|
imgs = torch.stack(imgs, dim=0).to(device) |
|
|
samples = nested_tensor_from_videos_list(imgs[None], size_divisibility=16) |
|
|
img_h, img_w = imgs.shape[-2:] |
|
|
size = torch.as_tensor([int(img_h), int(img_w)]).to(device) |
|
|
target = {"size": size} |
|
|
|
|
|
start_infer = time.time() |
|
|
|
|
|
with torch.no_grad(): |
|
|
with autocast(True): |
|
|
outputs = model(samples, [exp], [target]) |
|
|
end_infer = time.time() |
|
|
|
|
|
pred_logits = outputs["pred_logits"][0] |
|
|
pred_masks = outputs["pred_masks"][0] |
|
|
pred_boxes = outputs["pred_boxes"][0] |
|
|
|
|
|
|
|
|
pred_scores = pred_logits.sigmoid() |
|
|
pred_scores = pred_scores.mean(0) |
|
|
max_scores, _ = pred_scores.max(-1) |
|
|
_, max_ind = max_scores.max(-1) |
|
|
max_inds = max_ind.repeat(video_len) |
|
|
pred_masks = pred_masks[range(video_len), max_inds, ...] |
|
|
pred_masks = pred_masks.unsqueeze(0) |
|
|
pred_boxes = pred_boxes[range(video_len), max_inds].cpu().numpy() |
|
|
|
|
|
|
|
|
pred_masks = pred_masks[:, :, :img_h, :img_w].cpu() |
|
|
pred_masks = F.interpolate(pred_masks, size=(origin_h, origin_w), mode='bilinear', align_corners=False) |
|
|
pred_masks = (pred_masks.sigmoid() > 0.5).squeeze(0).cpu().numpy() |
|
|
|
|
|
|
|
|
color = np.array([220, 20, 60], dtype=np.uint8) |
|
|
|
|
|
start_save = time.time() |
|
|
save_imgs = [] |
|
|
for t, img in enumerate(frames): |
|
|
|
|
|
img = vis_add_mask(img, pred_masks[t], color, mask_edge_width) |
|
|
|
|
|
draw = ImageDraw.Draw(img) |
|
|
draw_boxes = pred_boxes[t][None] |
|
|
draw_boxes = rescale_bboxes(draw_boxes, (origin_w, origin_h)).tolist() |
|
|
|
|
|
|
|
|
if show_box: |
|
|
xmin, ymin, xmax, ymax = draw_boxes[0] |
|
|
draw.rectangle(((xmin, ymin), (xmax, ymax)), outline=tuple(color), width=5) |
|
|
|
|
|
save_imgs.append(np.asarray(img).copy()) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: |
|
|
iio.imwrite(tmp_file.name, save_imgs, fps=fps) |
|
|
result_video_path = tmp_file.name |
|
|
|
|
|
end_save = time.time() |
|
|
|
|
|
status = ( |
|
|
f"Inference Time: {(end_infer - start_infer):.1f}s\n" |
|
|
f"Saving Time: {(end_save - start_save):.1f}s" |
|
|
) |
|
|
return result_video_path, status |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
config_path = "configs/ytvos_swinb.yaml" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
|
|
|
print("Loading model...") |
|
|
load_model_once(config_path, device) |
|
|
print(f"Model loaded on device: {device}") |
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
title="ReferDINO", |
|
|
css=""" |
|
|
#hero { text-align: center; } |
|
|
#hero h1, #hero h2, #hero h3, #hero p { |
|
|
text-align: center !important; |
|
|
margin: 0.25rem 0; |
|
|
} |
|
|
""" |
|
|
) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
<h1>Referring Video Object Segmentation with |
|
|
<a href="https://github.com/iSEE-Laboratory/ReferDINO">ReferDINO</a> |
|
|
</h1> |
|
|
<h3>Note that this demo runs on CPU, so the video will be trimmed to β€6 seconds.</h3> |
|
|
""", |
|
|
elem_id="hero", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
|
|
|
input_video = gr.Video( |
|
|
label="πΉ Upload Video", |
|
|
height=300 |
|
|
) |
|
|
|
|
|
text_prompt = gr.Textbox( |
|
|
label="π Text Description", |
|
|
placeholder="Describe the object you want to segment (e.g., 'red car', 'person in blue shirt')", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
run_button = gr.Button( |
|
|
"π Run Inference", |
|
|
variant="primary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
tracking_alpha = gr.Slider( |
|
|
label="Momentum", |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.1, |
|
|
step=0.05, |
|
|
info="controls the memory updating (lower = longer memory)" |
|
|
) |
|
|
|
|
|
target_fps = gr.Slider( |
|
|
label="FPS", |
|
|
minimum=1, |
|
|
maximum=30, |
|
|
value=10, |
|
|
step=1, |
|
|
info="controls the FPS (lower = faster processing)" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
output_video = gr.Video( |
|
|
label="π― Segmentation Result", |
|
|
height=400 |
|
|
) |
|
|
|
|
|
status_text = gr.Textbox( |
|
|
label="π Status", |
|
|
lines=3, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["dogs.mp4", "the dog is drinking water", 0.1, 10], |
|
|
["dogs.mp4", "the dog is sleeping", 0.1, 10], |
|
|
], |
|
|
inputs=[input_video, text_prompt, tracking_alpha, target_fps], |
|
|
outputs=[output_video], |
|
|
fn=run_video_inference, |
|
|
cache_examples=False, |
|
|
label="π Try these examples:" |
|
|
) |
|
|
|
|
|
|
|
|
run_button.click( |
|
|
fn=run_video_inference, |
|
|
inputs=[input_video, text_prompt, tracking_alpha, target_fps], |
|
|
outputs=[output_video, status_text], |
|
|
show_progress=True |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = main() |
|
|
demo.launch( |
|
|
show_api=False, |
|
|
show_error=True |
|
|
) |