import json import os, io, csv, math, random import numpy as np import torchvision from einops import rearrange from decord import VideoReader from os.path import join as opj import gc import torch import torchvision.transforms as transforms from torch.utils.data.dataset import Dataset from tqdm import tqdm from PIL import Image from opensora.utils.dataset_utils import DecordInit from opensora.utils.utils import text_preprocessing def random_video_noise(t, c, h, w): vid = torch.rand(t, c, h, w) * 255.0 vid = vid.to(torch.uint8) return vid class T2V_dataset(Dataset): def __init__(self, args, transform, temporal_sample, tokenizer): self.image_data = args.image_data self.video_data = args.video_data self.num_frames = args.num_frames self.transform = transform self.temporal_sample = temporal_sample self.tokenizer = tokenizer self.model_max_length = args.model_max_length self.v_decoder = DecordInit() self.vid_cap_list = self.get_vid_cap_list() self.use_image_num = args.use_image_num self.use_img_from_vid = args.use_img_from_vid if self.use_image_num != 0 and not self.use_img_from_vid: self.img_cap_list = self.get_img_cap_list() def __len__(self): return len(self.vid_cap_list) def __getitem__(self, idx): try: # import ipdb;ipdb.set_trace() video_data = self.get_video(idx) image_data = {} if self.use_image_num != 0 and self.use_img_from_vid: image_data = self.get_image_from_video(video_data) elif self.use_image_num != 0 and not self.use_img_from_vid: image_data = self.get_image(idx) else: raise NotImplementedError gc.collect() return dict(video_data=video_data, image_data=image_data) except Exception as e: # print(f'Error with {e}, {self.vid_cap_list[idx]}') if os.path.exists(self.vid_cap_list[idx]['path']) and '_resize_1080p' in self.vid_cap_list[idx]['path']: os.remove(self.vid_cap_list[idx]['path']) print('remove:', self.vid_cap_list[idx]['path']) return self.__getitem__(random.randint(0, self.__len__() - 1)) def get_video(self, idx): # video = random.choice([random_video_noise(65, 3, 720, 360) * 255, random_video_noise(65, 3, 1024, 1024), random_video_noise(65, 3, 360, 720)]) # # print('random shape', video.shape) # input_ids = torch.ones(1, 120).to(torch.long).squeeze(0) # cond_mask = torch.cat([torch.ones(1, 60).to(torch.long), torch.ones(1, 60).to(torch.long)], dim=1).squeeze(0) video_path = self.vid_cap_list[idx]['path'] frame_idx = self.vid_cap_list[idx]['frame_idx'] #print('before decord') video = self.decord_read(video_path, frame_idx) # video = self.tv_read(video_path, frame_idx) #print('after decord') video = self.transform(video) # T C H W -> T C H W # del raw_video # gc.collect() # video = torch.rand(65, 3, 512, 512) #print('after transform') video = video.transpose(0, 1) # T C H W -> C T H W text = self.vid_cap_list[idx]['cap'] text = text_preprocessing(text) text_tokens_and_mask = self.tokenizer( text, max_length=self.model_max_length, padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt' ) input_ids = text_tokens_and_mask['input_ids'] cond_mask = text_tokens_and_mask['attention_mask'] return dict(video=video, input_ids=input_ids, cond_mask=cond_mask) def get_image_from_video(self, video_data): select_image_idx = np.linspace(0, self.num_frames-1, self.use_image_num, dtype=int) assert self.num_frames >= self.use_image_num image = [video_data['video'][:, i:i+1] for i in select_image_idx] # num_img [c, 1, h, w] input_ids = video_data['input_ids'].repeat(self.use_image_num, 1) # self.use_image_num, l cond_mask = video_data['cond_mask'].repeat(self.use_image_num, 1) # self.use_image_num, l return dict(image=image, input_ids=input_ids, cond_mask=cond_mask) def get_image(self, idx): idx = idx % len(self.img_cap_list) # out of range image_data = self.img_cap_list[idx] # [{'path': path, 'cap': cap}, ...] image = [Image.open(i['path']).convert('RGB') for i in image_data] # num_img [h, w, c] image = [torch.from_numpy(np.array(i)) for i in image] # num_img [h, w, c] image = [rearrange(i, 'h w c -> c h w').unsqueeze(0) for i in image] # num_img [1 c h w] image = [self.transform(i) for i in image] # num_img [1 C H W] -> num_img [1 C H W] image = [i.transpose(0, 1) for i in image] # num_img [1 C H W] -> num_img [C 1 H W] caps = [i['cap'] for i in image_data] text = [text_preprocessing(cap) for cap in caps] input_ids, cond_mask = [], [] for t in text: text_tokens_and_mask = self.tokenizer( t, max_length=self.model_max_length, padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt' ) input_ids.append(text_tokens_and_mask['input_ids']) cond_mask.append(text_tokens_and_mask['attention_mask']) input_ids = torch.cat(input_ids) # self.use_image_num, l cond_mask = torch.cat(cond_mask) # self.use_image_num, l return dict(image=image, input_ids=input_ids, cond_mask=cond_mask) def tv_read(self, path, frame_idx=None): vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') total_frames = len(vframes) if frame_idx is None: start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) else: start_frame_ind, end_frame_ind = frame_idx.split(':') start_frame_ind, end_frame_ind = int(start_frame_ind), int(end_frame_ind) # assert end_frame_ind - start_frame_ind >= self.num_frames frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) # frame_indice = np.linspace(0, 63, self.num_frames, dtype=int) video = vframes[frame_indice] # (T, C, H, W) return video def decord_read(self, path, frame_idx=None): decord_vr = self.v_decoder(path) total_frames = len(decord_vr) # Sampling video frames if frame_idx is None: start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) else: start_frame_ind, end_frame_ind = frame_idx.split(':') start_frame_ind, end_frame_ind = int(start_frame_ind), int(end_frame_ind) start_frame_ind, end_frame_ind = int(start_frame_ind), int(start_frame_ind) + self.num_frames # assert end_frame_ind - start_frame_ind >= self.num_frames frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) # frame_indice = np.linspace(0, 63, self.num_frames, dtype=int) video_data = decord_vr.get_batch(frame_indice).asnumpy() video_data = torch.from_numpy(video_data) video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) return video_data def get_vid_cap_list(self): vid_cap_lists = [] with open(self.video_data, 'r') as f: folder_anno = [i.strip().split(',') for i in f.readlines() if len(i.strip()) > 0] # print(folder_anno) for folder, anno in folder_anno: with open(anno, 'r') as f: vid_cap_list = json.load(f) print(f'Building {anno}...') for i in tqdm(range(len(vid_cap_list))): path = opj(folder, vid_cap_list[i]['path']) if os.path.exists(path.replace('.mp4', '_resize_1080p.mp4')): path = path.replace('.mp4', '_resize_1080p.mp4') vid_cap_list[i]['path'] = path vid_cap_lists += vid_cap_list return vid_cap_lists def get_img_cap_list(self): img_cap_lists = [] with open(self.image_data, 'r') as f: folder_anno = [i.strip().split(',') for i in f.readlines() if len(i.strip()) > 0] for folder, anno in folder_anno: with open(anno, 'r') as f: img_cap_list = json.load(f) print(f'Building {anno}...') for i in tqdm(range(len(img_cap_list))): img_cap_list[i]['path'] = opj(folder, img_cap_list[i]['path']) img_cap_lists += img_cap_list img_cap_lists = [img_cap_lists[i: i+self.use_image_num] for i in range(0, len(img_cap_lists), self.use_image_num)] return img_cap_lists[:-1] # drop last to avoid error length