nanushio
commited on
Commit
·
0c18aca
1
Parent(s):
c45a2ea
- [MINOR] [SOURCE] [UPDATE] 1. update app.py
Browse files- app.py +3 -1
- cover/datasets/cover_datasets.py +32 -17
app.py
CHANGED
@@ -66,8 +66,10 @@ def inference_one_video(input_video):
|
|
66 |
"""
|
67 |
TESTING
|
68 |
"""
|
|
|
|
|
69 |
views, _ = spatial_temporal_view_decomposition(
|
70 |
-
|
71 |
)
|
72 |
|
73 |
for k, v in views.items():
|
|
|
66 |
"""
|
67 |
TESTING
|
68 |
"""
|
69 |
+
# Convert input video to tensor and adjust dimensions
|
70 |
+
input_video_tensor = torch.from_numpy(input_video).permute(0, 3, 1, 2)
|
71 |
views, _ = spatial_temporal_view_decomposition(
|
72 |
+
input_video_tensor, dopt["sample_types"], temporal_samplers
|
73 |
)
|
74 |
|
75 |
for k, v in views.items():
|
cover/datasets/cover_datasets.py
CHANGED
@@ -232,34 +232,49 @@ def spatial_temporal_view_decomposition(
|
|
232 |
video_path, sample_types, samplers, is_train=False, augment=False,
|
233 |
):
|
234 |
video = {}
|
235 |
-
if
|
236 |
-
print("This part will be deprecated due to large memory cost.")
|
237 |
-
## This is only an adaptation to LIVE-Qualcomm
|
238 |
-
ovideo = skvideo.io.vread(
|
239 |
-
video_path, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"}
|
240 |
-
)
|
241 |
-
for stype in samplers:
|
242 |
-
frame_inds = samplers[stype](ovideo.shape[0], is_train)
|
243 |
-
imgs = [torch.from_numpy(ovideo[idx]) for idx in frame_inds]
|
244 |
-
video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
|
245 |
-
del ovideo
|
246 |
-
else:
|
247 |
-
decord.bridge.set_bridge("torch")
|
248 |
-
vreader = VideoReader(video_path)
|
249 |
-
### Avoid duplicated video decoding!!! Important!!!!
|
250 |
all_frame_inds = []
|
251 |
frame_inds = {}
|
252 |
for stype in samplers:
|
253 |
-
frame_inds[stype] = samplers[stype](
|
254 |
all_frame_inds.append(frame_inds[stype])
|
255 |
|
256 |
### Each frame is only decoded one time!!!
|
257 |
all_frame_inds = np.concatenate(all_frame_inds, 0)
|
258 |
-
frame_dict = {idx:
|
259 |
|
260 |
for stype in samplers:
|
261 |
imgs = [frame_dict[idx] for idx in frame_inds[stype]]
|
262 |
video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
|
264 |
sampled_video = {}
|
265 |
for stype, sopt in sample_types.items():
|
|
|
232 |
video_path, sample_types, samplers, is_train=False, augment=False,
|
233 |
):
|
234 |
video = {}
|
235 |
+
if torch.is_tensor(video_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
all_frame_inds = []
|
237 |
frame_inds = {}
|
238 |
for stype in samplers:
|
239 |
+
frame_inds[stype] = samplers[stype](video_path.shape[0], is_train)
|
240 |
all_frame_inds.append(frame_inds[stype])
|
241 |
|
242 |
### Each frame is only decoded one time!!!
|
243 |
all_frame_inds = np.concatenate(all_frame_inds, 0)
|
244 |
+
frame_dict = {idx: video_path[idx].permute(1, 2, 0) for idx in np.unique(all_frame_inds)}
|
245 |
|
246 |
for stype in samplers:
|
247 |
imgs = [frame_dict[idx] for idx in frame_inds[stype]]
|
248 |
video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
|
249 |
+
else:
|
250 |
+
if video_path.endswith(".yuv"):
|
251 |
+
print("This part will be deprecated due to large memory cost.")
|
252 |
+
## This is only an adaptation to LIVE-Qualcomm
|
253 |
+
ovideo = skvideo.io.vread(
|
254 |
+
video_path, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"}
|
255 |
+
)
|
256 |
+
for stype in samplers:
|
257 |
+
frame_inds = samplers[stype](ovideo.shape[0], is_train)
|
258 |
+
imgs = [torch.from_numpy(ovideo[idx]) for idx in frame_inds]
|
259 |
+
video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
|
260 |
+
del ovideo
|
261 |
+
else:
|
262 |
+
decord.bridge.set_bridge("torch")
|
263 |
+
vreader = VideoReader(video_path)
|
264 |
+
### Avoid duplicated video decoding!!! Important!!!!
|
265 |
+
all_frame_inds = []
|
266 |
+
frame_inds = {}
|
267 |
+
for stype in samplers:
|
268 |
+
frame_inds[stype] = samplers[stype](len(vreader), is_train)
|
269 |
+
all_frame_inds.append(frame_inds[stype])
|
270 |
+
|
271 |
+
### Each frame is only decoded one time!!!
|
272 |
+
all_frame_inds = np.concatenate(all_frame_inds, 0)
|
273 |
+
frame_dict = {idx: vreader[idx] for idx in np.unique(all_frame_inds)}
|
274 |
+
|
275 |
+
for stype in samplers:
|
276 |
+
imgs = [frame_dict[idx] for idx in frame_inds[stype]]
|
277 |
+
video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
|
278 |
|
279 |
sampled_video = {}
|
280 |
for stype, sopt in sample_types.items():
|