""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import re import torch from lavis.processors import transforms_video from lavis.common.registry import registry from lavis.processors.base_processor import BaseProcessor from lavis.datasets.data_utils import load_video from lavis.processors.randaugment import RandomAugment from omegaconf import OmegaConf from torchvision import transforms from torchvision.transforms.functional import InterpolationMode MAX_INT = registry.get("MAX_INT") class ToUint8(object): def __init__(self): pass def __call__(self, tensor): return tensor.to(torch.uint8) def __repr__(self): return self.__class__.__name__ class ToTHWC(object): """ Args: clip (torch.tensor, dtype=torch.uint8): Size is (C, T, H, W) Return: clip (torch.tensor, dtype=torch.float): Size is (T, H, W, C) """ def __init__(self): pass def __call__(self, tensor): return tensor.permute(1, 2, 3, 0) def __repr__(self): return self.__class__.__name__ class BlipImageBaseProcessor(BaseProcessor): def __init__(self, mean=None, std=None): if mean is None: mean = (0.48145466, 0.4578275, 0.40821073) if std is None: std = (0.26862954, 0.26130258, 0.27577711) self.normalize = transforms.Normalize(mean, std) class BlipVideoBaseProcessor(BaseProcessor): def __init__(self, mean=None, std=None, n_frms=MAX_INT): if mean is None: mean = (0.48145466, 0.4578275, 0.40821073) if std is None: std = (0.26862954, 0.26130258, 0.27577711) self.normalize = transforms_video.NormalizeVideo(mean, std) self.n_frms = n_frms @registry.register_processor("blip_caption") class BlipCaptionProcessor(BaseProcessor): def __init__(self, prompt="", max_words=50): self.prompt = prompt self.max_words = max_words def __call__(self, caption): caption = self.prompt + self.pre_caption(caption) return caption @classmethod def from_config(cls, cfg=None): if cfg is None: cfg = OmegaConf.create() prompt = cfg.get("prompt", "") max_words = cfg.get("max_words", 50) return cls(prompt=prompt, max_words=max_words) def pre_caption(self, caption): caption = re.sub( r"([.!\"()*#:;~])", " ", caption.lower(), ) caption = re.sub( r"\s{2,}", " ", caption, ) caption = caption.rstrip("\n") caption = caption.strip(" ") # truncate caption caption_words = caption.split(" ") if len(caption_words) > self.max_words: caption = " ".join(caption_words[: self.max_words]) return caption @registry.register_processor("blip_question") class BlipQuestionProcessor(BaseProcessor): def __init__(self, max_words=50): self.max_words = max_words def __call__(self, question): return self.pre_question(question) @classmethod def from_config(cls, cfg=None): if cfg is None: cfg = OmegaConf.create() max_words = cfg.get("max_words", 50) return cls(max_words=max_words) def pre_question(self, question): question = re.sub( r"([.!\"()*#:;~])", "", question.lower(), ) question = question.rstrip(" ") # truncate question question_words = question.split(" ") if len(question_words) > self.max_words: question = " ".join(question_words[: self.max_words]) return question @registry.register_processor("blip_image_train") class BlipImageTrainProcessor(BlipImageBaseProcessor): def __init__( self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0 ): super().__init__(mean=mean, std=std) self.transform = transforms.Compose( [ transforms.RandomResizedCrop( image_size, scale=(min_scale, max_scale), interpolation=InterpolationMode.BICUBIC, ), transforms.RandomHorizontalFlip(), RandomAugment( 2, 5, isPIL=True, augs=[ "Identity", "AutoContrast", "Brightness", "Sharpness", "Equalize", "ShearX", "ShearY", "TranslateX", "TranslateY", "Rotate", ], ), transforms.ToTensor(), self.normalize, ] ) def __call__(self, item): return self.transform(item) @classmethod def from_config(cls, cfg=None): if cfg is None: cfg = OmegaConf.create() image_size = cfg.get("image_size", 384) mean = cfg.get("mean", None) std = cfg.get("std", None) min_scale = cfg.get("min_scale", 0.5) max_scale = cfg.get("max_scale", 1.0) return cls( image_size=image_size, mean=mean, std=std, min_scale=min_scale, max_scale=max_scale, ) @registry.register_processor("blip_image_eval") class BlipImageEvalProcessor(BlipImageBaseProcessor): def __init__(self, image_size=384, mean=None, std=None): super().__init__(mean=mean, std=std) self.transform = transforms.Compose( [ transforms.Resize( (image_size, image_size), interpolation=InterpolationMode.BICUBIC ), transforms.ToTensor(), self.normalize, ] ) def __call__(self, item): return self.transform(item) @classmethod def from_config(cls, cfg=None): if cfg is None: cfg = OmegaConf.create() image_size = cfg.get("image_size", 384) mean = cfg.get("mean", None) std = cfg.get("std", None) return cls(image_size=image_size, mean=mean, std=std) @registry.register_processor("blip2_image_train") class Blip2ImageTrainProcessor(BlipImageBaseProcessor): def __init__( self, image_size=364, mean=None, std=None, min_scale=0.5, max_scale=1.0 ): super().__init__(mean=mean, std=std) self.transform = transforms.Compose( [ transforms.RandomResizedCrop( image_size, scale=(min_scale, max_scale), interpolation=InterpolationMode.BICUBIC, ), transforms.RandomHorizontalFlip(), transforms.ToTensor(), self.normalize, ] ) def __call__(self, item): return self.transform(item) @classmethod def from_config(cls, cfg=None): if cfg is None: cfg = OmegaConf.create() image_size = cfg.get("image_size", 364) mean = cfg.get("mean", None) std = cfg.get("std", None) min_scale = cfg.get("min_scale", 0.5) max_scale = cfg.get("max_scale", 1.0) return cls( image_size=image_size, mean=mean, std=std, min_scale=min_scale, max_scale=max_scale, ) @registry.register_processor("blip2_video_train") class Blip2VideoTrainProcessor(BlipVideoBaseProcessor): def __init__( self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0, n_frms=MAX_INT, ): super().__init__(mean=mean, std=std, n_frms=n_frms) self.image_size = image_size self.transform = transforms.Compose( [ # Video size is (C, T, H, W) transforms_video.RandomResizedCropVideo( image_size, scale=(min_scale, max_scale), interpolation_mode="bicubic", ), ToTHWC(), # C, T, H, W -> T, H, W, C ToUint8(), transforms_video.ToTensorVideo(), # T, H, W, C -> C, T, H, W self.normalize, ] ) def __call__(self, vpath, clip_proposal=None): clip, indices, fps = load_video( video_path=vpath, n_frms=self.n_frms, height=self.image_size, width=self.image_size, sampling="random", clip_proposal=clip_proposal ) return self.transform(clip), indices, fps @classmethod def from_config(cls, cfg=None): if cfg is None: cfg = OmegaConf.create() image_size = cfg.get("image_size", 364) mean = cfg.get("mean", None) std = cfg.get("std", None) min_scale = cfg.get("min_scale", 0.5) max_scale = cfg.get("max_scale", 1.0) n_frms = cfg.get("n_frms", MAX_INT) return cls( image_size=image_size, mean=mean, std=std, min_scale=min_scale, max_scale=max_scale, n_frms=n_frms ) @registry.register_processor("blip_video_eval") class BlipVideoEvalProcessor(BlipVideoBaseProcessor): def __init__(self, image_size=384, mean=None, std=None, n_frms=MAX_INT): super().__init__(mean=mean, std=std, n_frms=n_frms) self.image_size = image_size self.transform = transforms.Compose( [ ToUint8(), # C, T, H, W ToTHWC(), # T, H, W, C transforms_video.ToTensorVideo(), # C, T, H, W self.normalize, # C, T, H, W ] ) self.n_frms = n_frms def __call__(self, vpath, clip_proposal=None): clip, indices, fps = load_video( video_path=vpath, n_frms=self.n_frms, height=self.image_size, width=self.image_size, sampling="uniform", clip_proposal=clip_proposal ) return self.transform(clip), indices, fps @classmethod def from_config(cls, cfg=None): if cfg is None: cfg = OmegaConf.create() image_size = cfg.get("image_size", 256) mean = cfg.get("mean", None) std = cfg.get("std", None) n_frms = cfg.get("n_frms", MAX_INT) return cls(image_size=image_size, mean=mean, std=std, n_frms=n_frms)