MTTR commited on
Commit
260d870
1 Parent(s): 9c0f938

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -0
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ End-to-End Referring Video Object Segmentation with Multimodal Transformers
4
+
5
+ This notebook provides a (limited) hands-on demonstration of MTTR.
6
+
7
+ Given a text query and a short clip based on a YouTube video, we demonstrate how MTTR can be used to segment the referred object instance throughout the video.
8
+
9
+
10
+ ### Disclaimer
11
+ This is a **limited** demonstration of MTTR's performance. The model used here was trained **exclusively** on Refer-YouTube-VOS with window size `w=12` (as described in our paper). No additional training data was used whatsoever.
12
+ Hence, the model's performance may be limited, especially on instances from unseen categories.
13
+
14
+ Additionally, slow processing times may be encountered, depending on the input clip length and/or resolution, and due to Colab's limited computational resources.
15
+
16
+ Finally, we emphasize that this demonstration is intended to be used for academic purposes only. We do not take any responsibility for how the created content is used or distributed, and discourage the users from copyright infringment of YouTube videos. <br><br>
17
+
18
+ And now, with all formalities aside, let's begin!
19
+
20
+ """
21
+
22
+ import gradio as gr
23
+ import torch
24
+ import torchvision
25
+ import torchvision.transforms.functional as F
26
+ from einops import rearrange
27
+ import numpy as np
28
+ from PIL import Image, ImageDraw, ImageOps, ImageFont
29
+ from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip
30
+ from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
31
+ from tqdm import trange, tqdm
32
+
33
+ class NestedTensor(object):
34
+ def __init__(self, tensors, mask):
35
+ self.tensors = tensors
36
+ self.mask = mask
37
+
38
+ def nested_tensor_from_videos_list(videos_list):
39
+ def _max_by_axis(the_list):
40
+ maxes = the_list[0]
41
+ for sublist in the_list[1:]:
42
+ for index, item in enumerate(sublist):
43
+ maxes[index] = max(maxes[index], item)
44
+ return maxes
45
+
46
+ max_size = _max_by_axis([list(img.shape) for img in videos_list])
47
+ padded_batch_shape = [len(videos_list)] + max_size
48
+ b, t, c, h, w = padded_batch_shape
49
+ dtype = videos_list[0].dtype
50
+ device = videos_list[0].device
51
+ padded_videos = torch.zeros(padded_batch_shape, dtype=dtype, device=device)
52
+ videos_pad_masks = torch.ones((b, t, h, w), dtype=torch.bool, device=device)
53
+ for vid_frames, pad_vid_frames, vid_pad_m in zip(videos_list, padded_videos, videos_pad_masks):
54
+ pad_vid_frames[:vid_frames.shape[0], :, :vid_frames.shape[2], :vid_frames.shape[3]].copy_(vid_frames)
55
+ vid_pad_m[:vid_frames.shape[0], :vid_frames.shape[2], :vid_frames.shape[3]] = False
56
+ return NestedTensor(padded_videos.transpose(0, 1), videos_pad_masks.transpose(0, 1))
57
+
58
+ def apply_mask(image, mask, color, transparency=0.7):
59
+ mask = mask[..., np.newaxis].repeat(repeats=3, axis=2)
60
+ mask = mask * transparency
61
+ color_matrix = np.ones(image.shape, dtype=np.float) * color
62
+ out_image = color_matrix * mask + image * (1.0 - mask)
63
+ return out_image
64
+
65
+ def process(text_query, full_video_path):
66
+ start_pt, end_pt = 0, 10
67
+ input_clip_path = '/tmp/input.mp4'
68
+ # extract the relevant subclip:
69
+ with VideoFileClip(full_video_path) as video:
70
+ subclip = video.subclip(start_pt, end_pt)
71
+ subclip.write_videofile(input_clip_path)
72
+
73
+ checkpoint_path ='./refer-youtube-vos_window-12.pth.tar'
74
+ model, postprocessor = torch.hub.load('Randl/MTTR:main','mttr_refer_youtube_vos', get_weights=False)
75
+
76
+ model_state_dict = torch.load(checkpoint_path, map_location='cpu')
77
+ if 'model_state_dict' in model_state_dict.keys():
78
+ model_state_dict = model_state_dict['model_state_dict']
79
+ model.load_state_dict(model_state_dict, strict=True)
80
+
81
+
82
+ text_queries= [text_query]
83
+ window_length = 24 # length of window during inference
84
+ window_overlap = 6 # overlap (in frames) between consecutive windows
85
+
86
+ with torch.inference_mode():
87
+ # read and preprocess the video clip:
88
+ video, audio, meta = torchvision.io.read_video(filename=input_clip_path)
89
+ video = rearrange(video, 't h w c -> t c h w')
90
+ input_video = F.resize(video, size=360, max_size=640)
91
+ input_video = input_video.to(torch.float).div_(255)
92
+ input_video = F.normalize(input_video, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
93
+ video_metadata = {'resized_frame_size': input_video.shape[-2:], 'original_frame_size': video.shape[-2:]}
94
+
95
+ # partition the clip into overlapping windows of frames:
96
+ windows = [input_video[i:i+window_length] for i in range(0, len(input_video), window_length - window_overlap)]
97
+ # clean up the text queries:
98
+ text_queries = [" ".join(q.lower().split()) for q in text_queries]
99
+
100
+ pred_masks_per_query = []
101
+ t, _, h, w = video.shape
102
+ for text_query in tqdm(text_queries, desc='text queries'):
103
+ pred_masks = torch.zeros(size=(t, 1, h, w))
104
+ for i, window in enumerate(tqdm(windows, desc='windows')):
105
+ window = nested_tensor_from_videos_list([window])
106
+ valid_indices = torch.arange(len(window.tensors))
107
+ outputs = model(window, valid_indices, [text_query])
108
+ window_masks = postprocessor(outputs, [video_metadata], window.tensors.shape[-2:])[0]['pred_masks']
109
+ win_start_idx = i*(window_length-window_overlap)
110
+ pred_masks[win_start_idx:win_start_idx + window_length] = window_masks
111
+ pred_masks_per_query.append(pred_masks)
112
+
113
+ """Finally, we apply the generated instance masks and their corresponding text queries on the input clip for visualization:"""
114
+
115
+ # RGB colors for instance masks:
116
+ light_blue = (41, 171, 226)
117
+ purple = (237, 30, 121)
118
+ dark_green = (35, 161, 90)
119
+ orange = (255, 148, 59)
120
+ colors = np.array([light_blue, purple, dark_green, orange])
121
+
122
+ # width (in pixels) of the black strip above the video on which the text queries will be displayed:
123
+ text_border_height_per_query = 40
124
+
125
+ video_np = rearrange(video, 't c h w -> t h w c').numpy() / 255.0
126
+ # del video
127
+ pred_masks_per_frame = rearrange(torch.stack(pred_masks_per_query), 'q t 1 h w -> t q h w').numpy()
128
+ masked_video = []
129
+ for vid_frame, frame_masks in tqdm(zip(video_np, pred_masks_per_frame), total=len(video_np), desc='applying masks...'):
130
+ # apply the masks:
131
+ for inst_mask, color in zip(frame_masks, colors):
132
+ vid_frame = apply_mask(vid_frame, inst_mask, color / 255.0)
133
+ vid_frame = Image.fromarray((vid_frame * 255).astype(np.uint8))
134
+ # visualize the text queries:
135
+ vid_frame = ImageOps.expand(vid_frame, border=(0, len(text_queries)*text_border_height_per_query, 0, 0))
136
+ W, H = vid_frame.size
137
+ draw = ImageDraw.Draw(vid_frame)
138
+ font = ImageFont.truetype(font='LiberationSans-Regular.ttf', size=30)
139
+ for i, (text_query, color) in enumerate(zip(text_queries, colors), start=1):
140
+ w, h = draw.textsize(text_query, font=font)
141
+ draw.text(((W - w) / 2, (text_border_height_per_query * i) - h - 8),
142
+ text_query, fill=tuple(color) + (255,), font=font)
143
+ masked_video.append(np.array(vid_frame))
144
+
145
+ # generate and save the output clip:
146
+ output_clip_path = '/tmp/output_clip.mp4'
147
+ clip = ImageSequenceClip(sequence=masked_video, fps=meta['video_fps'])
148
+ clip = clip.set_audio(AudioFileClip(input_clip_path))
149
+ clip.write_videofile(output_clip_path, fps=meta['video_fps'], audio=True)
150
+ del masked_video
151
+
152
+
153
+ return output_clip_path
154
+
155
+
156
+
157
+ title = "Interactive demo: MTTR"
158
+
159
+ description = "To use it, upload a video file. Right now we only suggest using .mp4 files."
160
+
161
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.14821'>End-to-End Referring Video Object Segmentation with Multimodal Transformers</a> | <a href='https://github.com/mttr2021/MTTR'>Github Repo</a></p>"
162
+
163
+ iface = gr.Interface(fn=process,
164
+ inputs=[gr.inputs.Textbox(label="text query"), gr.inputs.Video(label="Input video. First 10 seconds of the video are used.")],
165
+ outputs='video',
166
+ title=title,
167
+ description=description,
168
+ enable_queue=True,
169
+ # examples=[[420, 'skate_jump.mp4']], # Not working for some reason...
170
+ article=article)
171
+
172
+ iface.launch(debug=True)