visualglm-6b / visual.py
zxdu20's picture
Init commit
0df4363
raw history blame
No virus
3.85 kB
import torch
import torch.nn as nn
from sat.model import ViTModel, BaseModel
from sat.model import BaseMixin
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
class LNFinalyMixin(BaseMixin):
def __init__(self, hidden_size):
super().__init__()
self.ln_vision = nn.LayerNorm(hidden_size)
def final_forward(self, logits, **kw_args):
return self.ln_vision(logits)
class EVAViT(ViTModel):
def __init__(self, args, transformer=None, parallel_output=True, **kwargs):
super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kwargs)
self.del_mixin("cls")
self.add_mixin("cls", LNFinalyMixin(args.hidden_size))
def forward(self, image):
batch_size = image.size(0)
input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=image.device)
attention_mask = torch.tensor([[1.]], dtype=image.dtype, device=image.device)
return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, image=image)
class QFormer(BaseModel):
def __init__(self, args, transformer=None, parallel_output=True, **kwargs):
super().__init__(args, transformer=transformer, parallel_output=parallel_output,
activation_func=nn.functional.gelu, **kwargs)
self.transformer.position_embeddings = None
def final_forward(self, logits, **kw_args):
return logits
def position_embedding_forward(self, position_ids, **kw_args):
return None
def forward(self, encoder_outputs):
batch_size = encoder_outputs.size(0)
input_ids = torch.arange(32, dtype=torch.long, device=encoder_outputs.device).unsqueeze(0).expand(batch_size,
-1)
attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device)
cross_attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device)
return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask,
encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask)
class BLIP2(torch.nn.Module):
def __init__(self, eva_args, qformer_args, vit=None, qformer=None, **kwargs):
super().__init__()
if vit is not None:
self.vit = vit
else:
self.vit = EVAViT(EVAViT.get_args(**eva_args))
if qformer is not None:
self.qformer = qformer
else:
self.qformer = QFormer(QFormer.get_args(**qformer_args))
self.glm_proj = nn.Linear(768, 4096).to(self.qformer.parameters().__next__().device).to(
self.qformer.parameters().__next__().dtype)
def forward(self, image, **kwargs):
enc = self.vit(image)[0]
out = self.qformer(enc)[0]
return self.glm_proj(out)
class BlipImageBaseProcessor():
def __init__(self, mean=None, std=None):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = transforms.Normalize(mean, std)
class BlipImageEvalProcessor(BlipImageBaseProcessor):
def __init__(self, image_size=384, mean=None, std=None):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)