Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import re | |
from minigpt4.common.registry import registry | |
from minigpt4.processors.base_processor import BaseProcessor | |
from minigpt4.processors.randaugment import RandomAugment | |
from omegaconf import OmegaConf | |
from torchvision import transforms | |
from torchvision.transforms.functional import InterpolationMode | |
class BlipImageBaseProcessor(BaseProcessor): | |
def __init__(self, mean=None, std=None): | |
if mean is None: | |
mean = (0.48145466, 0.4578275, 0.40821073) | |
if std is None: | |
std = (0.26862954, 0.26130258, 0.27577711) | |
self.normalize = transforms.Normalize(mean, std) | |
class BlipCaptionProcessor(BaseProcessor): | |
def __init__(self, prompt="", max_words=50): | |
self.prompt = prompt | |
self.max_words = max_words | |
def __call__(self, caption): | |
caption = self.prompt + self.pre_caption(caption) | |
return caption | |
def from_config(cls, cfg=None): | |
if cfg is None: | |
cfg = OmegaConf.create() | |
prompt = cfg.get("prompt", "") | |
max_words = cfg.get("max_words", 50) | |
return cls(prompt=prompt, max_words=max_words) | |
def pre_caption(self, caption): | |
caption = re.sub( | |
r"([.!\"()*#:;~])", | |
" ", | |
caption.lower(), | |
) | |
caption = re.sub( | |
r"\s{2,}", | |
" ", | |
caption, | |
) | |
caption = caption.rstrip("\n") | |
caption = caption.strip(" ") | |
# truncate caption | |
caption_words = caption.split(" ") | |
if len(caption_words) > self.max_words: | |
caption = " ".join(caption_words[: self.max_words]) | |
return caption | |
class Blip2ImageTrainProcessor(BlipImageBaseProcessor): | |
def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): | |
super().__init__(mean=mean, std=std) | |
self.transform = transforms.Compose( | |
[ | |
transforms.RandomResizedCrop( | |
image_size, | |
scale=(min_scale, max_scale), | |
interpolation=InterpolationMode.BICUBIC, | |
), | |
transforms.ToTensor(), | |
self.normalize, | |
] | |
) | |
def __call__(self, item): | |
return self.transform(item) | |
def from_config(cls, cfg=None): | |
if cfg is None: | |
cfg = OmegaConf.create() | |
image_size = cfg.get("image_size", 224) | |
mean = cfg.get("mean", None) | |
std = cfg.get("std", None) | |
min_scale = cfg.get("min_scale", 0.5) | |
max_scale = cfg.get("max_scale", 1.0) | |
return cls( | |
image_size=image_size, | |
mean=mean, | |
std=std, | |
min_scale=min_scale, | |
max_scale=max_scale, | |
) | |
class Blip2ImageEvalProcessor(BlipImageBaseProcessor): | |
def __init__(self, image_size=224, mean=None, std=None): | |
super().__init__(mean=mean, std=std) | |
self.transform = transforms.Compose( | |
[ | |
transforms.Resize( | |
(image_size, image_size), interpolation=InterpolationMode.BICUBIC | |
), | |
transforms.ToTensor(), | |
self.normalize, | |
] | |
) | |
def __call__(self, item): | |
return self.transform(item) | |
def from_config(cls, cfg=None): | |
if cfg is None: | |
cfg = OmegaConf.create() | |
image_size = cfg.get("image_size", 224) | |
mean = cfg.get("mean", None) | |
std = cfg.get("std", None) | |
return cls(image_size=image_size, mean=mean, std=std) |