yiyixuxu commited on
Commit
e572140
1 Parent(s): 0f2175b

limit video size, also add code to clean up the saved videos

Browse files
Files changed (1) hide show
  1. app.py +104 -84
app.py CHANGED
@@ -17,46 +17,63 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
17
  model, preprocess = clip.load("ViT-B/32")
18
 
19
 
20
- def select_video_format(url, format_note='480p', ext='mp4'):
21
  defaults = ['480p', '360p','240p','144p']
22
  ydl_opts = {}
23
  ydl = youtube_dl.YoutubeDL(ydl_opts)
24
  info_dict = ydl.extract_info(url, download=False)
25
  formats = info_dict.get('formats', None)
 
 
 
 
26
  available_format_notes = set([f['format_note'] for f in formats])
27
- if format_note not in available_format_notes:
28
- format_note = [d for d in defaults if d in available_format_notes][0]
29
- formats = [f for f in formats if f['format_note'] == format_note and f['ext'] == ext and f['vcodec'].split('.')[0] != 'av01']
30
- format = formats[0]
31
- format_id = format.get('format_id', None)
32
- fps = format.get('fps', None)
33
- print(f'format selected: {format}')
 
 
 
 
 
34
  return(format, format_id, fps)
35
 
36
-
37
- # to-do: delete saved videos
38
- # testing aria2c
39
- def download_video(url,format_id, n_keep=10):
40
- ydl_opts = {
41
- 'format':format_id,
42
- 'cachedir': False,
43
- 'external_downloader' : 'aria2c',
44
- 'external_downloader_args' :['--max-connection-per-server=16','--dir=videos'],
45
- 'outtmpl': "videos/%(id)s.%(ext)s"}
46
- # create a directory for saved videos
47
- video_path = Path('videos')
48
  try:
49
- video_path.mkdir(parents=True)
50
  except FileExistsError:
51
  pass
52
- with youtube_dl.YoutubeDL(ydl_opts) as ydl:
53
- try:
54
- ydl.cache.remove()
55
- meta = ydl.extract_info(url)
56
- save_location = 'videos/' + meta['id'] + '.' + meta['ext']
57
- except youtube_dl.DownloadError as error:
58
- print(f'error with download_video function: {error}')
59
- return(save_location)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def process_video_parallel(video, skip_frames, dest_path, num_processes, process_number):
62
  cap = cv2.VideoCapture(video)
