bluestarburst's picture
Upload folder using huggingface_hub
09bf9a3
import decord
decord.bridge.set_bridge('torch')
import os, io, csv, math, random
import numpy as np
from einops import rearrange
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
from animatediff.utils.util import zero_rank_print
class WebVid10M(Dataset):
def __init__(
self,
csv_path, video_folder,
sample_size=256, sample_stride=4, sample_n_frames=16,
is_image=False,
):
zero_rank_print(f"loading annotations from {csv_path} ...")
with open(csv_path, 'r') as csvfile:
self.dataset = list(csv.DictReader(csvfile))
self.length = len(self.dataset)
zero_rank_print(f"data scale: {self.length}")
self.video_folder = video_folder
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.is_image = is_image
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
self.pixel_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(sample_size[0]),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
def get_batch(self, idx):
video_dict = self.dataset[idx]
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
video_reader = decord.VideoReader(video_dir)
video_length = len(video_reader)
if not self.is_image:
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
start_idx = random.randint(0, video_length - clip_length)
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
else:
batch_index = [random.randint(0, video_length - 1)]
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.
del video_reader
if self.is_image:
pixel_values = pixel_values[0]
return pixel_values, name
def __len__(self):
return self.length
def __getitem__(self, idx):
while True:
try:
pixel_values, name = self.get_batch(idx)
break
except Exception as e:
idx = random.randint(0, self.length-1)
pixel_values = self.pixel_transforms(pixel_values)
sample = dict(pixel_values=pixel_values, text=name)
return sample
# implement the same dataset but use the first frames of the video instead of random frames
class ImgSeqDataset(Dataset):
def __init__(
self,
csv_path, video_folder,
sample_size=256, sample_stride=4, sample_n_frames=16,
is_image=False,
):
zero_rank_print(f"loading annotations from {csv_path} ...")
with open(csv_path, 'r') as csvfile:
self.dataset = list(csv.DictReader(csvfile))
self.length = len(self.dataset)
zero_rank_print(f"data scale: {self.length}")
self.video_folder = video_folder
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.is_image = is_image
self.prompt = [video_dict['name'] for video_dict in self.dataset]
self.prompt_ids = [None]
self.width = sample_size
self.height = sample_size
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
self.pixel_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(sample_size[0]),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
def get_batch(self, idx):
video_dict = self.dataset[idx]
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
video_reader = decord.VideoReader(video_dir)
video_length = len(video_reader)
if not self.is_image:
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
start_idx = 0
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
else:
batch_index = [random.randint(0, video_length - 1)]
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.
del video_reader
if self.is_image:
pixel_values = pixel_values[0]
return pixel_values, name
def __len__(self):
return self.length
def __getitem__(self, idx):
if not self.is_image:
video_dict = self.dataset[idx]
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
# load and sample video frames
vr = decord.VideoReader(video_dir, width=self.width, height=self.height)
sample_index = list(range(0, len(vr), 1))[:self.sample_n_frames]
video = vr.get_batch(sample_index)
video = rearrange(video, "f h w c -> f c h w")
example = {
"pixel_values": (video / 127.5 - 1.0),
"prompt_ids": self.prompt_ids[idx]
}
return example
while True:
try:
pixel_values, name = self.get_batch(idx)
break
except Exception as e:
idx = random.randint(0, self.length-1)
pixel_values = self.pixel_transforms(pixel_values)
sample = dict(pixel_values=pixel_values, text=name)
return sample
if __name__ == "__main__":
from animatediff.utils.util import save_videos_grid
dataset = WebVid10M(
csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
sample_size=256,
sample_stride=4, sample_n_frames=16,
is_image=True,
)
import pdb
pdb.set_trace()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,)
for idx, batch in enumerate(dataloader):
print(batch["pixel_values"].shape, len(batch["text"]))
# for i in range(batch["pixel_values"].shape[0]):
# save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)