shaoyent commited on
Commit
d37aaef
1 Parent(s): d597aa1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -8,6 +8,7 @@ from PIL import Image
8
  import torch
9
  from torch.nn.utils.rnn import pad_sequence
10
  from transformers import BridgeTowerProcessor
 
11
 
12
  from bridgetower_custom import BridgeTowerTextFeatureExtractor, BridgeTowerForITC
13
 
@@ -134,7 +135,7 @@ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=
134
  batch_list = []
135
  vtt = webvtt.read(subtitles)
136
 
137
- for idx, caption in enumerate(progress.tqdm(webvtt.read(subtitles), total=vtt.total_length, desc="Generating embeddings")):
138
  st_time = str2time(caption.start)
139
  ed_time = str2time(caption.end)
140
 
@@ -286,7 +287,7 @@ def get_video_id_from_url(video_url):
286
  return None
287
 
288
 
289
- def process(video_url, text_query, progress=gr.Progress()):
290
  tmp_dir = os.environ.get('TMPDIR', '/tmp')
291
  video_id = get_video_id_from_url(video_url)
292
  output_dir = os.path.join(tmp_dir, video_id)
@@ -298,7 +299,7 @@ def process(video_url, text_query, progress=gr.Progress()):
298
  output=output_dir,
299
  expanded=False,
300
  batch_size=8,
301
- progress=gr.Progress(),
302
  )
303
  frame_paths, transcripts = run_query(video_file, text_query, path=output_dir)
304
  return video_file, [(image, caption) for image, caption in zip(frame_paths, transcripts)]
 
8
  import torch
9
  from torch.nn.utils.rnn import pad_sequence
10
  from transformers import BridgeTowerProcessor
11
+ from tqdm import tqdm
12
 
13
  from bridgetower_custom import BridgeTowerTextFeatureExtractor, BridgeTowerForITC
14
 
 
135
  batch_list = []
136
  vtt = webvtt.read(subtitles)
137
 
138
+ for idx, caption in enumerate(tqdm(vtt, total=vtt.total_length, desc="Generating embeddings")):
139
  st_time = str2time(caption.start)
140
  ed_time = str2time(caption.end)
141
 
 
287
  return None
288
 
289
 
290
+ def process(video_url, text_query, progress=gr.Progress(track_tqdm=True)):
291
  tmp_dir = os.environ.get('TMPDIR', '/tmp')
292
  video_id = get_video_id_from_url(video_url)
293
  output_dir = os.path.join(tmp_dir, video_id)
 
299
  output=output_dir,
300
  expanded=False,
301
  batch_size=8,
302
+ progress=progress,
303
  )
304
  frame_paths, transcripts = run_query(video_file, text_query, path=output_dir)
305
  return video_file, [(image, caption) for image, caption in zip(frame_paths, transcripts)]