@@ -76,35 +93,30 @@ def process_video_parallel(video, skip_frames, dest_path, num_processes, process
76
  cap.release()
77
 
78
 
79
- def vid2frames(url, sampling_interval=1, ext='mp4'):
80
  # create folder for extracted frames - if folder exists, delete and create a new one
81
- dest_path = Path('frames')
82
  try:
83
- dest_path.mkdir(parents=True)
84
  except FileExistsError:
85
- shutil.rmtree(dest_path)
86
- dest_path.mkdir(parents=True)
87
- # figure out the format for download,
88
- # by default select 480p and .mp4
89
- format, format_id, fps = select_video_format(url, format_note='480p', ext='mp4')
90
  # download the video
91
- video = download_video(url,format_id)
92
- # calculate skip_frames
93
- try:
94
- skip_frames = int(fps * sampling_interval)
95
- except:
96
- skip_frames = int(30 * sampling_interval)
97
-
98
- print(f'video saved at: {video}, fps:{fps}, skip_frames: {skip_frames}')
99
  # extract video frames at given sampling interval with multiprocessing -
100
- n_workers = min(os.cpu_count(), 12)
101
-
102
- print(f'now extracting frames with {n_workers} process...')
103
-
104
- with Pool(n_workers) as pool:
105
- pool.map(partial(process_video_parallel, video, skip_frames, dest_path, n_workers), range(n_workers))
106
- return(skip_frames, dest_path)
107
 
 
 
 
 
 
108
 
109
 
110
  def captioned_strip(images, caption=None, times=None, rows=1):
@@ -130,41 +142,47 @@ def captioned_strip(images, caption=None, times=None, rows=1):
130
 
131
  def run_inference(url, sampling_interval, search_query, bs=526):
132
  skip_frames, path_frames= vid2frames(url,sampling_interval)
133
- filenames = sorted(path_frames.glob('*.jpg'),key=lambda p: int(p.stem))
134
- n_frames = len(filenames)
135
- bs = min(n_frames,bs)
136
- print(f"extracted {n_frames} frames, now encoding images")
137
- # encoding images one batch at a time, combine all batch outputs -> image_features, size n_frames x 512
138
- image_features = torch.empty(size=(n_frames, 512)).to(device)
139
- print(f"batch size :{bs} ; number of batches: {len(range(0, n_frames,bs))}")
140
- for b in range(0, n_frames,bs):
141
- images = []
142
- # loop through all frames in the batch -> create batch_image_input, size bs x 3 x 224 x 224
143
- for filename in filenames[b:b+bs]:
144
- image = Image.open(filename).convert("RGB")
145
- images.append(preprocess(image))
146
- batch_image_input = torch.tensor(np.stack(images)).to(device)
147
- # encoding batch_image_input -> batch_image_features
148
- with torch.no_grad():
149
- batch_image_features = model.encode_image(batch_image_input)
150
- batch_image_features /= batch_image_features.norm(dim=-1, keepdim=True)
151
- # add encoded image embedding to image_features
152
- image_features[b:b+bs] = batch_image_features
153
- # encoding search query
154
- with torch.no_grad():
155
- text_features = model.encode_text(clip.tokenize(search_query).to(device))
156
- text_features /= text_features.norm(dim=-1, keepdim=True)
157
- print(image_features.dtype, text_features.dtype)
158
- similarity = (100.0 * image_features @ text_features.T)
159
- values, indices = similarity.topk(4, dim=0)
 
 
160
 
161
- best_frames = [Image.open(filenames[ind]).convert("RGB") for ind in indices]
162
- times = [f'{datetime.timedelta(seconds = ind[0].item() * sampling_interval)}' for ind in indices]
163
- image_output = captioned_strip(best_frames,search_query, times,2)
164
- title = search_query
 
 
 
 
165
  return(title, image_output)
166
 
167
- inputs = [gr.inputs.Textbox(label="Give us the link to your youtube video!"),
168
  gr.Number(5,label='sampling interval (seconds)'),
169
  gr.inputs.Textbox(label="What do you want to search?")]
170
  outputs = [
@@ -172,6 +190,8 @@ outputs = [
172
  gr.outputs.Image(label=""),
173
  ]
174
 
 
 
175
  gr.Interface(
176
  run_inference,
177
  inputs=inputs,
 
17
  model, preprocess = clip.load("ViT-B/32")
18
 
19
 
20
+ def select_video_format(url, format_note='240p', ext='mp4', max_size = 50000000):
21
  defaults = ['480p', '360p','240p','144p']
22
  ydl_opts = {}
23
  ydl = youtube_dl.YoutubeDL(ydl_opts)
24
  info_dict = ydl.extract_info(url, download=False)
25
  formats = info_dict.get('formats', None)
26
+ # filter out formats we can't process
27
+ formats = [f for f in formats if f['ext'] == ext
28
+ and f['vcodec'].split('.')[0] != 'av01'
29
+ and f['filesize'] is not None and f['filesize'] <= max_size]
30
  available_format_notes = set([f['format_note'] for f in formats])
31
+ try:
32
+ if format_note not in available_format_notes:
33
+ format_note = [d for d in defaults if d in available_format_notes][0]
34
+ formats = [f for f in formats if f['format_note'] == format_note]
35
+
36
+ format = formats[0]
37
+ format_id = format.get('format_id', None)
38
+ fps = format.get('fps', None)
39
+ print(f'format selected: {format}')
40
+ except IndexError as err:
41
+ print(f"can't find suitable video formats. we are not able to process video larger than 95 Mib at the moment")
42
+ format, format_id, fps = None, None, None
43
  return(format, format_id, fps)
44
 
45
+ # to-do: delete saved videos
46
+ def download_video(url):
47
+ # create "videos" foder for saved videos
48
+ path_videos = Path('videos')
 
 
 
 
 
 
 
 
49
  try:
50
+ path_videos.mkdir(parents=True)
51
  except FileExistsError:
52
  pass
53
+ # clear the "videos" folder
54
+ videos_to_keep = ['v1rkzUIL8oc', 'k4R5wZs8cxI','0diCvgWv_ng']
55
+ if len(list(path_videos.glob('*'))) > 10:
56
+ for path_video in path_videos.glob('*'):
57
+ if path_video.stem not in set(videos_to_keep):
58
+ path_video.unlink()
59
+ print(f'removed video {path_video}')
60
+ # select format to download for given video
61
+ # by default select 480p and .mp4
62
+ format, format_id, fps = select_video_format(url)
63
+ if format_id is not None:
64
+ dl_opts = {
65
+ 'format':format_id,
66
+ 'outtmpl': "videos/%(id)s.%(ext)s"}
67
+
68
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
69
+ try:
70
+ ydl.cache.remove()
71
+ meta = ydl.extract_info(url)
72
+ save_location = 'videos/' + meta['id'] + '.' + meta['ext']
73
+ except youtube_dl.DownloadError as error:
74
+ print(f'error with download_video function: {error}')
75
+ save_location = None
76
+ return(fps, save_location)
77
 
78
  def process_video_parallel(video, skip_frames, dest_path, num_processes, process_number):
79
  cap = cv2.VideoCapture(video)
 
93
  cap.release()
94
 
95
 
96
+ def vid2frames(url, sampling_interval=1):
97
  # create folder for extracted frames - if folder exists, delete and create a new one
98
+ path_frames = Path('frames')
99
  try:
100
+ path_frames.mkdir(parents=True)
101
  except FileExistsError:
102
+ shutil.rmtree(path_frames)
103
+ path_frames.mkdir(parents=True)
104
+
 
 
105
  # download the video
106
+ fps, video = download_video(url)
107
+ if video is not None:
108
+ if fps is None: fps = 30
109
+ skip_frames = int(fps * sampling_interval)
110
+ print(f'video saved at: {video}, fps:{fps}, skip_frames: {skip_frames}')
 
 
 
111
  # extract video frames at given sampling interval with multiprocessing -
112
+ n_workers = min(os.cpu_count(), 12)
113
+ print(f'now extracting frames with {n_workers} process...')
 
 
 
 
 
114
 
115
+ with Pool(n_workers) as pool:
116
+ pool.map(partial(process_video_parallel, video, skip_frames, path_frames, n_workers), range(n_workers))
117
+ else:
118
+ skip_frames, path_frames = None, None
119
+ return(skip_frames, path_frames)
120
 
121
 
122
  def captioned_strip(images, caption=None, times=None, rows=1):
 
142
 
143
  def run_inference(url, sampling_interval, search_query, bs=526):
144
  skip_frames, path_frames= vid2frames(url,sampling_interval)
145
+ if path_frames is not None:
146
+ filenames = sorted(path_frames.glob('*.jpg'),key=lambda p: int(p.stem))
147
+ n_frames = len(filenames)
148
+ bs = min(n_frames,bs)
149
+ print(f"extracted {n_frames} frames, now encoding images")
150
+ # encoding images one batch at a time, combine all batch outputs -> image_features, size n_frames x 512
151
+ image_features = torch.empty(size=(n_frames, 512),dtype=torch.float32).to(device)
152
+ print(f"encoding images, batch size :{bs} ; number of batches: {len(range(0, n_frames,bs))}")
153
+ for b in range(0, n_frames,bs):
154
+ images = []
155
+ # loop through all frames in the batch -> create batch_image_input, size bs x 3 x 224 x 224
156
+ for filename in filenames[b:b+bs]:
157
+ image = Image.open(filename).convert("RGB")
158
+ images.append(preprocess(image))
159
+ batch_image_input = torch.tensor(np.stack(images)).to(device)
160
+ # encoding batch_image_input -> batch_image_features
161
+ with torch.no_grad():
162
+ batch_image_features = model.encode_image(batch_image_input)
163
+ batch_image_features /= batch_image_features.norm(dim=-1, keepdim=True)
164
+ # add encoded image embedding to image_features
165
+ image_features[b:b+bs] = batch_image_features
166
+ # encoding search query
167
+ print(f'encoding search query')
168
+ with torch.no_grad():
169
+ text_features = model.encode_text(clip.tokenize(search_query).to(device)).to(dtype=torch.float32)
170
+ text_features /= text_features.norm(dim=-1, keepdim=True)
171
+
172
+ similarity = (100.0 * image_features @ text_features.T)
173
+ values, indices = similarity.topk(4, dim=0)
174
 
175
+ best_frames = [Image.open(filenames[ind]).convert("RGB") for ind in indices]
176
+ times = [f'{datetime.timedelta(seconds = ind[0].item() * sampling_interval)}' for ind in indices]
177
+ image_output = captioned_strip(best_frames,search_query, times,2)
178
+ title = search_query
179
+ print('task complete')
180
+ else:
181
+ title = "not able to download video"
182
+ image_output = None
183
  return(title, image_output)
184
 
185
+ inputs = [gr.inputs.Textbox(label="Give us the link to your youtube video! (note that downloading mighte be slow, e.g. it will take a few minutes to process a 10 minutes video)"),
186
  gr.Number(5,label='sampling interval (seconds)'),
187
  gr.inputs.Textbox(label="What do you want to search?")]
188
  outputs = [
 
190
  gr.outputs.Image(label=""),
191
  ]
192
 
193
+ example_videos = ['v1rkzUIL8oc', 'k4R5wZs8cxI','0diCvgWv_ng']
194
+
195
  gr.Interface(
196
  run_inference,
197
  inputs=inputs,