import torch import torch.nn as nn from sat.model import ViTModel, BaseModel from sat.model import BaseMixin from torchvision import transforms from torchvision.transforms.functional import InterpolationMode class LNFinalyMixin(BaseMixin): def __init__(self, hidden_size): super().__init__() self.ln_vision = nn.LayerNorm(hidden_size) def final_forward(self, logits, **kw_args): return self.ln_vision(logits) class EVAViT(ViTModel): def __init__(self, args, transformer=None, parallel_output=True, **kwargs): super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kwargs) self.del_mixin("cls") self.add_mixin("cls", LNFinalyMixin(args.hidden_size)) def forward(self, image): batch_size = image.size(0) input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=image.device) attention_mask = torch.tensor([[1.]], dtype=image.dtype, device=image.device) return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, image=image) class QFormer(BaseModel): def __init__(self, args, transformer=None, parallel_output=True, **kwargs): super().__init__(args, transformer=transformer, parallel_output=parallel_output, activation_func=nn.functional.gelu, **kwargs) self.transformer.position_embeddings = None def final_forward(self, logits, **kw_args): return logits def position_embedding_forward(self, position_ids, **kw_args): return None def forward(self, encoder_outputs): batch_size = encoder_outputs.size(0) input_ids = torch.arange(32, dtype=torch.long, device=encoder_outputs.device).unsqueeze(0).expand(batch_size, -1) attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device) cross_attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device) return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask) class BLIP2(torch.nn.Module): def __init__(self, eva_args, qformer_args, vit=None, qformer=None, **kwargs): super().__init__() if vit is not None: self.vit = vit else: self.vit = EVAViT(EVAViT.get_args(**eva_args)) if qformer is not None: self.qformer = qformer else: self.qformer = QFormer(QFormer.get_args(**qformer_args)) self.glm_proj = nn.Linear(768, 4096).to(self.qformer.parameters().__next__().device).to( self.qformer.parameters().__next__().dtype) def forward(self, image, **kwargs): enc = self.vit(image)[0] out = self.qformer(enc)[0] return self.glm_proj(out) class BlipImageBaseProcessor(): def __init__(self, mean=None, std=None): if mean is None: mean = (0.48145466, 0.4578275, 0.40821073) if std is None: std = (0.26862954, 0.26130258, 0.27577711) self.normalize = transforms.Normalize(mean, std) class BlipImageEvalProcessor(BlipImageBaseProcessor): def __init__(self, image_size=384, mean=None, std=None): super().__init__(mean=mean, std=std) self.transform = transforms.Compose( [ transforms.Resize( (image_size, image_size), interpolation=InterpolationMode.BICUBIC ), transforms.ToTensor(), self.normalize, ] ) def __call__(self, item): return self.transform(item)