|
"""
|
|
Copyright (c) 2022, salesforce.com, inc.
|
|
All rights reserved.
|
|
SPDX-License-Identifier: BSD-3-Clause
|
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
"""
|
|
|
|
import re
|
|
|
|
from minigpt4.common.registry import registry
|
|
from minigpt4.processors.base_processor import BaseProcessor
|
|
from minigpt4.processors.randaugment import RandomAugment
|
|
from omegaconf import OmegaConf
|
|
from torchvision import transforms
|
|
from torchvision.transforms.functional import InterpolationMode
|
|
|
|
|
|
class BlipImageBaseProcessor(BaseProcessor):
|
|
def __init__(self, mean=None, std=None, do_normalize=True):
|
|
if mean is None:
|
|
mean = (0.48145466, 0.4578275, 0.40821073)
|
|
if std is None:
|
|
std = (0.26862954, 0.26130258, 0.27577711)
|
|
|
|
if do_normalize:
|
|
self.normalize = transforms.Normalize(mean, std)
|
|
else:
|
|
self.normalize = transforms.Lambda(lambda img: img)
|
|
|
|
|
|
@registry.register_processor("blip_caption")
|
|
class BlipCaptionProcessor(BaseProcessor):
|
|
def __init__(self, prompt="", max_words=50):
|
|
self.prompt = prompt
|
|
self.max_words = max_words
|
|
|
|
def __call__(self, caption):
|
|
caption = self.prompt + self.pre_caption(caption)
|
|
|
|
return caption
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg=None):
|
|
if cfg is None:
|
|
cfg = OmegaConf.create()
|
|
|
|
prompt = cfg.get("prompt", "")
|
|
max_words = cfg.get("max_words", 50)
|
|
|
|
return cls(prompt=prompt, max_words=max_words)
|
|
|
|
def pre_caption(self, caption):
|
|
caption = re.sub(
|
|
r"([.!\"()*#:;~])",
|
|
" ",
|
|
caption.lower(),
|
|
)
|
|
caption = re.sub(
|
|
r"\s{2,}",
|
|
" ",
|
|
caption,
|
|
)
|
|
caption = caption.rstrip("\n")
|
|
caption = caption.strip(" ")
|
|
|
|
|
|
caption_words = caption.split(" ")
|
|
if len(caption_words) > self.max_words:
|
|
caption = " ".join(caption_words[: self.max_words])
|
|
|
|
return caption
|
|
|
|
|
|
@registry.register_processor("blip_image_train")
|
|
class BlipImageTrainProcessor(BlipImageBaseProcessor):
|
|
def __init__(
|
|
self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0, do_normalize=True
|
|
):
|
|
super().__init__(mean=mean, std=std, do_normalize=do_normalize)
|
|
|
|
self.transform = transforms.Compose(
|
|
[
|
|
transforms.RandomResizedCrop(
|
|
image_size,
|
|
scale=(min_scale, max_scale),
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
),
|
|
transforms.RandomHorizontalFlip(),
|
|
RandomAugment(
|
|
2,
|
|
5,
|
|
isPIL=True,
|
|
augs=[
|
|
"Identity",
|
|
"AutoContrast",
|
|
"Brightness",
|
|
"Sharpness",
|
|
"Equalize",
|
|
"ShearX",
|
|
"ShearY",
|
|
"TranslateX",
|
|
"TranslateY",
|
|
"Rotate",
|
|
],
|
|
),
|
|
transforms.ToTensor(),
|
|
self.normalize,
|
|
]
|
|
)
|
|
|
|
def __call__(self, item):
|
|
return self.transform(item)
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg=None):
|
|
if cfg is None:
|
|
cfg = OmegaConf.create()
|
|
|
|
image_size = cfg.get("image_size", 384)
|
|
|
|
mean = cfg.get("mean", None)
|
|
std = cfg.get("std", None)
|
|
|
|
min_scale = cfg.get("min_scale", 0.5)
|
|
max_scale = cfg.get("max_scale", 1.0)
|
|
|
|
do_normalize = cfg.get("do_normalize", True)
|
|
|
|
return cls(
|
|
image_size=image_size,
|
|
mean=mean,
|
|
std=std,
|
|
min_scale=min_scale,
|
|
max_scale=max_scale,
|
|
do_normalize=do_normalize,
|
|
)
|
|
|
|
|
|
@registry.register_processor("blip_image_eval")
|
|
class BlipImageEvalProcessor(BlipImageBaseProcessor):
|
|
def __init__(self, image_size=384, mean=None, std=None, do_normalize=True):
|
|
super().__init__(mean=mean, std=std, do_normalize=do_normalize)
|
|
|
|
self.transform = transforms.Compose(
|
|
[
|
|
transforms.Resize(
|
|
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
|
|
),
|
|
transforms.ToTensor(),
|
|
self.normalize,
|
|
]
|
|
)
|
|
|
|
def __call__(self, item):
|
|
return self.transform(item)
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg=None):
|
|
if cfg is None:
|
|
cfg = OmegaConf.create()
|
|
|
|
image_size = cfg.get("image_size", 384)
|
|
|
|
mean = cfg.get("mean", None)
|
|
std = cfg.get("std", None)
|
|
|
|
do_normalize = cfg.get("do_normalize", True)
|
|
|
|
return cls(image_size=image_size, mean=mean, std=std, do_normalize=do_normalize)
|
|
|
|
|
|
@registry.register_processor("blip2_image_train")
|
|
class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
|
|
def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0, do_normalize=True):
|
|
super().__init__(mean=mean, std=std, do_normalize=do_normalize)
|
|
|
|
self.transform = transforms.Compose(
|
|
[
|
|
transforms.RandomResizedCrop(
|
|
image_size,
|
|
scale=(min_scale, max_scale),
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
),
|
|
transforms.ToTensor(),
|
|
self.normalize,
|
|
]
|
|
)
|
|
|
|
def __call__(self, item):
|
|
return self.transform(item)
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg=None):
|
|
if cfg is None:
|
|
cfg = OmegaConf.create()
|
|
|
|
image_size = cfg.get("image_size", 224)
|
|
|
|
mean = cfg.get("mean", None)
|
|
std = cfg.get("std", None)
|
|
|
|
min_scale = cfg.get("min_scale", 0.5)
|
|
max_scale = cfg.get("max_scale", 1.0)
|
|
|
|
do_normalize = cfg.get("do_normalize", True)
|
|
|
|
return cls(
|
|
image_size=image_size,
|
|
mean=mean,
|
|
std=std,
|
|
min_scale=min_scale,
|
|
max_scale=max_scale,
|
|
do_normalize=do_normalize,
|
|
)
|
|
|
|
|
|
@registry.register_processor("blip2_image_eval")
|
|
class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
|
|
def __init__(self, image_size=224, mean=None, std=None, do_normalize=True):
|
|
super().__init__(mean=mean, std=std, do_normalize=do_normalize)
|
|
|
|
self.transform = transforms.Compose(
|
|
[
|
|
transforms.Resize(
|
|
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
|
|
),
|
|
transforms.ToTensor(),
|
|
self.normalize,
|
|
]
|
|
)
|
|
|
|
def __call__(self, item):
|
|
return self.transform(item)
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg=None):
|
|
if cfg is None:
|
|
cfg = OmegaConf.create()
|
|
|
|
image_size = cfg.get("image_size", 224)
|
|
|
|
mean = cfg.get("mean", None)
|
|
std = cfg.get("std", None)
|
|
|
|
do_normalize = cfg.get("do_normalize", True)
|
|
|
|
return cls(image_size=image_size, mean=mean, std=std, do_normalize=do_normalize) |