Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import re | |
import torch | |
from lavis.processors import transforms_video | |
from lavis.common.registry import registry | |
from lavis.processors.base_processor import BaseProcessor | |
from lavis.datasets.data_utils import load_video | |
from lavis.processors.randaugment import RandomAugment | |
from omegaconf import OmegaConf | |
from torchvision import transforms | |
from torchvision.transforms.functional import InterpolationMode | |
MAX_INT = registry.get("MAX_INT") | |
class ToUint8(object): | |
def __init__(self): | |
pass | |
def __call__(self, tensor): | |
return tensor.to(torch.uint8) | |
def __repr__(self): | |
return self.__class__.__name__ | |
class ToTHWC(object): | |
""" | |
Args: | |
clip (torch.tensor, dtype=torch.uint8): Size is (C, T, H, W) | |
Return: | |
clip (torch.tensor, dtype=torch.float): Size is (T, H, W, C) | |
""" | |
def __init__(self): | |
pass | |
def __call__(self, tensor): | |
return tensor.permute(1, 2, 3, 0) | |
def __repr__(self): | |
return self.__class__.__name__ | |
class BlipImageBaseProcessor(BaseProcessor): | |
def __init__(self, mean=None, std=None): | |
if mean is None: | |
mean = (0.48145466, 0.4578275, 0.40821073) | |
if std is None: | |
std = (0.26862954, 0.26130258, 0.27577711) | |
self.normalize = transforms.Normalize(mean, std) | |
class BlipVideoBaseProcessor(BaseProcessor): | |
def __init__(self, mean=None, std=None, n_frms=MAX_INT): | |
if mean is None: | |
mean = (0.48145466, 0.4578275, 0.40821073) | |
if std is None: | |
std = (0.26862954, 0.26130258, 0.27577711) | |
self.normalize = transforms_video.NormalizeVideo(mean, std) | |
self.n_frms = n_frms | |
class BlipCaptionProcessor(BaseProcessor): | |
def __init__(self, prompt="", max_words=50): | |
self.prompt = prompt | |
self.max_words = max_words | |
def __call__(self, caption): | |
caption = self.prompt + self.pre_caption(caption) | |
return caption | |
def from_config(cls, cfg=None): | |
if cfg is None: | |
cfg = OmegaConf.create() | |
prompt = cfg.get("prompt", "") | |
max_words = cfg.get("max_words", 50) | |
return cls(prompt=prompt, max_words=max_words) | |
def pre_caption(self, caption): | |
caption = re.sub( | |
r"([.!\"()*#:;~])", | |
" ", | |
caption.lower(), | |
) | |
caption = re.sub( | |
r"\s{2,}", | |
" ", | |
caption, | |
) | |
caption = caption.rstrip("\n") | |
caption = caption.strip(" ") | |
# truncate caption | |
caption_words = caption.split(" ") | |
if len(caption_words) > self.max_words: | |
caption = " ".join(caption_words[: self.max_words]) | |
return caption | |
class BlipQuestionProcessor(BaseProcessor): | |
def __init__(self, max_words=50): | |
self.max_words = max_words | |
def __call__(self, question): | |
return self.pre_question(question) | |
def from_config(cls, cfg=None): | |
if cfg is None: | |
cfg = OmegaConf.create() | |
max_words = cfg.get("max_words", 50) | |
return cls(max_words=max_words) | |
def pre_question(self, question): | |
question = re.sub( | |
r"([.!\"()*#:;~])", | |
"", | |
question.lower(), | |
) | |
question = question.rstrip(" ") | |
# truncate question | |
question_words = question.split(" ") | |
if len(question_words) > self.max_words: | |
question = " ".join(question_words[: self.max_words]) | |
return question | |
class BlipImageTrainProcessor(BlipImageBaseProcessor): | |
def __init__( | |
self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0 | |
): | |
super().__init__(mean=mean, std=std) | |
self.transform = transforms.Compose( | |
[ | |
transforms.RandomResizedCrop( | |
image_size, | |
scale=(min_scale, max_scale), | |
interpolation=InterpolationMode.BICUBIC, | |
), | |
transforms.RandomHorizontalFlip(), | |
RandomAugment( | |
2, | |
5, | |
isPIL=True, | |
augs=[ | |
"Identity", | |
"AutoContrast", | |
"Brightness", | |
"Sharpness", | |
"Equalize", | |
"ShearX", | |
"ShearY", | |
"TranslateX", | |
"TranslateY", | |
"Rotate", | |
], | |
), | |
transforms.ToTensor(), | |
self.normalize, | |
] | |
) | |
def __call__(self, item): | |
return self.transform(item) | |
def from_config(cls, cfg=None): | |
if cfg is None: | |
cfg = OmegaConf.create() | |
image_size = cfg.get("image_size", 384) | |
mean = cfg.get("mean", None) | |
std = cfg.get("std", None) | |
min_scale = cfg.get("min_scale", 0.5) | |
max_scale = cfg.get("max_scale", 1.0) | |
return cls( | |
image_size=image_size, | |
mean=mean, | |
std=std, | |
min_scale=min_scale, | |
max_scale=max_scale, | |
) | |
class BlipImageEvalProcessor(BlipImageBaseProcessor): | |
def __init__(self, image_size=384, mean=None, std=None): | |
super().__init__(mean=mean, std=std) | |
self.transform = transforms.Compose( | |
[ | |
transforms.Resize( | |
(image_size, image_size), interpolation=InterpolationMode.BICUBIC | |
), | |
transforms.ToTensor(), | |
self.normalize, | |
] | |
) | |
def __call__(self, item): | |
return self.transform(item) | |
def from_config(cls, cfg=None): | |
if cfg is None: | |
cfg = OmegaConf.create() | |
image_size = cfg.get("image_size", 384) | |
mean = cfg.get("mean", None) | |
std = cfg.get("std", None) | |
return cls(image_size=image_size, mean=mean, std=std) | |
class Blip2ImageTrainProcessor(BlipImageBaseProcessor): | |
def __init__( | |
self, image_size=364, mean=None, std=None, min_scale=0.5, max_scale=1.0 | |
): | |
super().__init__(mean=mean, std=std) | |
self.transform = transforms.Compose( | |
[ | |
transforms.RandomResizedCrop( | |
image_size, | |
scale=(min_scale, max_scale), | |
interpolation=InterpolationMode.BICUBIC, | |
), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
self.normalize, | |
] | |
) | |
def __call__(self, item): | |
return self.transform(item) | |
def from_config(cls, cfg=None): | |
if cfg is None: | |
cfg = OmegaConf.create() | |
image_size = cfg.get("image_size", 364) | |
mean = cfg.get("mean", None) | |
std = cfg.get("std", None) | |
min_scale = cfg.get("min_scale", 0.5) | |
max_scale = cfg.get("max_scale", 1.0) | |
return cls( | |
image_size=image_size, | |
mean=mean, | |
std=std, | |
min_scale=min_scale, | |
max_scale=max_scale, | |
) | |
class Blip2VideoTrainProcessor(BlipVideoBaseProcessor): | |
def __init__( | |
self, | |
image_size=384, | |
mean=None, | |
std=None, | |
min_scale=0.5, | |
max_scale=1.0, | |
n_frms=MAX_INT, | |
): | |
super().__init__(mean=mean, std=std, n_frms=n_frms) | |
self.image_size = image_size | |
self.transform = transforms.Compose( | |
[ | |
# Video size is (C, T, H, W) | |
transforms_video.RandomResizedCropVideo( | |
image_size, | |
scale=(min_scale, max_scale), | |
interpolation_mode="bicubic", | |
), | |
ToTHWC(), # C, T, H, W -> T, H, W, C | |
ToUint8(), | |
transforms_video.ToTensorVideo(), # T, H, W, C -> C, T, H, W | |
self.normalize, | |
] | |
) | |
def __call__(self, vpath, clip_proposal=None): | |
clip, indices, fps = load_video( | |
video_path=vpath, | |
n_frms=self.n_frms, | |
height=self.image_size, | |
width=self.image_size, | |
sampling="random", | |
clip_proposal=clip_proposal | |
) | |
return self.transform(clip), indices, fps | |
def from_config(cls, cfg=None): | |
if cfg is None: | |
cfg = OmegaConf.create() | |
image_size = cfg.get("image_size", 364) | |
mean = cfg.get("mean", None) | |
std = cfg.get("std", None) | |
min_scale = cfg.get("min_scale", 0.5) | |
max_scale = cfg.get("max_scale", 1.0) | |
n_frms = cfg.get("n_frms", MAX_INT) | |
return cls( | |
image_size=image_size, | |
mean=mean, | |
std=std, | |
min_scale=min_scale, | |
max_scale=max_scale, | |
n_frms=n_frms | |
) | |
class BlipVideoEvalProcessor(BlipVideoBaseProcessor): | |
def __init__(self, image_size=384, mean=None, std=None, n_frms=MAX_INT): | |
super().__init__(mean=mean, std=std, n_frms=n_frms) | |
self.image_size = image_size | |
self.transform = transforms.Compose( | |
[ | |
ToUint8(), # C, T, H, W | |
ToTHWC(), # T, H, W, C | |
transforms_video.ToTensorVideo(), # C, T, H, W | |
self.normalize, # C, T, H, W | |
] | |
) | |
self.n_frms = n_frms | |
def __call__(self, vpath, clip_proposal=None): | |
clip, indices, fps = load_video( | |
video_path=vpath, | |
n_frms=self.n_frms, | |
height=self.image_size, | |
width=self.image_size, | |
sampling="uniform", | |
clip_proposal=clip_proposal | |
) | |
return self.transform(clip), indices, fps | |
def from_config(cls, cfg=None): | |
if cfg is None: | |
cfg = OmegaConf.create() | |
image_size = cfg.get("image_size", 256) | |
mean = cfg.get("mean", None) | |
std = cfg.get("std", None) | |
n_frms = cfg.get("n_frms", MAX_INT) | |
return cls(image_size=image_size, mean=mean, std=std, n_frms=n_frms) |