Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
""" | |
End-to-End Referring Video Object Segmentation with Multimodal Transformers | |
This notebook provides a (limited) hands-on demonstration of MTTR. | |
Given a text query and a short clip based on a YouTube video, we demonstrate how MTTR can be used to segment the referred object instance throughout the video. | |
### Disclaimer | |
This is a **limited** demonstration of MTTR's performance. The model used here was trained **exclusively** on Refer-YouTube-VOS with window size `w=12` (as described in our paper). No additional training data was used whatsoever. | |
Hence, the model's performance may be limited, especially on instances from unseen categories. | |
Additionally, slow processing times may be encountered, depending on the input clip length and/or resolution, and due to HuggingFace's limited computational resources (no GPU acceleration unfortunately). | |
Finally, we emphasize that this demonstration is intended to be used for academic purposes only. We do not take any responsibility for how the created content is used or distributed. | |
""" | |
import gradio as gr | |
import torch | |
import torchvision | |
import torchvision.transforms.functional as F | |
from einops import rearrange | |
import numpy as np | |
from PIL import Image, ImageDraw, ImageOps, ImageFont | |
from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip | |
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip | |
from tqdm import trange, tqdm | |
class NestedTensor(object): | |
def __init__(self, tensors, mask): | |
self.tensors = tensors | |
self.mask = mask | |
def nested_tensor_from_videos_list(videos_list): | |
def _max_by_axis(the_list): | |
maxes = the_list[0] | |
for sublist in the_list[1:]: | |
for index, item in enumerate(sublist): | |
maxes[index] = max(maxes[index], item) | |
return maxes | |
max_size = _max_by_axis([list(img.shape) for img in videos_list]) | |
padded_batch_shape = [len(videos_list)] + max_size | |
b, t, c, h, w = padded_batch_shape | |
dtype = videos_list[0].dtype | |
device = videos_list[0].device | |
padded_videos = torch.zeros(padded_batch_shape, dtype=dtype, device=device) | |
videos_pad_masks = torch.ones((b, t, h, w), dtype=torch.bool, device=device) | |
for vid_frames, pad_vid_frames, vid_pad_m in zip(videos_list, padded_videos, videos_pad_masks): | |
pad_vid_frames[:vid_frames.shape[0], :, :vid_frames.shape[2], :vid_frames.shape[3]].copy_(vid_frames) | |
vid_pad_m[:vid_frames.shape[0], :vid_frames.shape[2], :vid_frames.shape[3]] = False | |
return NestedTensor(padded_videos.transpose(0, 1), videos_pad_masks.transpose(0, 1)) | |
def apply_mask(image, mask, color, transparency=0.7): | |
mask = mask[..., np.newaxis].repeat(repeats=3, axis=2) | |
mask = mask * transparency | |
color_matrix = np.ones(image.shape, dtype=np.float) * color | |
out_image = color_matrix * mask + image * (1.0 - mask) | |
return out_image | |
def process(text_query, full_video_path): | |
start_pt, end_pt = 0, 10 | |
input_clip_path = '/tmp/input.mp4' | |
# extract the relevant subclip: | |
with VideoFileClip(full_video_path) as video: | |
subclip = video.subclip(start_pt, end_pt) | |
subclip.write_videofile(input_clip_path) | |
checkpoint_path ='./refer-youtube-vos_window-12.pth.tar' | |
model, postprocessor = torch.hub.load('Randl/MTTR:main','mttr_refer_youtube_vos', get_weights=False) | |
model_state_dict = torch.load(checkpoint_path, map_location='cpu') | |
if 'model_state_dict' in model_state_dict.keys(): | |
model_state_dict = model_state_dict['model_state_dict'] | |
model.load_state_dict(model_state_dict, strict=True) | |
text_queries= [text_query] | |
window_length = 24 # length of window during inference | |
window_overlap = 6 # overlap (in frames) between consecutive windows | |
with torch.inference_mode(): | |
# read and preprocess the video clip: | |
video, audio, meta = torchvision.io.read_video(filename=input_clip_path) | |
video = rearrange(video, 't h w c -> t c h w') | |
input_video = F.resize(video, size=360, max_size=640) | |
input_video = input_video.to(torch.float).div_(255) | |
input_video = F.normalize(input_video, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
video_metadata = {'resized_frame_size': input_video.shape[-2:], 'original_frame_size': video.shape[-2:]} | |
# partition the clip into overlapping windows of frames: | |
windows = [input_video[i:i+window_length] for i in range(0, len(input_video), window_length - window_overlap)] | |
# clean up the text queries: | |
text_queries = [" ".join(q.lower().split()) for q in text_queries] | |
pred_masks_per_query = [] | |
t, _, h, w = video.shape | |
for text_query in tqdm(text_queries, desc='text queries'): | |
pred_masks = torch.zeros(size=(t, 1, h, w)) | |
for i, window in enumerate(tqdm(windows, desc='windows')): | |
window = nested_tensor_from_videos_list([window]) | |
valid_indices = torch.arange(len(window.tensors)) | |
outputs = model(window, valid_indices, [text_query]) | |
window_masks = postprocessor(outputs, [video_metadata], window.tensors.shape[-2:])[0]['pred_masks'] | |
win_start_idx = i*(window_length-window_overlap) | |
pred_masks[win_start_idx:win_start_idx + window_length] = window_masks | |
pred_masks_per_query.append(pred_masks) | |
"""Finally, we apply the generated instance masks and their corresponding text queries on the input clip for visualization:""" | |
# RGB colors for instance masks: | |
light_blue = (41, 171, 226) | |
purple = (237, 30, 121) | |
dark_green = (35, 161, 90) | |
orange = (255, 148, 59) | |
colors = np.array([light_blue, purple, dark_green, orange]) | |
# width (in pixels) of the black strip above the video on which the text queries will be displayed: | |
text_border_height_per_query = 40 | |
video_np = rearrange(video, 't c h w -> t h w c').numpy() / 255.0 | |
# del video | |
pred_masks_per_frame = rearrange(torch.stack(pred_masks_per_query), 'q t 1 h w -> t q h w').numpy() | |
masked_video = [] | |
for vid_frame, frame_masks in tqdm(zip(video_np, pred_masks_per_frame), total=len(video_np), desc='applying masks...'): | |
# apply the masks: | |
for inst_mask, color in zip(frame_masks, colors): | |
vid_frame = apply_mask(vid_frame, inst_mask, color / 255.0) | |
vid_frame = Image.fromarray((vid_frame * 255).astype(np.uint8)) | |
# visualize the text queries: | |
vid_frame = ImageOps.expand(vid_frame, border=(0, len(text_queries)*text_border_height_per_query, 0, 0)) | |
W, H = vid_frame.size | |
draw = ImageDraw.Draw(vid_frame) | |
font = ImageFont.truetype(font='LiberationSans-Regular.ttf', size=30) | |
for i, (text_query, color) in enumerate(zip(text_queries, colors), start=1): | |
w, h = draw.textsize(text_query, font=font) | |
draw.text(((W - w) / 2, (text_border_height_per_query * i) - h - 8), | |
text_query, fill=tuple(color) + (255,), font=font) | |
masked_video.append(np.array(vid_frame)) | |
# generate and save the output clip: | |
output_clip_path = '/tmp/output_clip.mp4' | |
clip = ImageSequenceClip(sequence=masked_video, fps=meta['video_fps']) | |
clip = clip.set_audio(AudioFileClip(input_clip_path)) | |
clip.write_videofile(output_clip_path, fps=meta['video_fps'], audio=True) | |
del masked_video | |
return output_clip_path | |
title = "End-to-End Referring Video Object Segmentation with Multimodal Transformers - Interactive Demo" | |
description = "This notebook provides a (limited) hands-on demonstration of MTTR.\n Given a text query and a short clip based on a YouTube video, we demonstrate how MTTR can be used to segment the referred object instance throughout the video. To use it, upload an .mp4 video file and input a text query which describes one of the instances in that video. \n Disclaimer: \n This is a **limited** demonstration of MTTR's performance. The model used here was trained **exclusively** on Refer-YouTube-VOS with window size `w=12` (as described in our paper). No additional training data was used whatsoever. Hence, the model's performance may be limited, especially on instances from unseen categories. Additionally, slow processing times may be encountered, depending on the input clip length and/or resolution, and due to HuggingFace's limited computational resources (no GPU acceleration unfortunately).Finally, we emphasize that this demonstration is intended to be used for academic purposes only. We do not take any responsibility for how the created content is used or distributed. " | |
article = "<p style='text-align: center'><a href='https://github.com/mttr2021/MTTR'>Github Repo</a></p>" | |
iface = gr.Interface(fn=process, | |
inputs=[gr.inputs.Textbox(label="text query"), gr.inputs.Video(label="Input video. First 10 seconds of the video are used.")], | |
outputs='video', | |
title=title, | |
description=description, | |
enable_queue=True, | |
examples=[['a black bike used to perform tricks', 'bike_tricks_1.mp4']], # Not working for some reason... | |
article=article) | |
iface.launch(debug=True) | |