COVER / cover /datasets /cover_datasets.py
nanushio
- [MINOR] [SOURCE] [UPDATE] 1. update app.py
0c18aca
raw
history blame
16.6 kB
import copy
import glob
import os
import os.path as osp
import random
from functools import lru_cache
import cv2
import decord
import numpy as np
import skvideo.io
import torch
import torchvision
from decord import VideoReader, cpu, gpu
from tqdm import tqdm
random.seed(42)
decord.bridge.set_bridge("torch")
def get_spatial_fragments(
video,
fragments_h=7,
fragments_w=7,
fsize_h=32,
fsize_w=32,
aligned=32,
nfrags=1,
random=False,
random_upsample=False,
fallback_type="upsample",
upsample=-1,
**kwargs,
):
if upsample > 0:
old_h, old_w = video.shape[-2], video.shape[-1]
if old_h >= old_w:
w = upsample
h = int(upsample * old_h / old_w)
else:
h = upsample
w = int(upsample * old_w / old_h)
video = get_resized_video(video, h, w)
size_h = fragments_h * fsize_h
size_w = fragments_w * fsize_w
## video: [C,T,H,W]
## situation for images
if video.shape[1] == 1:
aligned = 1
dur_t, res_h, res_w = video.shape[-3:]
ratio = min(res_h / size_h, res_w / size_w)
if fallback_type == "upsample" and ratio < 1:
ovideo = video
video = torch.nn.functional.interpolate(
video / 255.0, scale_factor=1 / ratio, mode="bilinear"
)
video = (video * 255.0).type_as(ovideo)
if random_upsample:
randratio = random.random() * 0.5 + 1
video = torch.nn.functional.interpolate(
video / 255.0, scale_factor=randratio, mode="bilinear"
)
video = (video * 255.0).type_as(ovideo)
assert dur_t % aligned == 0, "Please provide match vclip and align index"
size = size_h, size_w
## make sure that sampling will not run out of the picture
hgrids = torch.LongTensor(
[min(res_h // fragments_h * i, res_h - fsize_h) for i in range(fragments_h)]
)
wgrids = torch.LongTensor(
[min(res_w // fragments_w * i, res_w - fsize_w) for i in range(fragments_w)]
)
hlength, wlength = res_h // fragments_h, res_w // fragments_w
if random:
print("This part is deprecated. Please remind that.")
if res_h > fsize_h:
rnd_h = torch.randint(
res_h - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned)
)
else:
rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
if res_w > fsize_w:
rnd_w = torch.randint(
res_w - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned)
)
else:
rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
else:
if hlength > fsize_h:
rnd_h = torch.randint(
hlength - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned)
)
else:
rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
if wlength > fsize_w:
rnd_w = torch.randint(
wlength - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned)
)
else:
rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
target_video = torch.zeros(video.shape[:-2] + size).to(video.device)
# target_videos = []
for i, hs in enumerate(hgrids):
for j, ws in enumerate(wgrids):
for t in range(dur_t // aligned):
t_s, t_e = t * aligned, (t + 1) * aligned
h_s, h_e = i * fsize_h, (i + 1) * fsize_h
w_s, w_e = j * fsize_w, (j + 1) * fsize_w
if random:
h_so, h_eo = rnd_h[i][j][t], rnd_h[i][j][t] + fsize_h
w_so, w_eo = rnd_w[i][j][t], rnd_w[i][j][t] + fsize_w
else:
h_so, h_eo = hs + rnd_h[i][j][t], hs + rnd_h[i][j][t] + fsize_h
w_so, w_eo = ws + rnd_w[i][j][t], ws + rnd_w[i][j][t] + fsize_w
target_video[:, t_s:t_e, h_s:h_e, w_s:w_e] = video[
:, t_s:t_e, h_so:h_eo, w_so:w_eo
]
# target_videos.append(video[:,t_s:t_e,h_so:h_eo,w_so:w_eo])
# target_video = torch.stack(target_videos, 0).reshape((dur_t // aligned, fragments, fragments,) + target_videos[0].shape).permute(3,0,4,1,5,2,6)
# target_video = target_video.reshape((-1, dur_t,) + size) ## Splicing Fragments
return target_video
@lru_cache
def get_resize_function(size_h, size_w, target_ratio=1, random_crop=False):
if random_crop:
return torchvision.transforms.RandomResizedCrop(
(size_h, size_w), scale=(0.40, 1.0)
)
if target_ratio > 1:
size_h = int(target_ratio * size_w)
assert size_h > size_w
elif target_ratio < 1:
size_w = int(size_h / target_ratio)
assert size_w > size_h
return torchvision.transforms.Resize((size_h, size_w))
def get_resized_video(
video, size_h=224, size_w=224, random_crop=False, arp=False, **kwargs,
):
video = video.permute(1, 0, 2, 3)
resize_opt = get_resize_function(
size_h, size_w, video.shape[-2] / video.shape[-1] if arp else 1, random_crop
)
video = resize_opt(video).permute(1, 0, 2, 3)
return video
def get_arp_resized_video(
video, short_edge=224, train=False, **kwargs,
):
if train: ## if during training, will random crop into square and then resize
res_h, res_w = video.shape[-2:]
ori_short_edge = min(video.shape[-2:])
if res_h > ori_short_edge:
rnd_h = random.randrange(res_h - ori_short_edge)
video = video[..., rnd_h : rnd_h + ori_short_edge, :]
elif res_w > ori_short_edge:
rnd_w = random.randrange(res_w - ori_short_edge)
video = video[..., :, rnd_h : rnd_h + ori_short_edge]
ori_short_edge = min(video.shape[-2:])
scale_factor = short_edge / ori_short_edge
ovideo = video
video = torch.nn.functional.interpolate(
video / 255.0, scale_factors=scale_factor, mode="bilinear"
)
video = (video * 255.0).type_as(ovideo)
return video
def get_arp_fragment_video(
video, short_fragments=7, fsize=32, train=False, **kwargs,
):
if (
train
): ## if during training, will random crop into square and then get fragments
res_h, res_w = video.shape[-2:]
ori_short_edge = min(video.shape[-2:])
if res_h > ori_short_edge:
rnd_h = random.randrange(res_h - ori_short_edge)
video = video[..., rnd_h : rnd_h + ori_short_edge, :]
elif res_w > ori_short_edge:
rnd_w = random.randrange(res_w - ori_short_edge)
video = video[..., :, rnd_h : rnd_h + ori_short_edge]
kwargs["fsize_h"], kwargs["fsize_w"] = fsize, fsize
res_h, res_w = video.shape[-2:]
if res_h > res_w:
kwargs["fragments_w"] = short_fragments
kwargs["fragments_h"] = int(short_fragments * res_h / res_w)
else:
kwargs["fragments_h"] = short_fragments
kwargs["fragments_w"] = int(short_fragments * res_w / res_h)
return get_spatial_fragments(video, **kwargs)
def get_cropped_video(
video, size_h=224, size_w=224, **kwargs,
):
kwargs["fragments_h"], kwargs["fragments_w"] = 1, 1
kwargs["fsize_h"], kwargs["fsize_w"] = size_h, size_w
return get_spatial_fragments(video, **kwargs)
def get_single_view(
video, sample_type="aesthetic", **kwargs,
):
if sample_type.startswith("aesthetic"):
video = get_resized_video(video, **kwargs)
elif sample_type.startswith("technical"):
video = get_spatial_fragments(video, **kwargs)
elif sample_type.startswith("semantic"):
video = get_resized_video(video, **kwargs)
elif sample_type == "original":
return video
return video
def spatial_temporal_view_decomposition(
video_path, sample_types, samplers, is_train=False, augment=False,
):
video = {}
if torch.is_tensor(video_path):
all_frame_inds = []
frame_inds = {}
for stype in samplers:
frame_inds[stype] = samplers[stype](video_path.shape[0], is_train)
all_frame_inds.append(frame_inds[stype])
### Each frame is only decoded one time!!!
all_frame_inds = np.concatenate(all_frame_inds, 0)
frame_dict = {idx: video_path[idx].permute(1, 2, 0) for idx in np.unique(all_frame_inds)}
for stype in samplers:
imgs = [frame_dict[idx] for idx in frame_inds[stype]]
video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
else:
if video_path.endswith(".yuv"):
print("This part will be deprecated due to large memory cost.")
## This is only an adaptation to LIVE-Qualcomm
ovideo = skvideo.io.vread(
video_path, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"}
)
for stype in samplers:
frame_inds = samplers[stype](ovideo.shape[0], is_train)
imgs = [torch.from_numpy(ovideo[idx]) for idx in frame_inds]
video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
del ovideo
else:
decord.bridge.set_bridge("torch")
vreader = VideoReader(video_path)
### Avoid duplicated video decoding!!! Important!!!!
all_frame_inds = []
frame_inds = {}
for stype in samplers:
frame_inds[stype] = samplers[stype](len(vreader), is_train)
all_frame_inds.append(frame_inds[stype])
### Each frame is only decoded one time!!!
all_frame_inds = np.concatenate(all_frame_inds, 0)
frame_dict = {idx: vreader[idx] for idx in np.unique(all_frame_inds)}
for stype in samplers:
imgs = [frame_dict[idx] for idx in frame_inds[stype]]
video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
sampled_video = {}
for stype, sopt in sample_types.items():
sampled_video[stype] = get_single_view(video[stype], stype, **sopt)
return sampled_video, frame_inds
import random
import numpy as np
class UnifiedFrameSampler:
def __init__(
self, fsize_t, fragments_t, frame_interval=1, num_clips=1, drop_rate=0.0,
):
self.fragments_t = fragments_t
self.fsize_t = fsize_t
self.size_t = fragments_t * fsize_t
self.frame_interval = frame_interval
self.num_clips = num_clips
self.drop_rate = drop_rate
def get_frame_indices(self, num_frames, train=False):
tgrids = np.array(
[num_frames // self.fragments_t * i for i in range(self.fragments_t)],
dtype=np.int32,
)
tlength = num_frames // self.fragments_t
if tlength > self.fsize_t * self.frame_interval:
rnd_t = np.random.randint(
0, tlength - self.fsize_t * self.frame_interval, size=len(tgrids)
)
else:
rnd_t = np.zeros(len(tgrids), dtype=np.int32)
ranges_t = (
np.arange(self.fsize_t)[None, :] * self.frame_interval
+ rnd_t[:, None]
+ tgrids[:, None]
)
drop = random.sample(
list(range(self.fragments_t)), int(self.fragments_t * self.drop_rate)
)
dropped_ranges_t = []
for i, rt in enumerate(ranges_t):
if i not in drop:
dropped_ranges_t.append(rt)
return np.concatenate(dropped_ranges_t)
def __call__(self, total_frames, train=False, start_index=0):
frame_inds = []
for i in range(self.num_clips):
frame_inds += [self.get_frame_indices(total_frames)]
frame_inds = np.concatenate(frame_inds)
frame_inds = np.mod(frame_inds + start_index, total_frames)
return frame_inds.astype(np.int32)
class ViewDecompositionDataset(torch.utils.data.Dataset):
def __init__(self, opt):
## opt is a dictionary that includes options for video sampling
super().__init__()
self.weight = opt.get("weight", 0.5)
self.fully_supervised = opt.get("fully_supervised", False)
print("Fully supervised:", self.fully_supervised)
self.video_infos = []
self.ann_file = opt["anno_file"]
self.data_prefix = opt["data_prefix"]
self.opt = opt
self.sample_types = opt["sample_types"]
self.data_backend = opt.get("data_backend", "disk")
self.augment = opt.get("augment", False)
if self.data_backend == "petrel":
from petrel_client import client
self.client = client.Client(enable_mc=True)
self.phase = opt["phase"]
self.crop = opt.get("random_crop", False)
self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
self.std = torch.FloatTensor([58.395, 57.12, 57.375])
self.mean_semantic = torch.FloatTensor([122.77, 116.75, 104.09])
self.std_semantic = torch.FloatTensor([68.50, 66.63, 70.32])
self.samplers = {}
for stype, sopt in opt["sample_types"].items():
if "t_frag" not in sopt:
# resized temporal sampling for TQE in COVER
self.samplers[stype] = UnifiedFrameSampler(
sopt["clip_len"], sopt["num_clips"], sopt["frame_interval"]
)
else:
# temporal sampling for AQE in COVER
self.samplers[stype] = UnifiedFrameSampler(
sopt["clip_len"] // sopt["t_frag"],
sopt["t_frag"],
sopt["frame_interval"],
sopt["num_clips"],
)
print(
stype + " branch sampled frames:",
self.samplers[stype](240, self.phase == "train"),
)
if isinstance(self.ann_file, list):
self.video_infos = self.ann_file
else:
try:
with open(self.ann_file, "r") as fin:
for line in fin:
line_split = line.strip().split(",")
filename, a, t, label = line_split
if self.fully_supervised:
label = float(a), float(t), float(label)
else:
label = float(label)
filename = osp.join(self.data_prefix, filename)
self.video_infos.append(dict(filename=filename, label=label))
except:
#### No Label Testing
video_filenames = []
for (root, dirs, files) in os.walk(self.data_prefix, topdown=True):
for file in files:
if file.endswith(".mp4"):
video_filenames += [os.path.join(root, file)]
print(len(video_filenames))
video_filenames = sorted(video_filenames)
for filename in video_filenames:
self.video_infos.append(dict(filename=filename, label=-1))
def __getitem__(self, index):
video_info = self.video_infos[index]
filename = video_info["filename"]
label = video_info["label"]
try:
## Read Original Frames
## Process Frames
data, frame_inds = spatial_temporal_view_decomposition(
filename,
self.sample_types,
self.samplers,
self.phase == "train",
self.augment and (self.phase == "train"),
)
for k, v in data.items():
if k == 'technical' or k == 'aesthetic':
data[k] = ((v.permute(1, 2, 3, 0) - self.mean) / self.std).permute(
3, 0, 1, 2
)
elif k == 'semantic' :
data[k] = ((v.permute(1, 2, 3, 0) - self.mean_semantic) / self.std_semantic).permute(
3, 0, 1, 2
)
data["num_clips"] = {}
for stype, sopt in self.sample_types.items():
data["num_clips"][stype] = sopt["num_clips"]
data["frame_inds"] = frame_inds
data["gt_label"] = label
data["name"] = filename # osp.basename(video_info["filename"])
except:
# exception flow
return {"name": filename}
return data
def __len__(self):
return len(self.video_infos)