|
import torch |
|
import torch.nn as nn |
|
|
|
from sat.model import ViTModel, BaseModel |
|
from sat.model import BaseMixin |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
class LNFinalyMixin(BaseMixin): |
|
def __init__(self, hidden_size): |
|
super().__init__() |
|
self.ln_vision = nn.LayerNorm(hidden_size) |
|
|
|
def final_forward(self, logits, **kw_args): |
|
return self.ln_vision(logits) |
|
|
|
|
|
class EVAViT(ViTModel): |
|
def __init__(self, args, transformer=None, parallel_output=True, **kwargs): |
|
super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kwargs) |
|
self.del_mixin("cls") |
|
self.add_mixin("cls", LNFinalyMixin(args.hidden_size)) |
|
|
|
def forward(self, image): |
|
batch_size = image.size(0) |
|
input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=image.device) |
|
attention_mask = torch.tensor([[1.]], dtype=image.dtype, device=image.device) |
|
return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, image=image) |
|
|
|
|
|
class QFormer(BaseModel): |
|
def __init__(self, args, transformer=None, parallel_output=True, **kwargs): |
|
super().__init__(args, transformer=transformer, parallel_output=parallel_output, |
|
activation_func=nn.functional.gelu, **kwargs) |
|
self.transformer.position_embeddings = None |
|
|
|
def final_forward(self, logits, **kw_args): |
|
return logits |
|
|
|
def position_embedding_forward(self, position_ids, **kw_args): |
|
return None |
|
|
|
def forward(self, encoder_outputs): |
|
batch_size = encoder_outputs.size(0) |
|
input_ids = torch.arange(32, dtype=torch.long, device=encoder_outputs.device).unsqueeze(0).expand(batch_size, |
|
-1) |
|
attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device) |
|
cross_attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device) |
|
return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, |
|
encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask) |
|
|
|
|
|
class BLIP2(torch.nn.Module): |
|
def __init__(self, eva_args, qformer_args, vit=None, qformer=None, **kwargs): |
|
super().__init__() |
|
if vit is not None: |
|
self.vit = vit |
|
else: |
|
self.vit = EVAViT(EVAViT.get_args(**eva_args)) |
|
if qformer is not None: |
|
self.qformer = qformer |
|
else: |
|
self.qformer = QFormer(QFormer.get_args(**qformer_args)) |
|
|
|
self.glm_proj = nn.Linear(768, 4096).to(self.qformer.parameters().__next__().device).to( |
|
self.qformer.parameters().__next__().dtype) |
|
|
|
def forward(self, image, **kwargs): |
|
enc = self.vit(image)[0] |
|
out = self.qformer(enc)[0] |
|
return self.glm_proj(out) |
|
|
|
|
|
class BlipImageBaseProcessor(): |
|
def __init__(self, mean=None, std=None): |
|
if mean is None: |
|
mean = (0.48145466, 0.4578275, 0.40821073) |
|
if std is None: |
|
std = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
self.normalize = transforms.Normalize(mean, std) |
|
|
|
|
|
class BlipImageEvalProcessor(BlipImageBaseProcessor): |
|
def __init__(self, image_size=384, mean=None, std=None): |
|
super().__init__(mean=mean, std=std) |
|
|
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
(image_size, image_size), interpolation=InterpolationMode.BICUBIC |
|
), |
|
transforms.ToTensor(), |
|
self.normalize, |
|
] |
|
) |
|
|
|
def __call__(self, item): |
|
return self.transform(item) |
|
|