PIA / animatediff /data /dataset.py
LeoXing1996
init repo for fg
a001281
raw
history blame
No virus
4.66 kB
import os, io, csv, math, random
import numpy as np
from einops import rearrange
from decord import VideoReader
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
from PIA.utils.util import zero_rank_print, detect_edges
import cv2
def get_score(video_data,
cond_frame_idx,
weight=[1.0, 1.0, 1.0, 1.0],
use_edge=True):
"""
Similar to get_score under utils/util.py/detect_edges
"""
"""
the shape of video_data is f c h w, np.ndarray
"""
h, w = video_data.shape[1], video_data.shape[2]
cond_frame = video_data[cond_frame_idx]
cond_hsv_list = list(
cv2.split(
cv2.cvtColor(cond_frame.astype(np.float32), cv2.COLOR_RGB2HSV)))
if use_edge:
cond_frame_lum = cond_hsv_list[-1]
cond_frame_edge = detect_edges(cond_frame_lum.astype(np.uint8))
cond_hsv_list.append(cond_frame_edge)
score_sum = []
for frame_idx in range(video_data.shape[0]):
frame = video_data[frame_idx]
hsv_list = list(
cv2.split(cv2.cvtColor(frame.astype(np.float32),
cv2.COLOR_RGB2HSV)))
if use_edge:
frame_img_lum = hsv_list[-1]
frame_img_edge = detect_edges(lum=frame_img_lum.astype(np.uint8))
hsv_list.append(frame_img_edge)
hsv_diff = [
np.abs(hsv_list[c] - cond_hsv_list[c]) for c in range(len(weight))
]
hsv_mse = [np.sum(hsv_diff[c]) * weight[c] for c in range(len(weight))]
score_sum.append(sum(hsv_mse) / (h * w) / (sum(weight)))
return score_sum
class WebVid10M(Dataset):
def __init__(
self,
csv_path, video_folder,
sample_size=256, sample_stride=4, sample_n_frames=16,
is_image=False,
):
zero_rank_print(f"loading annotations from {csv_path} ...")
with open(csv_path, 'r') as csvfile:
self.dataset = list(csv.DictReader(csvfile))
self.length = len(self.dataset)
zero_rank_print(f"data scale: {self.length}")
self.video_folder = video_folder
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.is_image = is_image
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
self.pixel_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(sample_size[0]),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
def get_batch(self, idx):
video_dict = self.dataset[idx]
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
video_reader = VideoReader(video_dir)
video_length = len(video_reader)
total_frames = len(video_reader)
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
start_idx = random.randint(0, video_length - clip_length)
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
frame_indice = [random.randint(0, total_frames - 1)]
pixel_values_np = video_reader.get_batch(frame_indice).asnumpy()
cond_frames = random.randint(0, self.sample_n_frames - 1)
# f h w c -> f c h w
pixel_values = torch.from_numpy(pixel_values_np).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.
del video_reader
if self.is_image:
pixel_values = pixel_values[0]
return pixel_values, name, cond_frames, videoid
def __len__(self):
return self.length
def __getitem__(self, idx):
while True:
try:
video, name, cond_frames, videoid = self.get_batch(idx)
break
except Exception as e:
# zero_rank_print(e)
idx = random.randint(0, self.length-1)
video = self.pixel_transforms(video)
video_ = video.clone().permute(0, 2, 3, 1).numpy() / 2 + 0.5
video_ = video_ * 255
#video_ = video_.astype(np.uint8)
score = get_score(video_, cond_frame_idx=cond_frames)
del video_
sample = dict(pixel_values=video, text=name, score=score, cond_frames=cond_frames, vid=videoid)
return sample