Spaces:
vztu
/
Runtime error

nanushio commited on
Commit
0c18aca
·
1 Parent(s): c45a2ea

- [MINOR] [SOURCE] [UPDATE] 1. update app.py

Browse files
Files changed (2) hide show
  1. app.py +3 -1
  2. cover/datasets/cover_datasets.py +32 -17
app.py CHANGED
@@ -66,8 +66,10 @@ def inference_one_video(input_video):
66
  """
67
  TESTING
68
  """
 
 
69
  views, _ = spatial_temporal_view_decomposition(
70
- input_video, dopt["sample_types"], temporal_samplers
71
  )
72
 
73
  for k, v in views.items():
 
66
  """
67
  TESTING
68
  """
69
+ # Convert input video to tensor and adjust dimensions
70
+ input_video_tensor = torch.from_numpy(input_video).permute(0, 3, 1, 2)
71
  views, _ = spatial_temporal_view_decomposition(
72
+ input_video_tensor, dopt["sample_types"], temporal_samplers
73
  )
74
 
75
  for k, v in views.items():
cover/datasets/cover_datasets.py CHANGED
@@ -232,34 +232,49 @@ def spatial_temporal_view_decomposition(
232
  video_path, sample_types, samplers, is_train=False, augment=False,
233
  ):
234
  video = {}
235
- if video_path.endswith(".yuv"):
236
- print("This part will be deprecated due to large memory cost.")
237
- ## This is only an adaptation to LIVE-Qualcomm
238
- ovideo = skvideo.io.vread(
239
- video_path, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"}
240
- )
241
- for stype in samplers:
242
- frame_inds = samplers[stype](ovideo.shape[0], is_train)
243
- imgs = [torch.from_numpy(ovideo[idx]) for idx in frame_inds]
244
- video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
245
- del ovideo
246
- else:
247
- decord.bridge.set_bridge("torch")
248
- vreader = VideoReader(video_path)
249
- ### Avoid duplicated video decoding!!! Important!!!!
250
  all_frame_inds = []
251
  frame_inds = {}
252
  for stype in samplers:
253
- frame_inds[stype] = samplers[stype](len(vreader), is_train)
254
  all_frame_inds.append(frame_inds[stype])
255
 
256
  ### Each frame is only decoded one time!!!
257
  all_frame_inds = np.concatenate(all_frame_inds, 0)
258
- frame_dict = {idx: vreader[idx] for idx in np.unique(all_frame_inds)}
259
 
260
  for stype in samplers:
261
  imgs = [frame_dict[idx] for idx in frame_inds[stype]]
262
  video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  sampled_video = {}
265
  for stype, sopt in sample_types.items():
 
232
  video_path, sample_types, samplers, is_train=False, augment=False,
233
  ):
234
  video = {}
235
+ if torch.is_tensor(video_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  all_frame_inds = []
237
  frame_inds = {}
238
  for stype in samplers:
239
+ frame_inds[stype] = samplers[stype](video_path.shape[0], is_train)
240
  all_frame_inds.append(frame_inds[stype])
241
 
242
  ### Each frame is only decoded one time!!!
243
  all_frame_inds = np.concatenate(all_frame_inds, 0)
244
+ frame_dict = {idx: video_path[idx].permute(1, 2, 0) for idx in np.unique(all_frame_inds)}
245
 
246
  for stype in samplers:
247
  imgs = [frame_dict[idx] for idx in frame_inds[stype]]
248
  video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
249
+ else:
250
+ if video_path.endswith(".yuv"):
251
+ print("This part will be deprecated due to large memory cost.")
252
+ ## This is only an adaptation to LIVE-Qualcomm
253
+ ovideo = skvideo.io.vread(
254
+ video_path, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"}
255
+ )
256
+ for stype in samplers:
257
+ frame_inds = samplers[stype](ovideo.shape[0], is_train)
258
+ imgs = [torch.from_numpy(ovideo[idx]) for idx in frame_inds]
259
+ video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
260
+ del ovideo
261
+ else:
262
+ decord.bridge.set_bridge("torch")
263
+ vreader = VideoReader(video_path)
264
+ ### Avoid duplicated video decoding!!! Important!!!!
265
+ all_frame_inds = []
266
+ frame_inds = {}
267
+ for stype in samplers:
268
+ frame_inds[stype] = samplers[stype](len(vreader), is_train)
269
+ all_frame_inds.append(frame_inds[stype])
270
+
271
+ ### Each frame is only decoded one time!!!
272
+ all_frame_inds = np.concatenate(all_frame_inds, 0)
273
+ frame_dict = {idx: vreader[idx] for idx in np.unique(all_frame_inds)}
274
+
275
+ for stype in samplers:
276
+ imgs = [frame_dict[idx] for idx in frame_inds[stype]]
277
+ video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
278
 
279
  sampled_video = {}
280
  for stype, sopt in sample_types.items():