# -*- coding: utf-8 -*- import gradio as gr import torch import torchvision import torchvision.transforms.functional as F from einops import rearrange import numpy as np from PIL import Image, ImageDraw, ImageOps, ImageFont from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip from tqdm import trange, tqdm class NestedTensor(object): def __init__(self, tensors, mask): self.tensors = tensors self.mask = mask def nested_tensor_from_videos_list(videos_list): def _max_by_axis(the_list): maxes = the_list[0] for sublist in the_list[1:]: for index, item in enumerate(sublist): maxes[index] = max(maxes[index], item) return maxes max_size = _max_by_axis([list(img.shape) for img in videos_list]) padded_batch_shape = [len(videos_list)] + max_size b, t, c, h, w = padded_batch_shape dtype = videos_list[0].dtype device = videos_list[0].device padded_videos = torch.zeros(padded_batch_shape, dtype=dtype, device=device) videos_pad_masks = torch.ones((b, t, h, w), dtype=torch.bool, device=device) for vid_frames, pad_vid_frames, vid_pad_m in zip(videos_list, padded_videos, videos_pad_masks): pad_vid_frames[:vid_frames.shape[0], :, :vid_frames.shape[2], :vid_frames.shape[3]].copy_(vid_frames) vid_pad_m[:vid_frames.shape[0], :vid_frames.shape[2], :vid_frames.shape[3]] = False return NestedTensor(padded_videos.transpose(0, 1), videos_pad_masks.transpose(0, 1)) def apply_mask(image, mask, color, transparency=0.7): mask = mask[..., np.newaxis].repeat(repeats=3, axis=2) mask = mask * transparency color_matrix = np.ones(image.shape, dtype=np.float) * color out_image = color_matrix * mask + image * (1.0 - mask) return out_image def process(text_query, full_video_path): start_pt, max_end_pt = 0, 10 input_clip_path = '/tmp/input.mp4' # extract the relevant subclip: full_video = VideoFileClip(full_video_path) subclip = full_video.subclip(start_pt, min(full_video.duration, max_end_pt)) subclip.write_videofile(input_clip_path) checkpoint_path ='./refer-youtube-vos_window-12.pth.tar' model, postprocessor = torch.hub.load('mttr2021/MTTR:main','mttr_refer_youtube_vos', get_weights=False) model_state_dict = torch.load(checkpoint_path, map_location='cpu') if 'model_state_dict' in model_state_dict.keys(): model_state_dict = model_state_dict['model_state_dict'] model.load_state_dict(model_state_dict, strict=True) text_queries= [text_query] window_length = 24 # length of window during inference window_overlap = 6 # overlap (in frames) between consecutive windows with torch.inference_mode(): # read and preprocess the video clip: video, audio, meta = torchvision.io.read_video(filename=input_clip_path) video = rearrange(video, 't h w c -> t c h w') input_video = F.resize(video, size=360, max_size=640) input_video = input_video.to(torch.float).div_(255) input_video = F.normalize(input_video, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) video_metadata = {'resized_frame_size': input_video.shape[-2:], 'original_frame_size': video.shape[-2:]} # partition the clip into overlapping windows of frames: windows = [input_video[i:i+window_length] for i in range(0, len(input_video), window_length - window_overlap)] # clean up the text queries: text_queries = [" ".join(q.lower().split()) for q in text_queries] pred_masks_per_query = [] t, _, h, w = video.shape for text_query in tqdm(text_queries, desc='text queries'): pred_masks = torch.zeros(size=(t, 1, h, w)) for i, window in enumerate(tqdm(windows, desc='windows')): window = nested_tensor_from_videos_list([window]) valid_indices = torch.arange(len(window.tensors)) outputs = model(window, valid_indices, [text_query]) window_masks = postprocessor(outputs, [video_metadata], window.tensors.shape[-2:])[0]['pred_masks'] win_start_idx = i*(window_length-window_overlap) pred_masks[win_start_idx:win_start_idx + window_length] = window_masks pred_masks_per_query.append(pred_masks) """Finally, we apply the generated instance masks and their corresponding text queries on the input clip for visualization:""" # RGB colors for instance masks: light_blue = (41, 171, 226) purple = (237, 30, 121) dark_green = (35, 161, 90) orange = (255, 148, 59) colors = np.array([light_blue, purple, dark_green, orange]) # width (in pixels) of the black strip above the video on which the text queries will be displayed: text_border_height_per_query = 40 video_np = rearrange(video, 't c h w -> t h w c').numpy() / 255.0 # del video pred_masks_per_frame = rearrange(torch.stack(pred_masks_per_query), 'q t 1 h w -> t q h w').numpy() masked_video = [] for vid_frame, frame_masks in tqdm(zip(video_np, pred_masks_per_frame), total=len(video_np), desc='applying masks...'): # apply the masks: for inst_mask, color in zip(frame_masks, colors): vid_frame = apply_mask(vid_frame, inst_mask, color / 255.0) vid_frame = Image.fromarray((vid_frame * 255).astype(np.uint8)) # visualize the text queries: vid_frame = ImageOps.expand(vid_frame, border=(0, len(text_queries)*text_border_height_per_query, 0, 0)) W, H = vid_frame.size draw = ImageDraw.Draw(vid_frame) font = ImageFont.truetype(font='LiberationSans-Regular.ttf', size=30) for i, (text_query, color) in enumerate(zip(text_queries, colors), start=1): w, h = draw.textsize(text_query, font=font) draw.text(((W - w) / 2, (text_border_height_per_query * i) - h - 8), text_query, fill=tuple(color) + (255,), font=font) masked_video.append(np.array(vid_frame)) # generate and save the output clip: output_clip_path = '/tmp/output_clip.mp4' clip = ImageSequenceClip(sequence=masked_video, fps=meta['video_fps']) if subclip.audio: # attach audio if original subclip had audio clip = clip.set_audio(subclip.audio) clip.write_videofile(output_clip_path, fps=meta['video_fps'], audio=True) del masked_video return output_clip_path title = "MTTR - Interactive Demo" description = "Given a text query and a short clip based on a YouTube video, we demonstrate how MTTR can be used to segment the referred object instance throughout the video. Select one of the examples below and click 'submit'. Alternatively, try using your own input by uploading a short .mp4 video file and entering a short text query which describes one of the object instances in that video. Note - Due to HuggingFace's limited computational resources (no GPU acceleration unfortunately), processing times may take several minutes, so please be patient. Check out our Colab notebook (link below) for much faster processing times (GPU acceleration available) and more options." article = "Check out [MTTR's GitHub page](https://github.com/mttr2021/MTTR) for more info about this project.
Also, check out our interactive [Colab notebook](https://colab.research.google.com/drive/12p0jpSx3pJNfZk-y_L44yeHZlhsKVra-?usp=sharing) for **much faster** processing (GPU accelerated) and more options!
**Disclaimer:**
This is a **limited** demonstration of MTTR's performance. The model used here was trained **exclusively** on Refer-YouTube-VOS with window size `w=12` (as described in our paper). No additional training data was used whatsoever. Hence, the model's performance may be limited, especially on instances from unseen categories.
Additionally, slow processing times may be encountered due to HuggingFace's limited computational resources (no GPU acceleration unfortunately), and depending on the input clip length and/or resolution.
Finally, we emphasize that this demonstration is intended to be used for academic purposes only. We do not take any responsibility for how the created content is used or distributed." examples = [['guy in white shirt performing tricks on a bike', 'bike_tricks_2.mp4'], ['a man riding a surfboard', 'surfing.mp4'], ['a guy performing tricks on a skateboard', 'skateboarding.mp4'], ['man in red shirt playing tennis', 'tennis.mp4'], ['brown and black dog playing', 'dogs_playing_1.mp4'], ['a dog to the left playing with a toy', 'dogs_playing_2.mp4'], ['person in blue riding a bike', 'blue_biker_riding.mp4'], ['a dog to the right', 'dog_and_cat.mp4'], ['a person hugging a dog', 'girl_hugging_dog.mp4'], ['a black bike used to perform tricks', 'bike_tricks_1.mp4']] iface = gr.Interface(fn=process, inputs=[gr.inputs.Textbox(label="text query"), gr.inputs.Video(label="input video - first 10 seconds are used")], outputs='video', title=title, description=description, enable_queue=True, examples=examples, examples_per_page=4, allow_flagging=False, article=article) iface.launch(debug=True)