musepose / dataset /dance_image.py
jhj0517
initial commit
7c3ff16
raw
history blame
No virus
4.1 kB
import json
import random
import torch
import torchvision.transforms as transforms
from decord import VideoReader
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPImageProcessor
class HumanDanceDataset(Dataset):
def __init__(
self,
img_size,
img_scale=(1.0, 1.0),
img_ratio=(0.9, 1.0),
drop_ratio=0.1,
data_meta_paths=["./data/fahsion_meta.json"],
sample_margin=30,
):
super().__init__()
self.img_size = img_size
self.img_scale = img_scale
self.img_ratio = img_ratio
self.sample_margin = sample_margin
# -----
# vid_meta format:
# [{'video_path': , 'kps_path': , 'other':},
# {'video_path': , 'kps_path': , 'other':}]
# -----
vid_meta = []
for data_meta_path in data_meta_paths:
vid_meta.extend(json.load(open(data_meta_path, "r")))
self.vid_meta = vid_meta
self.clip_image_processor = CLIPImageProcessor()
self.transform = transforms.Compose(
[
# transforms.RandomResizedCrop(
# self.img_size,
# scale=self.img_scale,
# ratio=self.img_ratio,
# interpolation=transforms.InterpolationMode.BILINEAR,
# ),
transforms.Resize(
self.img_size,
),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.cond_transform = transforms.Compose(
[
# transforms.RandomResizedCrop(
# self.img_size,
# scale=self.img_scale,
# ratio=self.img_ratio,
# interpolation=transforms.InterpolationMode.BILINEAR,
# ),
transforms.Resize(
self.img_size,
),
transforms.ToTensor(),
]
)
self.drop_ratio = drop_ratio
def augmentation(self, image, transform, state=None):
if state is not None:
torch.set_rng_state(state)
return transform(image)
def __getitem__(self, index):
video_meta = self.vid_meta[index]
video_path = video_meta["video_path"]
kps_path = video_meta["kps_path"]
video_reader = VideoReader(video_path)
kps_reader = VideoReader(kps_path)
assert len(video_reader) == len(
kps_reader
), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
video_length = len(video_reader)
margin = min(self.sample_margin, video_length)
ref_img_idx = random.randint(0, video_length - 1)
if ref_img_idx + margin < video_length:
tgt_img_idx = random.randint(ref_img_idx + margin, video_length - 1)
elif ref_img_idx - margin > 0:
tgt_img_idx = random.randint(0, ref_img_idx - margin)
else:
tgt_img_idx = random.randint(0, video_length - 1)
ref_img = video_reader[ref_img_idx]
ref_img_pil = Image.fromarray(ref_img.asnumpy())
tgt_img = video_reader[tgt_img_idx]
tgt_img_pil = Image.fromarray(tgt_img.asnumpy())
tgt_pose = kps_reader[tgt_img_idx]
tgt_pose_pil = Image.fromarray(tgt_pose.asnumpy())
state = torch.get_rng_state()
tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state)
ref_img_vae = self.augmentation(ref_img_pil, self.transform, state)
clip_image = self.clip_image_processor(
images=ref_img_pil, return_tensors="pt"
).pixel_values[0]
sample = dict(
video_dir=video_path,
img=tgt_img,
tgt_pose=tgt_pose_img,
ref_img=ref_img_vae,
clip_images=clip_image,
)
return sample
def __len__(self):
return len(self.vid_meta)