Video-LLaMA / video_llama /processors /video_processor.py
θˆŸε‹€
v1
45d16e9
raw
history blame
7 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import torch
from video_llama.common.registry import registry
from decord import VideoReader
import decord
import numpy as np
from video_llama.processors import transforms_video
from video_llama.processors.base_processor import BaseProcessor
from video_llama.processors.randaugment import VideoRandomAugment
from video_llama.processors import functional_video as F
from omegaconf import OmegaConf
from torchvision import transforms
import random as rnd
MAX_INT = registry.get("MAX_INT")
decord.bridge.set_bridge("torch")
def load_video(video_path, n_frms=MAX_INT, height=-1, width=-1, sampling="uniform", return_msg = False):
decord.bridge.set_bridge("torch")
vr = VideoReader(uri=video_path, height=height, width=width)
vlen = len(vr)
start, end = 0, vlen
n_frms = min(n_frms, vlen)
if sampling == "uniform":
indices = np.arange(start, end, vlen / n_frms).astype(int).tolist()
elif sampling == "headtail":
indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2))
indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2))
indices = indices_h + indices_t
else:
raise NotImplementedError
# get_batch -> T, H, W, C
temp_frms = vr.get_batch(indices)
# print(type(temp_frms))
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
frms = tensor_frms.permute(3, 0, 1, 2).float() # (C, T, H, W)
if not return_msg:
return frms
fps = float(vr.get_avg_fps())
sec = ", ".join([str(round(f / fps, 1)) for f in indices])
# " " should be added in the start and end
msg = f"The video contains {len(indices)} frames sampled at {sec} seconds. "
return frms, msg
class AlproVideoBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None, n_frms=MAX_INT):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = transforms_video.NormalizeVideo(mean, std)
self.n_frms = n_frms
class ToUint8(object):
def __init__(self):
pass
def __call__(self, tensor):
return tensor.to(torch.uint8)
def __repr__(self):
return self.__class__.__name__
class ToTHWC(object):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (C, T, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, H, W, C)
"""
def __init__(self):
pass
def __call__(self, tensor):
return tensor.permute(1, 2, 3, 0)
def __repr__(self):
return self.__class__.__name__
class ResizeVideo(object):
def __init__(self, target_size, interpolation_mode="bilinear"):
self.target_size = target_size
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: central cropping of video clip. Size is
(C, T, crop_size, crop_size)
"""
return F.resize(clip, self.target_size, self.interpolation_mode)
def __repr__(self):
return self.__class__.__name__ + "(resize_size={0})".format(self.target_size)
@registry.register_processor("alpro_video_train")
class AlproVideoTrainProcessor(AlproVideoBaseProcessor):
def __init__(
self,
image_size=384,
mean=None,
std=None,
min_scale=0.5,
max_scale=1.0,
n_frms=MAX_INT,
):
super().__init__(mean=mean, std=std, n_frms=n_frms)
self.image_size = image_size
self.transform = transforms.Compose(
[
# Video size is (C, T, H, W)
transforms_video.RandomResizedCropVideo(
image_size,
scale=(min_scale, max_scale),
interpolation_mode="bicubic",
),
ToTHWC(), # C, T, H, W -> T, H, W, C
ToUint8(),
transforms_video.ToTensorVideo(), # T, H, W, C -> C, T, H, W
self.normalize,
]
)
def __call__(self, vpath):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: video clip after transforms. Size is (C, T, size, size).
"""
clip = load_video(
video_path=vpath,
n_frms=self.n_frms,
height=self.image_size,
width=self.image_size,
sampling="headtail",
)
return self.transform(clip)
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
image_size = cfg.get("image_size", 256)
mean = cfg.get("mean", None)
std = cfg.get("std", None)
min_scale = cfg.get("min_scale", 0.5)
max_scale = cfg.get("max_scale", 1.0)
n_frms = cfg.get("n_frms", MAX_INT)
return cls(
image_size=image_size,
mean=mean,
std=std,
min_scale=min_scale,
max_scale=max_scale,
n_frms=n_frms,
)
@registry.register_processor("alpro_video_eval")
class AlproVideoEvalProcessor(AlproVideoBaseProcessor):
def __init__(self, image_size=256, mean=None, std=None, n_frms=MAX_INT):
super().__init__(mean=mean, std=std, n_frms=n_frms)
self.image_size = image_size
# Input video size is (C, T, H, W)
self.transform = transforms.Compose(
[
# frames will be resized during decord loading.
ToUint8(), # C, T, H, W
ToTHWC(), # T, H, W, C
transforms_video.ToTensorVideo(), # C, T, H, W
self.normalize, # C, T, H, W
]
)
def __call__(self, vpath):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: video clip after transforms. Size is (C, T, size, size).
"""
clip = load_video(
video_path=vpath,
n_frms=self.n_frms,
height=self.image_size,
width=self.image_size,
)
return self.transform(clip)
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
image_size = cfg.get("image_size", 256)
mean = cfg.get("mean", None)
std = cfg.get("std", None)
n_frms = cfg.get("n_frms", MAX_INT)
return cls(image_size=image_size, mean=mean, std=std, n_frms=n_frms)