videocrafter / lvdm /data /webvid.py
RamAnanth1's picture
Upload 14 files
b6b5d48
raw
history blame
No virus
7.65 kB
import os
import random
import bisect
import pandas as pd
import omegaconf
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from decord import VideoReader, cpu
import torchvision.transforms._transforms_video as transforms_video
class WebVid(Dataset):
"""
WebVid Dataset.
Assumes webvid data is structured as follows.
Webvid/
videos/
000001_000050/ ($page_dir)
1.mp4 (videoid.mp4)
...
5000.mp4
...
"""
def __init__(self,
meta_path,
data_dir,
subsample=None,
video_length=16,
resolution=[256, 512],
frame_stride=1,
spatial_transform=None,
crop_resolution=None,
fps_max=None,
load_raw_resolution=False,
fps_schedule=None,
fs_probs=None,
bs_per_gpu=None,
trigger_word='',
dataname='',
):
self.meta_path = meta_path
self.data_dir = data_dir
self.subsample = subsample
self.video_length = video_length
self.resolution = [resolution, resolution] if isinstance(resolution, int) else resolution
self.frame_stride = frame_stride
self.fps_max = fps_max
self.load_raw_resolution = load_raw_resolution
self.fs_probs = fs_probs
self.trigger_word = trigger_word
self.dataname = dataname
self._load_metadata()
if spatial_transform is not None:
if spatial_transform == "random_crop":
self.spatial_transform = transforms_video.RandomCropVideo(crop_resolution)
elif spatial_transform == "resize_center_crop":
assert(self.resolution[0] == self.resolution[1])
self.spatial_transform = transforms.Compose([
transforms.Resize(resolution),
transforms_video.CenterCropVideo(resolution),
])
else:
raise NotImplementedError
else:
self.spatial_transform = None
self.fps_schedule = fps_schedule
self.bs_per_gpu = bs_per_gpu
if self.fps_schedule is not None:
assert(self.bs_per_gpu is not None)
self.counter = 0
self.stage_idx = 0
def _load_metadata(self):
metadata = pd.read_csv(self.meta_path)
if self.subsample is not None:
metadata = metadata.sample(self.subsample, random_state=0)
metadata['caption'] = metadata['name']
del metadata['name']
self.metadata = metadata
self.metadata.dropna(inplace=True)
# self.metadata['caption'] = self.metadata['caption'].str[:350]
def _get_video_path(self, sample):
if self.dataname == "loradata":
rel_video_fp = str(sample['videoid']) + '.mp4'
full_video_fp = os.path.join(self.data_dir, rel_video_fp)
else:
rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp)
return full_video_fp, rel_video_fp
def get_fs_based_on_schedule(self, frame_strides, schedule):
assert(len(frame_strides) == len(schedule) + 1) # nstage=len_fps_schedule + 1
global_step = self.counter // self.bs_per_gpu # TODO: support resume.
stage_idx = bisect.bisect(schedule, global_step)
frame_stride = frame_strides[stage_idx]
# log stage change
if stage_idx != self.stage_idx:
print(f'fps stage: {stage_idx} start ... new frame stride = {frame_stride}')
self.stage_idx = stage_idx
return frame_stride
def get_fs_based_on_probs(self, frame_strides, probs):
assert(len(frame_strides) == len(probs))
return random.choices(frame_strides, weights=probs)[0]
def get_fs_randomly(self, frame_strides):
return random.choice(frame_strides)
def __getitem__(self, index):
if isinstance(self.frame_stride, list) or isinstance(self.frame_stride, omegaconf.listconfig.ListConfig):
if self.fps_schedule is not None:
frame_stride = self.get_fs_based_on_schedule(self.frame_stride, self.fps_schedule)
elif self.fs_probs is not None:
frame_stride = self.get_fs_based_on_probs(self.frame_stride, self.fs_probs)
else:
frame_stride = self.get_fs_randomly(self.frame_stride)
else:
frame_stride = self.frame_stride
assert(isinstance(frame_stride, int)), type(frame_stride)
while True:
index = index % len(self.metadata)
sample = self.metadata.iloc[index]
video_path, rel_fp = self._get_video_path(sample)
caption = sample['caption']+self.trigger_word
# make reader
try:
if self.load_raw_resolution:
video_reader = VideoReader(video_path, ctx=cpu(0))
else:
video_reader = VideoReader(video_path, ctx=cpu(0), width=self.resolution[1], height=self.resolution[0])
if len(video_reader) < self.video_length:
print(f"video length ({len(video_reader)}) is smaller than target length({self.video_length})")
index += 1
continue
else:
pass
except:
index += 1
print(f"Load video failed! path = {video_path}")
continue
# sample strided frames
all_frames = list(range(0, len(video_reader), frame_stride))
if len(all_frames) < self.video_length: # recal a max fs
frame_stride = len(video_reader) // self.video_length
assert(frame_stride != 0)
all_frames = list(range(0, len(video_reader), frame_stride))
# select a random clip
rand_idx = random.randint(0, len(all_frames) - self.video_length)
frame_indices = all_frames[rand_idx:rand_idx+self.video_length]
try:
frames = video_reader.get_batch(frame_indices)
break
except:
print(f"Get frames failed! path = {video_path}")
index += 1
continue
assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}'
frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w]
if self.spatial_transform is not None:
frames = self.spatial_transform(frames)
if self.resolution is not None:
assert(frames.shape[2] == self.resolution[0] and frames.shape[3] == self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}'
frames = (frames / 255 - 0.5) * 2
fps_ori = video_reader.get_avg_fps()
fps_clip = fps_ori // frame_stride
if self.fps_max is not None and fps_clip > self.fps_max:
fps_clip = self.fps_max
data = {'video': frames, 'caption': caption, 'path': video_path, 'fps': fps_clip, 'frame_stride': frame_stride}
if self.fps_schedule is not None:
self.counter += 1
return data
def __len__(self):
return len(self.metadata)