yiyixuxu commited on
Commit
c97026d
1 Parent(s): 12f763a

changed sampling algorithem

Browse files
Files changed (1) hide show
  1. app.py +22 -24
app.py CHANGED
@@ -40,30 +40,21 @@ def download_video(url,format_id):
40
  save_location = meta['id'] + '.' + meta['ext']
41
  return(save_location)
42
 
43
- def read_frames(dest_path):
44
- original_images = []
45
- images = []
46
- for filename in sorted(dest_path.glob('*.jpg'),key=lambda p: int(p.stem)):
47
- image = Image.open(filename).convert("RGB")
48
- original_images.append(image)
49
- images.append(preprocess(image))
50
- return original_images, images
51
-
52
  def process_video_parallel(video, skip_frames, dest_path, num_processes, process_number):
53
  cap = cv2.VideoCapture(video)
54
- chunks_per_process = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) // (num_processes * skip_frames)
55
- count = skip_frames * chunks_per_process * process_number
56
- print(f"worker: {process_number}, process frames {count} ~ {skip_frames * chunks_per_process * (process_number + 1)} \n total number of frames: {cap.get(cv2.CAP_PROP_FRAME_COUNT)} \n video: {video}; isOpen? : {cap.isOpened()}")
57
- while count < skip_frames * chunks_per_process * (process_number + 1) :
58
- if skip_frames > 1:
59
- cap.set(cv2.CAP_PROP_POS_FRAMES, count)
60
  ret, frame = cap.read()
61
  if not ret:
62
  break
63
- filename =f"{dest_path}/{count}.jpg"
64
- cv2.imwrite(filename, frame)
65
- print(f"saved {filename}")
66
- count += skip_frames # Skip 300 frames i.e. 10 seconds for 30 fps
 
 
67
  cap.release()
68
 
69
 
@@ -87,8 +78,7 @@ def vid2frames(url, sampling_interval=1, ext='mp4'):
87
  except:
88
  skip_frames = int(30 * sampling_interval)
89
 
90
- # testing
91
- skip_frames = 1
92
  print(f'video saved at: {video}, fps:{fps}, skip_frames: {skip_frames}')
93
  # extract video frames at given sampling interval with multiprocessing -
94
  print('extracting frames...')
@@ -99,7 +89,16 @@ def vid2frames(url, sampling_interval=1, ext='mp4'):
99
  print(f'n_workers: {n_workers}')
100
  with Pool(n_workers) as pool:
101
  pool.map(partial(process_video_parallel, video, skip_frames, dest_path, n_workers), range(n_workers))
102
- return dest_path
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  def captioned_strip(images, caption=None, times=None, rows=1):
@@ -126,8 +125,7 @@ def captioned_strip(images, caption=None, times=None, rows=1):
126
  return img
127
 
128
  def run_inference(url, sampling_interval, search_query):
129
- path_frames = vid2frames(url,sampling_interval)
130
- original_images, images = read_frames(path_frames)
131
  image_input = torch.tensor(np.stack(images)).to(device)
132
  with torch.no_grad():
133
  image_features = model.encode_image(image_input)
 
40
  save_location = meta['id'] + '.' + meta['ext']
41
  return(save_location)
42
 
 
 
 
 
 
 
 
 
 
43
  def process_video_parallel(video, skip_frames, dest_path, num_processes, process_number):
44
  cap = cv2.VideoCapture(video)
45
+ frames_per_process = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) // (num_processes)
46
+ count = frames_per_process * process_number
47
+ print(f"worker: {process_number}, process frames {count} ~ {frames_per_process * (process_number + 1)} \n total number of frames: {cap.get(cv2.CAP_PROP_FRAME_COUNT)} \n video: {video}; isOpen? : {cap.isOpened()}")
48
+ while count < frames_per_process * (process_number + 1) :
 
 
49
  ret, frame = cap.read()
50
  if not ret:
51
  break
52
+ count += 1
53
+ if (count - frames_per_process * process_number) % skip_frames ==0:
54
+ filename =f"{dest_path}/{count}.jpg"
55
+ cv2.imwrite(filename, frame)
56
+ #print(f"saved {filename}")
57
+
58
  cap.release()
59
 
60
 
 
78
  except:
79
  skip_frames = int(30 * sampling_interval)
80
 
81
+
 
82
  print(f'video saved at: {video}, fps:{fps}, skip_frames: {skip_frames}')
83
  # extract video frames at given sampling interval with multiprocessing -
84
  print('extracting frames...')
 
89
  print(f'n_workers: {n_workers}')
90
  with Pool(n_workers) as pool:
91
  pool.map(partial(process_video_parallel, video, skip_frames, dest_path, n_workers), range(n_workers))
92
+ # read frames
93
+ original_images = []
94
+ images = []
95
+ filenames = sorted(dest_path.glob('*.jpg'),key=lambda p: int(p.stem))
96
+ print(f"extracted {len(filenames)} frames")
97
+ for filename in filenames:
98
+ image = Image.open(filename).convert("RGB")
99
+ original_images.append(image)
100
+ images.append(preprocess(image))
101
+ return original_images, images
102
 
103
 
104
  def captioned_strip(images, caption=None, times=None, rows=1):
 
125
  return img
126
 
127
  def run_inference(url, sampling_interval, search_query):
128
+ original_images, images = vid2frames(url,sampling_interval)
 
129
  image_input = torch.tensor(np.stack(images)).to(device)
130
  with torch.no_grad():
131
  image_features = model.encode_image(image_input)