SeViLA / lavis /processors /blip_processors.py
shoubin
upload_demo
7e8784c
raw history blame
No virus
10.9 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import re
import torch
from lavis.processors import transforms_video
from lavis.common.registry import registry
from lavis.processors.base_processor import BaseProcessor
from lavis.datasets.data_utils import load_video
from lavis.processors.randaugment import RandomAugment
from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
MAX_INT = registry.get("MAX_INT")
class ToUint8(object):
def __init__(self):
pass
def __call__(self, tensor):
return tensor.to(torch.uint8)
def __repr__(self):
return self.__class__.__name__
class ToTHWC(object):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (C, T, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, H, W, C)
"""
def __init__(self):
pass
def __call__(self, tensor):
return tensor.permute(1, 2, 3, 0)
def __repr__(self):
return self.__class__.__name__
class BlipImageBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = transforms.Normalize(mean, std)
class BlipVideoBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None, n_frms=MAX_INT):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = transforms_video.NormalizeVideo(mean, std)
self.n_frms = n_frms
@registry.register_processor("blip_caption")
class BlipCaptionProcessor(BaseProcessor):
def __init__(self, prompt="", max_words=50):
self.prompt = prompt
self.max_words = max_words
def __call__(self, caption):
caption = self.prompt + self.pre_caption(caption)
return caption
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
prompt = cfg.get("prompt", "")
max_words = cfg.get("max_words", 50)
return cls(prompt=prompt, max_words=max_words)
def pre_caption(self, caption):
caption = re.sub(
r"([.!\"()*#:;~])",
" ",
caption.lower(),
)
caption = re.sub(
r"\s{2,}",
" ",
caption,
)
caption = caption.rstrip("\n")
caption = caption.strip(" ")
# truncate caption
caption_words = caption.split(" ")
if len(caption_words) > self.max_words:
caption = " ".join(caption_words[: self.max_words])
return caption
@registry.register_processor("blip_question")
class BlipQuestionProcessor(BaseProcessor):
def __init__(self, max_words=50):
self.max_words = max_words
def __call__(self, question):
return self.pre_question(question)
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
max_words = cfg.get("max_words", 50)
return cls(max_words=max_words)
def pre_question(self, question):
question = re.sub(
r"([.!\"()*#:;~])",
"",
question.lower(),
)
question = question.rstrip(" ")
# truncate question
question_words = question.split(" ")
if len(question_words) > self.max_words:
question = " ".join(question_words[: self.max_words])
return question
@registry.register_processor("blip_image_train")
class BlipImageTrainProcessor(BlipImageBaseProcessor):
def __init__(
self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0
):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.RandomResizedCrop(
image_size,
scale=(min_scale, max_scale),
interpolation=InterpolationMode.BICUBIC,
),
transforms.RandomHorizontalFlip(),
RandomAugment(
2,
5,
isPIL=True,
augs=[
"Identity",
"AutoContrast",
"Brightness",
"Sharpness",
"Equalize",
"ShearX",
"ShearY",
"TranslateX",
"TranslateY",
"Rotate",
],
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
image_size = cfg.get("image_size", 384)
mean = cfg.get("mean", None)
std = cfg.get("std", None)
min_scale = cfg.get("min_scale", 0.5)
max_scale = cfg.get("max_scale", 1.0)
return cls(
image_size=image_size,
mean=mean,
std=std,
min_scale=min_scale,
max_scale=max_scale,
)
@registry.register_processor("blip_image_eval")
class BlipImageEvalProcessor(BlipImageBaseProcessor):
def __init__(self, image_size=384, mean=None, std=None):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
image_size = cfg.get("image_size", 384)
mean = cfg.get("mean", None)
std = cfg.get("std", None)
return cls(image_size=image_size, mean=mean, std=std)
@registry.register_processor("blip2_image_train")
class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
def __init__(
self, image_size=364, mean=None, std=None, min_scale=0.5, max_scale=1.0
):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.RandomResizedCrop(
image_size,
scale=(min_scale, max_scale),
interpolation=InterpolationMode.BICUBIC,
),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
image_size = cfg.get("image_size", 364)
mean = cfg.get("mean", None)
std = cfg.get("std", None)
min_scale = cfg.get("min_scale", 0.5)
max_scale = cfg.get("max_scale", 1.0)
return cls(
image_size=image_size,
mean=mean,
std=std,
min_scale=min_scale,
max_scale=max_scale,
)
@registry.register_processor("blip2_video_train")
class Blip2VideoTrainProcessor(BlipVideoBaseProcessor):
def __init__(
self,
image_size=384,
mean=None,
std=None,
min_scale=0.5,
max_scale=1.0,
n_frms=MAX_INT,
):
super().__init__(mean=mean, std=std, n_frms=n_frms)
self.image_size = image_size
self.transform = transforms.Compose(
[
# Video size is (C, T, H, W)
transforms_video.RandomResizedCropVideo(
image_size,
scale=(min_scale, max_scale),
interpolation_mode="bicubic",
),
ToTHWC(), # C, T, H, W -> T, H, W, C
ToUint8(),
transforms_video.ToTensorVideo(), # T, H, W, C -> C, T, H, W
self.normalize,
]
)
def __call__(self, vpath, clip_proposal=None):
clip, indices, fps = load_video(
video_path=vpath,
n_frms=self.n_frms,
height=self.image_size,
width=self.image_size,
sampling="random",
clip_proposal=clip_proposal
)
return self.transform(clip), indices, fps
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
image_size = cfg.get("image_size", 364)
mean = cfg.get("mean", None)
std = cfg.get("std", None)
min_scale = cfg.get("min_scale", 0.5)
max_scale = cfg.get("max_scale", 1.0)
n_frms = cfg.get("n_frms", MAX_INT)
return cls(
image_size=image_size,
mean=mean,
std=std,
min_scale=min_scale,
max_scale=max_scale,
n_frms=n_frms
)
@registry.register_processor("blip_video_eval")
class BlipVideoEvalProcessor(BlipVideoBaseProcessor):
def __init__(self, image_size=384, mean=None, std=None, n_frms=MAX_INT):
super().__init__(mean=mean, std=std, n_frms=n_frms)
self.image_size = image_size
self.transform = transforms.Compose(
[
ToUint8(), # C, T, H, W
ToTHWC(), # T, H, W, C
transforms_video.ToTensorVideo(), # C, T, H, W
self.normalize, # C, T, H, W
]
)
self.n_frms = n_frms
def __call__(self, vpath, clip_proposal=None):
clip, indices, fps = load_video(
video_path=vpath,
n_frms=self.n_frms,
height=self.image_size,
width=self.image_size,
sampling="uniform",
clip_proposal=clip_proposal
)
return self.transform(clip), indices, fps
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
image_size = cfg.get("image_size", 256)
mean = cfg.get("mean", None)
std = cfg.get("std", None)
n_frms = cfg.get("n_frms", MAX_INT)
return cls(image_size=image_size, mean=mean, std=std, n_frms=n_frms)