Spaces:
Running
on
A10G
Running
on
A10G
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import torch | |
from video_llama.common.registry import registry | |
from decord import VideoReader | |
import decord | |
import numpy as np | |
from video_llama.processors import transforms_video | |
from video_llama.processors.base_processor import BaseProcessor | |
from video_llama.processors.randaugment import VideoRandomAugment | |
from video_llama.processors import functional_video as F | |
from omegaconf import OmegaConf | |
from torchvision import transforms | |
import random as rnd | |
MAX_INT = registry.get("MAX_INT") | |
decord.bridge.set_bridge("torch") | |
def load_video(video_path, n_frms=MAX_INT, height=-1, width=-1, sampling="uniform", return_msg = False): | |
decord.bridge.set_bridge("torch") | |
vr = VideoReader(uri=video_path, height=height, width=width) | |
vlen = len(vr) | |
start, end = 0, vlen | |
n_frms = min(n_frms, vlen) | |
if sampling == "uniform": | |
indices = np.arange(start, end, vlen / n_frms).astype(int).tolist() | |
elif sampling == "headtail": | |
indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2)) | |
indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2)) | |
indices = indices_h + indices_t | |
else: | |
raise NotImplementedError | |
# get_batch -> T, H, W, C | |
temp_frms = vr.get_batch(indices) | |
# print(type(temp_frms)) | |
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms | |
frms = tensor_frms.permute(3, 0, 1, 2).float() # (C, T, H, W) | |
if not return_msg: | |
return frms | |
fps = float(vr.get_avg_fps()) | |
sec = ", ".join([str(round(f / fps, 1)) for f in indices]) | |
# " " should be added in the start and end | |
msg = f"The video contains {len(indices)} frames sampled at {sec} seconds. " | |
return frms, msg | |
class AlproVideoBaseProcessor(BaseProcessor): | |
def __init__(self, mean=None, std=None, n_frms=MAX_INT): | |
if mean is None: | |
mean = (0.48145466, 0.4578275, 0.40821073) | |
if std is None: | |
std = (0.26862954, 0.26130258, 0.27577711) | |
self.normalize = transforms_video.NormalizeVideo(mean, std) | |
self.n_frms = n_frms | |
class ToUint8(object): | |
def __init__(self): | |
pass | |
def __call__(self, tensor): | |
return tensor.to(torch.uint8) | |
def __repr__(self): | |
return self.__class__.__name__ | |
class ToTHWC(object): | |
""" | |
Args: | |
clip (torch.tensor, dtype=torch.uint8): Size is (C, T, H, W) | |
Return: | |
clip (torch.tensor, dtype=torch.float): Size is (T, H, W, C) | |
""" | |
def __init__(self): | |
pass | |
def __call__(self, tensor): | |
return tensor.permute(1, 2, 3, 0) | |
def __repr__(self): | |
return self.__class__.__name__ | |
class ResizeVideo(object): | |
def __init__(self, target_size, interpolation_mode="bilinear"): | |
self.target_size = target_size | |
self.interpolation_mode = interpolation_mode | |
def __call__(self, clip): | |
""" | |
Args: | |
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) | |
Returns: | |
torch.tensor: central cropping of video clip. Size is | |
(C, T, crop_size, crop_size) | |
""" | |
return F.resize(clip, self.target_size, self.interpolation_mode) | |
def __repr__(self): | |
return self.__class__.__name__ + "(resize_size={0})".format(self.target_size) | |
class AlproVideoTrainProcessor(AlproVideoBaseProcessor): | |
def __init__( | |
self, | |
image_size=384, | |
mean=None, | |
std=None, | |
min_scale=0.5, | |
max_scale=1.0, | |
n_frms=MAX_INT, | |
): | |
super().__init__(mean=mean, std=std, n_frms=n_frms) | |
self.image_size = image_size | |
self.transform = transforms.Compose( | |
[ | |
# Video size is (C, T, H, W) | |
transforms_video.RandomResizedCropVideo( | |
image_size, | |
scale=(min_scale, max_scale), | |
interpolation_mode="bicubic", | |
), | |
ToTHWC(), # C, T, H, W -> T, H, W, C | |
ToUint8(), | |
transforms_video.ToTensorVideo(), # T, H, W, C -> C, T, H, W | |
self.normalize, | |
] | |
) | |
def __call__(self, vpath): | |
""" | |
Args: | |
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) | |
Returns: | |
torch.tensor: video clip after transforms. Size is (C, T, size, size). | |
""" | |
clip = load_video( | |
video_path=vpath, | |
n_frms=self.n_frms, | |
height=self.image_size, | |
width=self.image_size, | |
sampling="headtail", | |
) | |
return self.transform(clip) | |
def from_config(cls, cfg=None): | |
if cfg is None: | |
cfg = OmegaConf.create() | |
image_size = cfg.get("image_size", 256) | |
mean = cfg.get("mean", None) | |
std = cfg.get("std", None) | |
min_scale = cfg.get("min_scale", 0.5) | |
max_scale = cfg.get("max_scale", 1.0) | |
n_frms = cfg.get("n_frms", MAX_INT) | |
return cls( | |
image_size=image_size, | |
mean=mean, | |
std=std, | |
min_scale=min_scale, | |
max_scale=max_scale, | |
n_frms=n_frms, | |
) | |
class AlproVideoEvalProcessor(AlproVideoBaseProcessor): | |
def __init__(self, image_size=256, mean=None, std=None, n_frms=MAX_INT): | |
super().__init__(mean=mean, std=std, n_frms=n_frms) | |
self.image_size = image_size | |
# Input video size is (C, T, H, W) | |
self.transform = transforms.Compose( | |
[ | |
# frames will be resized during decord loading. | |
ToUint8(), # C, T, H, W | |
ToTHWC(), # T, H, W, C | |
transforms_video.ToTensorVideo(), # C, T, H, W | |
self.normalize, # C, T, H, W | |
] | |
) | |
def __call__(self, vpath): | |
""" | |
Args: | |
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) | |
Returns: | |
torch.tensor: video clip after transforms. Size is (C, T, size, size). | |
""" | |
clip = load_video( | |
video_path=vpath, | |
n_frms=self.n_frms, | |
height=self.image_size, | |
width=self.image_size, | |
) | |
return self.transform(clip) | |
def from_config(cls, cfg=None): | |
if cfg is None: | |
cfg = OmegaConf.create() | |
image_size = cfg.get("image_size", 256) | |
mean = cfg.get("mean", None) | |
std = cfg.get("std", None) | |
n_frms = cfg.get("n_frms", MAX_INT) | |
return cls(image_size=image_size, mean=mean, std=std, n_frms=n_frms) | |