Cinemo / datasets /video_transforms.py
maxin-cn's picture
Upload folder using huggingface_hub
be791d6 verified
raw
history blame
26.1 kB
import torch
import random
import numbers
from torchvision.transforms import RandomCrop, RandomResizedCrop
from PIL import Image
from torchvision.utils import _log_api_usage_once
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tensor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
"""
if len(clip.size()) != 4:
raise ValueError("clip should be a 4D tensor")
return clip[..., i : i + h, j : j + w]
def resize(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
def resize_scale(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size[0] / min(H, W)
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
def resize_with_scale_factor(clip, scale_factor, interpolation_mode):
return torch.nn.functional.interpolate(clip, scale_factor=scale_factor, mode=interpolation_mode, align_corners=False)
def resize_scale_with_height(clip, target_size, interpolation_mode):
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size / H
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
def resize_scale_with_weight(clip, target_size, interpolation_mode):
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size / W
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
# print(clip.shape)
th, tw = crop_size
if h < th or w < tw:
# print(h, w)
raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w))
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw), i, j
def center_crop_using_short_edge(clip):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h < w:
th, tw = h, h
i = 0
j = int(round((w - tw) / 2.0))
else:
th, tw = w, w
i = int(round((h - th) / 2.0))
j = 0
return crop(clip, i, j, th, tw)
def random_shift_crop(clip):
'''
Slide along the long edge, with the short edge as crop size
'''
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h <= w:
long_edge = w
short_edge = h
else:
long_edge = h
short_edge =w
th, tw = short_edge, short_edge
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return crop(clip, i, j, th, tw), i, j
def random_crop(clip, crop_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size[-2], crop_size[-1]
if h < th or w < tw:
raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w))
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
clip_crop = crop(clip, i, j, th, tw)
return clip_crop, i, j
def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
# return clip.float().permute(3, 0, 1, 2) / 255.0
return clip.float() / 255.0
def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
# print(mean)
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
Returns:
flipped clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
return clip.flip(-1)
class RandomCropVideo:
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: randomly cropped video clip.
size is (T, C, OH, OW)
"""
i, j, h, w = self.get_params(clip)
return crop(clip, i, j, h, w)
def get_params(self, clip):
h, w = clip.shape[-2:]
th, tw = self.size
if h < th or w < tw:
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
if w == tw and h == th:
return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return i, j, th, tw
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class CenterCropResizeVideo:
'''
First use the short side for cropping length,
center crop video, then resize to the specified size
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
# print(clip.shape)
clip_center_crop = center_crop_using_short_edge(clip)
# print(clip_center_crop.shape) 320 512
clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode)
return clip_center_crop_resize
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class SDXL:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
# add aditional one pixel for avoiding error in center crop
ori_h, ori_w = clip.size(-2), clip.size(-1)
tar_h, tar_w = self.size[0] + 1, self.size[1] + 1
# if ori_h >= tar_h and ori_w >= tar_w:
# clip_tar_crop, i, j = random_crop(clip=clip, crop_size=self.size)
# else:
# tar_h_div_ori_h = tar_h / ori_h
# tar_w_div_ori_w = tar_w / ori_w
# if tar_h_div_ori_h > tar_w_div_ori_w:
# clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode)
# else:
# clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode)
# clip_tar_crop, i, j = random_crop(clip, self.size)
if ori_h >= tar_h and ori_w >= tar_w:
tar_h_div_ori_h = tar_h / ori_h
tar_w_div_ori_w = tar_w / ori_w
if tar_h_div_ori_h > tar_w_div_ori_w:
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode)
else:
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode)
ori_h, ori_w = clip.size(-2), clip.size(-1)
clip_tar_crop, i, j = random_crop(clip, self.size)
else:
tar_h_div_ori_h = tar_h / ori_h
tar_w_div_ori_w = tar_w / ori_w
if tar_h_div_ori_h > tar_w_div_ori_w:
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode)
else:
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode)
clip_tar_crop, i, j = random_crop(clip, self.size)
return clip_tar_crop, ori_h, ori_w, i, j
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class SDXLCenterCrop:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
# add aditional one pixel for avoiding error in center crop
ori_h, ori_w = clip.size(-2), clip.size(-1)
tar_h, tar_w = self.size[0] + 1, self.size[1] + 1
tar_h_div_ori_h = tar_h / ori_h
tar_w_div_ori_w = tar_w / ori_w
# print('before resize', clip.shape)
if tar_h_div_ori_h > tar_w_div_ori_w:
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode)
# print('after h resize', clip.shape)
else:
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode)
# print('after resize', clip.shape)
# print(clip.shape)
# clip_tar_crop, i, j = random_crop(clip, self.size)
clip_tar_crop, i, j = center_crop(clip, self.size)
# print('after crop', clip_tar_crop.shape)
return clip_tar_crop, ori_h, ori_w, i, j
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class InternVideo320512:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
# add aditional one pixel for avoiding error in center crop
h, w = clip.size(-2), clip.size(-1)
# print('before resize', clip.shape)
if h < 320:
clip = resize_scale_with_height(clip=clip, target_size=321, interpolation_mode=self.interpolation_mode)
# print('after h resize', clip.shape)
if w < 512:
clip = resize_scale_with_weight(clip=clip, target_size=513, interpolation_mode=self.interpolation_mode)
# print('after w resize', clip.shape)
# print(clip.shape)
clip_center_crop = center_crop(clip, self.size)
clip_center_crop_no_subtitles = center_crop(clip, (220, 352))
clip_center_resize = resize(clip_center_crop_no_subtitles, target_size=self.size, interpolation_mode=self.interpolation_mode)
# print(clip_center_crop.shape)
return clip_center_resize
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class CenterCropVideo:
'''
First scale to the specified size in equal proportion to the short edge,
then center cropping
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
clip_center_crop = center_crop(clip_resize, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class KineticsRandomCropResizeVideo:
'''
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
clip_random_crop = random_shift_crop(clip)
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
return clip_resize
class ResizeVideo():
'''
First use the short side for cropping length,
center crop video, then resize to the specified size
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
return clip_resize
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class CenterCropVideo:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop(clip, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class NormalizeVideo:
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
"""
return normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
class ToTensorVideo:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
return to_tensor(clip)
def __repr__(self) -> str:
return self.__class__.__name__
class RandomHorizontalFlipVideo:
"""
Flip the video clip along the horizontal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (T, C, H, W)
Return:
clip (torch.tensor): Size is (T, C, H, W)
"""
if random.random() < self.p:
clip = hflip(clip)
return clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
class Compose:
"""Composes several transforms together. This transform does not support torchscript.
Please, see the note below.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> ])
.. note::
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
>>> transforms = torch.nn.Sequential(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
"""
def __init__(self, transforms):
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(self)
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
if isinstance(t, SDXLCenterCrop) or isinstance(t, SDXL):
img, ori_h, ori_w, crops_coords_top, crops_coords_left = t(img)
else:
img = t(img)
return img, ori_h, ori_w, crops_coords_top, crops_coords_left
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += f" {t}"
format_string += "\n)"
return format_string
# ------------------------------------------------------------
# --------------------- Sampling ---------------------------
# ------------------------------------------------------------
class TemporalRandomCrop(object):
"""Temporally crop the given frame indices at a random location.
Args:
size (int): Desired length of frames will be seen in the model.
"""
def __init__(self, size):
self.size = size
def __call__(self, total_frames):
rand_end = max(0, total_frames - self.size - 1)
begin_index = random.randint(0, rand_end)
end_index = min(begin_index + self.size, total_frames)
return begin_index, end_index
if __name__ == '__main__':
from torchvision import transforms
import torchvision.io as io
import numpy as np
from torchvision.utils import save_image
import os
vframes, aframes, info = io.read_video(
filename='./v_Archery_g01_c03.avi',
pts_unit='sec',
output_format='TCHW'
)
trans = transforms.Compose([
ToTensorVideo(),
RandomHorizontalFlipVideo(),
UCFCenterCropVideo(512),
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
target_video_len = 32
frame_interval = 1
total_frames = len(vframes)
print(total_frames)
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
# Sampling video frames
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
# print(start_frame_ind)
# print(end_frame_ind)
assert end_frame_ind - start_frame_ind >= target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
print(frame_indice)
select_vframes = vframes[frame_indice]
print(select_vframes.shape)
print(select_vframes.dtype)
select_vframes_trans = trans(select_vframes)
print(select_vframes_trans.shape)
print(select_vframes_trans.dtype)
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
print(select_vframes_trans_int.dtype)
print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
for i in range(target_video_len):
save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1))