import json import random import torch import torchvision.transforms as transforms from decord import VideoReader from PIL import Image from torch.utils.data import Dataset from transformers import CLIPImageProcessor class HumanDanceDataset(Dataset): def __init__( self, img_size, img_scale=(1.0, 1.0), img_ratio=(0.9, 1.0), drop_ratio=0.1, data_meta_paths=["./data/fahsion_meta.json"], sample_margin=30, ): super().__init__() self.img_size = img_size self.img_scale = img_scale self.img_ratio = img_ratio self.sample_margin = sample_margin # ----- # vid_meta format: # [{'video_path': , 'kps_path': , 'other':}, # {'video_path': , 'kps_path': , 'other':}] # ----- vid_meta = [] for data_meta_path in data_meta_paths: vid_meta.extend(json.load(open(data_meta_path, "r"))) self.vid_meta = vid_meta self.clip_image_processor = CLIPImageProcessor() self.transform = transforms.Compose( [ # transforms.RandomResizedCrop( # self.img_size, # scale=self.img_scale, # ratio=self.img_ratio, # interpolation=transforms.InterpolationMode.BILINEAR, # ), transforms.Resize( self.img_size, ), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) self.cond_transform = transforms.Compose( [ # transforms.RandomResizedCrop( # self.img_size, # scale=self.img_scale, # ratio=self.img_ratio, # interpolation=transforms.InterpolationMode.BILINEAR, # ), transforms.Resize( self.img_size, ), transforms.ToTensor(), ] ) self.drop_ratio = drop_ratio def augmentation(self, image, transform, state=None): if state is not None: torch.set_rng_state(state) return transform(image) def __getitem__(self, index): video_meta = self.vid_meta[index] video_path = video_meta["video_path"] kps_path = video_meta["kps_path"] video_reader = VideoReader(video_path) kps_reader = VideoReader(kps_path) assert len(video_reader) == len( kps_reader ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}" video_length = len(video_reader) margin = min(self.sample_margin, video_length) ref_img_idx = random.randint(0, video_length - 1) if ref_img_idx + margin < video_length: tgt_img_idx = random.randint(ref_img_idx + margin, video_length - 1) elif ref_img_idx - margin > 0: tgt_img_idx = random.randint(0, ref_img_idx - margin) else: tgt_img_idx = random.randint(0, video_length - 1) ref_img = video_reader[ref_img_idx] ref_img_pil = Image.fromarray(ref_img.asnumpy()) tgt_img = video_reader[tgt_img_idx] tgt_img_pil = Image.fromarray(tgt_img.asnumpy()) tgt_pose = kps_reader[tgt_img_idx] tgt_pose_pil = Image.fromarray(tgt_pose.asnumpy()) state = torch.get_rng_state() tgt_img = self.augmentation(tgt_img_pil, self.transform, state) tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state) ref_img_vae = self.augmentation(ref_img_pil, self.transform, state) clip_image = self.clip_image_processor( images=ref_img_pil, return_tensors="pt" ).pixel_values[0] sample = dict( video_dir=video_path, img=tgt_img, tgt_pose=tgt_pose_img, ref_img=ref_img_vae, clip_images=clip_image, ) return sample def __len__(self): return len(self.vid_meta)