import decord import cv2 import os, io, csv, torch, math, random from typing import Optional from einops import rearrange import numpy as np from decord import VideoReader from petrel_client.client import Client from torch.utils.data.dataset import Dataset import torchvision.transforms as transforms from torch.utils.data.distributed import DistributedSampler import animatediff.data.video_transformer as video_transforms from animatediff.utils.util import zero_rank_print, detect_edges, prepare_mask_coef_by_score def get_score(video_data, cond_frame_idx, weight=[1.0, 1.0, 1.0, 1.0], use_edge=True): """ Similar to get_score under utils/util.py/detect_edges """ """ the shape of video_data is f c h w, np.ndarray """ h, w = video_data.shape[1], video_data.shape[2] cond_frame = video_data[cond_frame_idx] cond_hsv_list = list( cv2.split( cv2.cvtColor(cond_frame.astype(np.float32), cv2.COLOR_RGB2HSV))) if use_edge: cond_frame_lum = cond_hsv_list[-1] cond_frame_edge = detect_edges(cond_frame_lum.astype(np.uint8)) cond_hsv_list.append(cond_frame_edge) score_sum = [] for frame_idx in range(video_data.shape[0]): frame = video_data[frame_idx] hsv_list = list( cv2.split(cv2.cvtColor(frame.astype(np.float32), cv2.COLOR_RGB2HSV))) if use_edge: frame_img_lum = hsv_list[-1] frame_img_edge = detect_edges(lum=frame_img_lum.astype(np.uint8)) hsv_list.append(frame_img_edge) hsv_diff = [ np.abs(hsv_list[c] - cond_hsv_list[c]) for c in range(len(weight)) ] hsv_mse = [np.sum(hsv_diff[c]) * weight[c] for c in range(len(weight))] score_sum.append(sum(hsv_mse) / (h * w) / (sum(weight))) return score_sum class WebVid10M(Dataset): def __init__( self, csv_path, sample_n_frames, sample_stride, sample_size=[320,512], conf_path="~/petreloss.conf", static_video=False, is_image=False, ): zero_rank_print(f"initializing ceph client ...") self._client = Client(conf_path=conf_path, enable_mc=True) self.sample_n_frames = sample_n_frames self.sample_stride = sample_stride self.temporal_sampler = video_transforms.TemporalRandomCrop(sample_n_frames * sample_stride) self.static_video = static_video self.is_image = is_image zero_rank_print(f"(~1 mins) loading annotations from {csv_path} ...") with open(csv_path, 'r') as csvfile: self.dataset = list(csv.DictReader(csvfile)) self.length = len(self.dataset) zero_rank_print(f"data scale: {self.length}") sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) self.pixel_transforms = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.Resize(sample_size[0]), transforms.CenterCrop(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) def get_batch(self, idx): video_dict = self.dataset[idx] videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] ceph_dir = f"webvideo:s3://WebVid10M/{page_dir}/{videoid}.mp4" video_bytes = self._client.Get(ceph_dir) video_bytes = io.BytesIO(video_bytes) # ensure not reading zero byte assert video_bytes.getbuffer().nbytes != 0 video_reader = VideoReader(video_bytes) total_frames = len(video_reader) if not self.is_image: if self.static_video: frame_indice = random.randint(0, total_frames-1) frame_indice = np.linspace(frame_indice, frame_indice, self.sample_n_frames, dtype=int) else: start_frame_ind, end_frame_ind = self.temporal_sampler(total_frames) assert end_frame_ind - start_frame_ind >= self.sample_n_frames frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int) else: frame_indice = [random.randint(0, total_frames - 1)] pixel_values_np = video_reader.get_batch(frame_indice).asnumpy() cond_frames = random.randint(0, self.sample_n_frames - 1) # f h w c -> f c h w pixel_values = torch.from_numpy(pixel_values_np).permute(0, 3, 1, 2).contiguous() pixel_values = pixel_values / 255. del video_reader if self.is_image: pixel_values = pixel_values[0] return pixel_values, name, cond_frames, videoid def __len__(self): return self.length def __getitem__(self, idx): while True: try: video, name, cond_frames, videoid = self.get_batch(idx) break except Exception as e: # zero_rank_print(e) idx = random.randint(0, self.length-1) video = self.pixel_transforms(video) video_ = video.clone().permute(0, 2, 3, 1).numpy() / 2 + 0.5 video_ = video_ * 255 #video_ = video_.astype(np.uint8) score = get_score(video_, cond_frame_idx=cond_frames) del video_ sample = dict(pixel_values=video, text=name, score=score, cond_frames=cond_frames, vid=videoid) return sample if __name__ == "__main__": dataset = WebVid10M( csv_path="results_10M_train.csv", sample_size=(320,512), sample_n_frames=16, sample_stride=4, static_video=False, is_image=False, ) distributed_sampler = DistributedSampler( dataset, num_replicas=1, rank=0, shuffle=True, seed=5, ) batch_size = 1 dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0, sampler=distributed_sampler) STATISTIC = [[0., 0.], [0.3535855, 24.23687346], [0.91609545, 30.65091947], [1.41165152, 34.40093286], [1.56943881, 36.99639585], [1.73182842, 39.42044163], [1.82733002, 40.94703526], [1.88060527, 42.66233244], [1.96208071, 43.73070788], [2.02723091, 44.25965378], [2.10820894, 45.66120213], [2.21115041, 46.29561324], [2.23412351, 47.08810863], [2.29430165, 47.9515062], [2.32986362, 48.69085638], [2.37310751, 49.19931439]] for idx, batch in enumerate(dataloader): pixel_values, texts, vid = batch['pixel_values'], batch['text'], batch['vid'] pixel_values = (pixel_values.clone()) / 2. + 0.5 pixel_values*= 255 score = get_score(pixel_values) cond_frames = [0] * len(batch_size) score = prepare_mask_coef_by_score(pixel_values, cond_frames, statistic=STATISTIC) print(f'num: {idx}, diff: {score}')