|
import csv |
|
import io |
|
import json |
|
import math |
|
import os |
|
import random |
|
from threading import Thread |
|
|
|
import albumentations |
|
import cv2 |
|
import gc |
|
import numpy as np |
|
import torch |
|
import torchvision.transforms as transforms |
|
|
|
from func_timeout import func_timeout, FunctionTimedOut |
|
from decord import VideoReader |
|
from PIL import Image |
|
from torch.utils.data import BatchSampler, Sampler |
|
from torch.utils.data.dataset import Dataset |
|
from contextlib import contextmanager |
|
|
|
VIDEO_READER_TIMEOUT = 20 |
|
|
|
def get_random_mask(shape): |
|
f, c, h, w = shape |
|
|
|
if f != 1: |
|
mask_index = np.random.choice([0, 1, 2, 3, 4], p = [0.05, 0.3, 0.3, 0.3, 0.05]) |
|
else: |
|
mask_index = np.random.choice([0, 1], p = [0.2, 0.8]) |
|
mask = torch.zeros((f, 1, h, w), dtype=torch.uint8) |
|
|
|
if mask_index == 0: |
|
center_x = torch.randint(0, w, (1,)).item() |
|
center_y = torch.randint(0, h, (1,)).item() |
|
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() |
|
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() |
|
|
|
start_x = max(center_x - block_size_x // 2, 0) |
|
end_x = min(center_x + block_size_x // 2, w) |
|
start_y = max(center_y - block_size_y // 2, 0) |
|
end_y = min(center_y + block_size_y // 2, h) |
|
mask[:, :, start_y:end_y, start_x:end_x] = 1 |
|
elif mask_index == 1: |
|
mask[:, :, :, :] = 1 |
|
elif mask_index == 2: |
|
mask_frame_index = np.random.randint(1, 5) |
|
mask[mask_frame_index:, :, :, :] = 1 |
|
elif mask_index == 3: |
|
mask_frame_index = np.random.randint(1, 5) |
|
mask[mask_frame_index:-mask_frame_index, :, :, :] = 1 |
|
elif mask_index == 4: |
|
center_x = torch.randint(0, w, (1,)).item() |
|
center_y = torch.randint(0, h, (1,)).item() |
|
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() |
|
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() |
|
|
|
start_x = max(center_x - block_size_x // 2, 0) |
|
end_x = min(center_x + block_size_x // 2, w) |
|
start_y = max(center_y - block_size_y // 2, 0) |
|
end_y = min(center_y + block_size_y // 2, h) |
|
|
|
mask_frame_before = np.random.randint(0, f // 2) |
|
mask_frame_after = np.random.randint(f // 2, f) |
|
mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1 |
|
else: |
|
raise ValueError(f"The mask_index {mask_index} is not define") |
|
return mask |
|
|
|
class ImageVideoSampler(BatchSampler): |
|
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch. |
|
|
|
Args: |
|
sampler (Sampler): Base sampler. |
|
dataset (Dataset): Dataset providing data information. |
|
batch_size (int): Size of mini-batch. |
|
drop_last (bool): If ``True``, the sampler will drop the last batch if |
|
its size would be less than ``batch_size``. |
|
aspect_ratios (dict): The predefined aspect ratios. |
|
""" |
|
|
|
def __init__(self, |
|
sampler: Sampler, |
|
dataset: Dataset, |
|
batch_size: int, |
|
drop_last: bool = False |
|
) -> None: |
|
if not isinstance(sampler, Sampler): |
|
raise TypeError('sampler should be an instance of ``Sampler``, ' |
|
f'but got {sampler}') |
|
if not isinstance(batch_size, int) or batch_size <= 0: |
|
raise ValueError('batch_size should be a positive integer value, ' |
|
f'but got batch_size={batch_size}') |
|
self.sampler = sampler |
|
self.dataset = dataset |
|
self.batch_size = batch_size |
|
self.drop_last = drop_last |
|
|
|
|
|
self.bucket = {'image':[], 'video':[]} |
|
|
|
def __iter__(self): |
|
for idx in self.sampler: |
|
content_type = self.dataset.dataset[idx].get('type', 'image') |
|
self.bucket[content_type].append(idx) |
|
|
|
|
|
if len(self.bucket['video']) == self.batch_size: |
|
bucket = self.bucket['video'] |
|
yield bucket[:] |
|
del bucket[:] |
|
elif len(self.bucket['image']) == self.batch_size: |
|
bucket = self.bucket['image'] |
|
yield bucket[:] |
|
del bucket[:] |
|
|
|
@contextmanager |
|
def VideoReader_contextmanager(*args, **kwargs): |
|
vr = VideoReader(*args, **kwargs) |
|
try: |
|
yield vr |
|
finally: |
|
del vr |
|
gc.collect() |
|
|
|
def get_video_reader_batch(video_reader, batch_index): |
|
frames = video_reader.get_batch(batch_index).asnumpy() |
|
return frames |
|
|
|
def resize_frame(frame, target_short_side): |
|
h, w, _ = frame.shape |
|
if h < w: |
|
if target_short_side > h: |
|
return frame |
|
new_h = target_short_side |
|
new_w = int(target_short_side * w / h) |
|
else: |
|
if target_short_side > w: |
|
return frame |
|
new_w = target_short_side |
|
new_h = int(target_short_side * h / w) |
|
|
|
resized_frame = cv2.resize(frame, (new_w, new_h)) |
|
return resized_frame |
|
|
|
class ImageVideoDataset(Dataset): |
|
def __init__( |
|
self, |
|
ann_path, data_root=None, |
|
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, |
|
image_sample_size=512, |
|
video_repeat=0, |
|
text_drop_ratio=-1, |
|
enable_bucket=False, |
|
video_length_drop_start=0.1, |
|
video_length_drop_end=0.9, |
|
enable_inpaint=False, |
|
): |
|
|
|
print(f"loading annotations from {ann_path} ...") |
|
if ann_path.endswith('.csv'): |
|
with open(ann_path, 'r') as csvfile: |
|
dataset = list(csv.DictReader(csvfile)) |
|
elif ann_path.endswith('.json'): |
|
dataset = json.load(open(ann_path)) |
|
|
|
self.data_root = data_root |
|
|
|
|
|
self.dataset = [] |
|
for data in dataset: |
|
if data.get('type', 'image') != 'video': |
|
self.dataset.append(data) |
|
if video_repeat > 0: |
|
for _ in range(video_repeat): |
|
for data in dataset: |
|
if data.get('type', 'image') == 'video': |
|
self.dataset.append(data) |
|
del dataset |
|
|
|
self.length = len(self.dataset) |
|
print(f"data scale: {self.length}") |
|
|
|
self.enable_bucket = enable_bucket |
|
self.text_drop_ratio = text_drop_ratio |
|
self.enable_inpaint = enable_inpaint |
|
|
|
self.video_length_drop_start = video_length_drop_start |
|
self.video_length_drop_end = video_length_drop_end |
|
|
|
|
|
self.video_sample_stride = video_sample_stride |
|
self.video_sample_n_frames = video_sample_n_frames |
|
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) |
|
self.video_transforms = transforms.Compose( |
|
[ |
|
transforms.Resize(min(self.video_sample_size)), |
|
transforms.CenterCrop(self.video_sample_size), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
] |
|
) |
|
|
|
|
|
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) |
|
self.image_transforms = transforms.Compose([ |
|
transforms.Resize(min(self.image_sample_size)), |
|
transforms.CenterCrop(self.image_sample_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) |
|
]) |
|
|
|
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size)) |
|
|
|
def get_batch(self, idx): |
|
data_info = self.dataset[idx % len(self.dataset)] |
|
|
|
if data_info.get('type', 'image')=='video': |
|
video_id, text = data_info['file_path'], data_info['text'] |
|
|
|
if self.data_root is None: |
|
video_dir = video_id |
|
else: |
|
video_dir = os.path.join(self.data_root, video_id) |
|
|
|
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: |
|
min_sample_n_frames = min( |
|
self.video_sample_n_frames, |
|
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) |
|
) |
|
if min_sample_n_frames == 0: |
|
raise ValueError(f"No Frames in video.") |
|
|
|
video_length = int(self.video_length_drop_end * len(video_reader)) |
|
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) |
|
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 |
|
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) |
|
|
|
try: |
|
sample_args = (video_reader, batch_index) |
|
pixel_values = func_timeout( |
|
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args |
|
) |
|
resized_frames = [] |
|
for i in range(len(pixel_values)): |
|
frame = pixel_values[i] |
|
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) |
|
resized_frames.append(resized_frame) |
|
pixel_values = np.array(resized_frames) |
|
except FunctionTimedOut: |
|
raise ValueError(f"Read {idx} timeout.") |
|
except Exception as e: |
|
raise ValueError(f"Failed to extract frames from video. Error is {e}.") |
|
|
|
if not self.enable_bucket: |
|
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() |
|
pixel_values = pixel_values / 255. |
|
del video_reader |
|
else: |
|
pixel_values = pixel_values |
|
|
|
if not self.enable_bucket: |
|
pixel_values = self.video_transforms(pixel_values) |
|
|
|
|
|
if random.random() < self.text_drop_ratio: |
|
text = '' |
|
return pixel_values, text, 'video' |
|
else: |
|
image_path, text = data_info['file_path'], data_info['text'] |
|
if self.data_root is not None: |
|
image_path = os.path.join(self.data_root, image_path) |
|
image = Image.open(image_path).convert('RGB') |
|
if not self.enable_bucket: |
|
image = self.image_transforms(image).unsqueeze(0) |
|
else: |
|
image = np.expand_dims(np.array(image), 0) |
|
if random.random() < self.text_drop_ratio: |
|
text = '' |
|
return image, text, 'image' |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
def __getitem__(self, idx): |
|
data_info = self.dataset[idx % len(self.dataset)] |
|
data_type = data_info.get('type', 'image') |
|
while True: |
|
sample = {} |
|
try: |
|
data_info_local = self.dataset[idx % len(self.dataset)] |
|
data_type_local = data_info_local.get('type', 'image') |
|
if data_type_local != data_type: |
|
raise ValueError("data_type_local != data_type") |
|
|
|
pixel_values, name, data_type = self.get_batch(idx) |
|
sample["pixel_values"] = pixel_values |
|
sample["text"] = name |
|
sample["data_type"] = data_type |
|
sample["idx"] = idx |
|
|
|
if len(sample) > 0: |
|
break |
|
except Exception as e: |
|
print(e, self.dataset[idx % len(self.dataset)]) |
|
idx = random.randint(0, self.length-1) |
|
|
|
if self.enable_inpaint and not self.enable_bucket: |
|
mask = get_random_mask(pixel_values.size()) |
|
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask |
|
sample["mask_pixel_values"] = mask_pixel_values |
|
sample["mask"] = mask |
|
|
|
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() |
|
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 |
|
sample["clip_pixel_values"] = clip_pixel_values |
|
|
|
ref_pixel_values = sample["pixel_values"][0].unsqueeze(0) |
|
if (mask == 1).all(): |
|
ref_pixel_values = torch.ones_like(ref_pixel_values) * -1 |
|
sample["ref_pixel_values"] = ref_pixel_values |
|
|
|
return sample |
|
|
|
|
|
class ImageVideoControlDataset(Dataset): |
|
def __init__( |
|
self, |
|
ann_path, data_root=None, |
|
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, |
|
image_sample_size=512, |
|
video_repeat=0, |
|
text_drop_ratio=-1, |
|
enable_bucket=False, |
|
video_length_drop_start=0.1, |
|
video_length_drop_end=0.9, |
|
enable_inpaint=False, |
|
): |
|
|
|
print(f"loading annotations from {ann_path} ...") |
|
if ann_path.endswith('.csv'): |
|
with open(ann_path, 'r') as csvfile: |
|
dataset = list(csv.DictReader(csvfile)) |
|
elif ann_path.endswith('.json'): |
|
dataset = json.load(open(ann_path)) |
|
|
|
self.data_root = data_root |
|
|
|
|
|
self.dataset = [] |
|
for data in dataset: |
|
if data.get('type', 'image') != 'video': |
|
self.dataset.append(data) |
|
if video_repeat > 0: |
|
for _ in range(video_repeat): |
|
for data in dataset: |
|
if data.get('type', 'image') == 'video': |
|
self.dataset.append(data) |
|
del dataset |
|
|
|
self.length = len(self.dataset) |
|
print(f"data scale: {self.length}") |
|
|
|
self.enable_bucket = enable_bucket |
|
self.text_drop_ratio = text_drop_ratio |
|
self.enable_inpaint = enable_inpaint |
|
|
|
self.video_length_drop_start = video_length_drop_start |
|
self.video_length_drop_end = video_length_drop_end |
|
|
|
|
|
self.video_sample_stride = video_sample_stride |
|
self.video_sample_n_frames = video_sample_n_frames |
|
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) |
|
self.video_transforms = transforms.Compose( |
|
[ |
|
transforms.Resize(min(self.video_sample_size)), |
|
transforms.CenterCrop(self.video_sample_size), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
] |
|
) |
|
|
|
|
|
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) |
|
self.image_transforms = transforms.Compose([ |
|
transforms.Resize(min(self.image_sample_size)), |
|
transforms.CenterCrop(self.image_sample_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) |
|
]) |
|
|
|
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size)) |
|
|
|
def get_batch(self, idx): |
|
data_info = self.dataset[idx % len(self.dataset)] |
|
video_id, control_video_id, text = data_info['file_path'], data_info['control_file_path'], data_info['text'] |
|
|
|
if data_info.get('type', 'image')=='video': |
|
if self.data_root is None: |
|
video_dir = video_id |
|
else: |
|
video_dir = os.path.join(self.data_root, video_id) |
|
|
|
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: |
|
min_sample_n_frames = min( |
|
self.video_sample_n_frames, |
|
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) |
|
) |
|
if min_sample_n_frames == 0: |
|
raise ValueError(f"No Frames in video.") |
|
|
|
video_length = int(self.video_length_drop_end * len(video_reader)) |
|
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) |
|
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 |
|
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) |
|
|
|
try: |
|
sample_args = (video_reader, batch_index) |
|
pixel_values = func_timeout( |
|
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args |
|
) |
|
resized_frames = [] |
|
for i in range(len(pixel_values)): |
|
frame = pixel_values[i] |
|
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) |
|
resized_frames.append(resized_frame) |
|
pixel_values = np.array(resized_frames) |
|
except FunctionTimedOut: |
|
raise ValueError(f"Read {idx} timeout.") |
|
except Exception as e: |
|
raise ValueError(f"Failed to extract frames from video. Error is {e}.") |
|
|
|
if not self.enable_bucket: |
|
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() |
|
pixel_values = pixel_values / 255. |
|
del video_reader |
|
else: |
|
pixel_values = pixel_values |
|
|
|
if not self.enable_bucket: |
|
pixel_values = self.video_transforms(pixel_values) |
|
|
|
|
|
if random.random() < self.text_drop_ratio: |
|
text = '' |
|
|
|
if self.data_root is None: |
|
control_video_id = control_video_id |
|
else: |
|
control_video_id = os.path.join(self.data_root, control_video_id) |
|
|
|
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: |
|
try: |
|
sample_args = (control_video_reader, batch_index) |
|
control_pixel_values = func_timeout( |
|
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args |
|
) |
|
resized_frames = [] |
|
for i in range(len(control_pixel_values)): |
|
frame = control_pixel_values[i] |
|
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) |
|
resized_frames.append(resized_frame) |
|
control_pixel_values = np.array(resized_frames) |
|
except FunctionTimedOut: |
|
raise ValueError(f"Read {idx} timeout.") |
|
except Exception as e: |
|
raise ValueError(f"Failed to extract frames from video. Error is {e}.") |
|
|
|
if not self.enable_bucket: |
|
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() |
|
control_pixel_values = control_pixel_values / 255. |
|
del control_video_reader |
|
else: |
|
control_pixel_values = control_pixel_values |
|
|
|
if not self.enable_bucket: |
|
control_pixel_values = self.video_transforms(control_pixel_values) |
|
return pixel_values, control_pixel_values, text, "video" |
|
else: |
|
image_path, text = data_info['file_path'], data_info['text'] |
|
if self.data_root is not None: |
|
image_path = os.path.join(self.data_root, image_path) |
|
image = Image.open(image_path).convert('RGB') |
|
if not self.enable_bucket: |
|
image = self.image_transforms(image).unsqueeze(0) |
|
else: |
|
image = np.expand_dims(np.array(image), 0) |
|
|
|
if random.random() < self.text_drop_ratio: |
|
text = '' |
|
|
|
if self.data_root is None: |
|
control_image_id = control_image_id |
|
else: |
|
control_image_id = os.path.join(self.data_root, control_image_id) |
|
|
|
control_image = Image.open(control_image_id).convert('RGB') |
|
if not self.enable_bucket: |
|
control_image = self.image_transforms(control_image).unsqueeze(0) |
|
else: |
|
control_image = np.expand_dims(np.array(control_image), 0) |
|
return image, control_image, text, 'image' |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
def __getitem__(self, idx): |
|
data_info = self.dataset[idx % len(self.dataset)] |
|
data_type = data_info.get('type', 'image') |
|
while True: |
|
sample = {} |
|
try: |
|
data_info_local = self.dataset[idx % len(self.dataset)] |
|
data_type_local = data_info_local.get('type', 'image') |
|
if data_type_local != data_type: |
|
raise ValueError("data_type_local != data_type") |
|
|
|
pixel_values, control_pixel_values, name, data_type = self.get_batch(idx) |
|
sample["pixel_values"] = pixel_values |
|
sample["control_pixel_values"] = control_pixel_values |
|
sample["text"] = name |
|
sample["data_type"] = data_type |
|
sample["idx"] = idx |
|
|
|
if len(sample) > 0: |
|
break |
|
except Exception as e: |
|
print(e, self.dataset[idx % len(self.dataset)]) |
|
idx = random.randint(0, self.length-1) |
|
|
|
if self.enable_inpaint and not self.enable_bucket: |
|
mask = get_random_mask(pixel_values.size()) |
|
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask |
|
sample["mask_pixel_values"] = mask_pixel_values |
|
sample["mask"] = mask |
|
|
|
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() |
|
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 |
|
sample["clip_pixel_values"] = clip_pixel_values |
|
|
|
ref_pixel_values = sample["pixel_values"][0].unsqueeze(0) |
|
if (mask == 1).all(): |
|
ref_pixel_values = torch.ones_like(ref_pixel_values) * -1 |
|
sample["ref_pixel_values"] = ref_pixel_values |
|
|
|
return sample |
|
|