pix2pix-zero-demo / lavis /processors /transforms_video.py
John6666's picture
Upload 351 files
e84842d verified
raw
history blame
5.01 kB
#!/usr/bin/env python3
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import numbers
import random
from torchvision.transforms import (
RandomCrop,
RandomResizedCrop,
)
import lavis.processors.functional_video as F
__all__ = [
"RandomCropVideo",
"RandomResizedCropVideo",
"CenterCropVideo",
"NormalizeVideo",
"ToTensorVideo",
"RandomHorizontalFlipVideo",
]
class RandomCropVideo(RandomCrop):
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: randomly cropped/resized video clip.
size is (C, T, OH, OW)
"""
i, j, h, w = self.get_params(clip, self.size)
return F.crop(clip, i, j, h, w)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class RandomResizedCropVideo(RandomResizedCrop):
def __init__(
self,
size,
scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0),
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(
f"size should be tuple (height, width), instead got {size}"
)
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
self.scale = scale
self.ratio = ratio
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: randomly cropped/resized video clip.
size is (C, T, H, W)
"""
i, j, h, w = self.get_params(clip, self.scale, self.ratio)
return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})"
class CenterCropVideo:
def __init__(self, crop_size):
if isinstance(crop_size, numbers.Number):
self.crop_size = (int(crop_size), int(crop_size))
else:
self.crop_size = crop_size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: central cropping of video clip. Size is
(C, T, crop_size, crop_size)
"""
return F.center_crop(clip, self.crop_size)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(crop_size={self.crop_size})"
class NormalizeVideo:
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W)
"""
return F.normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
class ToTensorVideo:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
Return:
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
"""
return F.to_tensor(clip)
def __repr__(self) -> str:
return self.__class__.__name__
class RandomHorizontalFlipVideo:
"""
Flip the video clip along the horizonal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (C, T, H, W)
Return:
clip (torch.tensor): Size is (C, T, H, W)
"""
if random.random() < self.p:
clip = F.hflip(clip)
return clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"