visualglm-6b / visual.py
zxdu20's picture
Init commit
0df4363
import torch
import torch.nn as nn
from sat.model import ViTModel, BaseModel
from sat.model import BaseMixin
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
class LNFinalyMixin(BaseMixin):
def __init__(self, hidden_size):
super().__init__()
self.ln_vision = nn.LayerNorm(hidden_size)
def final_forward(self, logits, **kw_args):
return self.ln_vision(logits)
class EVAViT(ViTModel):
def __init__(self, args, transformer=None, parallel_output=True, **kwargs):
super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kwargs)
self.del_mixin("cls")
self.add_mixin("cls", LNFinalyMixin(args.hidden_size))
def forward(self, image):
batch_size = image.size(0)
input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=image.device)
attention_mask = torch.tensor([[1.]], dtype=image.dtype, device=image.device)
return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, image=image)
class QFormer(BaseModel):
def __init__(self, args, transformer=None, parallel_output=True, **kwargs):
super().__init__(args, transformer=transformer, parallel_output=parallel_output,
activation_func=nn.functional.gelu, **kwargs)
self.transformer.position_embeddings = None
def final_forward(self, logits, **kw_args):
return logits
def position_embedding_forward(self, position_ids, **kw_args):
return None
def forward(self, encoder_outputs):
batch_size = encoder_outputs.size(0)
input_ids = torch.arange(32, dtype=torch.long, device=encoder_outputs.device).unsqueeze(0).expand(batch_size,
-1)
attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device)
cross_attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device)
return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask,
encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask)
class BLIP2(torch.nn.Module):
def __init__(self, eva_args, qformer_args, vit=None, qformer=None, **kwargs):
super().__init__()
if vit is not None:
self.vit = vit
else:
self.vit = EVAViT(EVAViT.get_args(**eva_args))
if qformer is not None:
self.qformer = qformer
else:
self.qformer = QFormer(QFormer.get_args(**qformer_args))
self.glm_proj = nn.Linear(768, 4096).to(self.qformer.parameters().__next__().device).to(
self.qformer.parameters().__next__().dtype)
def forward(self, image, **kwargs):
enc = self.vit(image)[0]
out = self.qformer(enc)[0]
return self.glm_proj(out)
class BlipImageBaseProcessor():
def __init__(self, mean=None, std=None):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = transforms.Normalize(mean, std)
class BlipImageEvalProcessor(BlipImageBaseProcessor):
def __init__(self, image_size=384, mean=None, std=None):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)