|
from timm.models.layers import trunc_normal_ as __call_trunc_normal_ |
|
from torchscale.component.multiway_network import MutliwayEmbedding |
|
from torchscale.component.embedding import PositionalEmbedding |
|
from torchscale.architecture.encoder import Encoder |
|
from transformers import PreTrainedModel |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch |
|
import math |
|
from transformers import AutoModel |
|
from transformers.utils.generic import ModelOutput |
|
from dataclasses import dataclass |
|
from typing import Optional |
|
from efficientnet_pytorch import EfficientNet |
|
from lavis.common.registry import registry |
|
from .configuration_vivqa import ViVQAConfig |
|
|
|
class BartPhoExtractor(nn.Module): |
|
def __init__(self): |
|
super(BartPhoExtractor, self).__init__() |
|
self.bartpho_word = AutoModel.from_pretrained("vinai/bartpho-word") |
|
|
|
def forward(self, input_ids, attention_mask): |
|
last_hidden_states = self.bartpho_word(input_ids, attention_mask) |
|
features = last_hidden_states[0] |
|
return features |
|
|
|
class Blip2EfficientExtractor(nn.Module): |
|
def __init__(self): |
|
super(Blip2EfficientExtractor, self).__init__() |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self.model_blip2 = registry.get_model_class(name="blip2_feature_extractor").from_pretrained(model_type="pretrain").to(self.device) |
|
if self.device == "cpu" or self.device == torch.device("cpu"): |
|
self.model_blip2 = self.model_blip2.float() |
|
self.model_blip2.eval() |
|
|
|
|
|
self.model_efficientnet = EfficientNet.from_pretrained('efficientnet-b7', advprop=True).to(self.device) |
|
self.model_efficientnet.eval() |
|
self.pooling1 = nn.AdaptiveAvgPool2d((1, 32)) |
|
self.pooling2 = nn.AdaptiveAvgPool2d((1, 768)) |
|
|
|
def forward(self, images): |
|
|
|
global_features = self.model_blip2.extract_features(samples={"image": images}, mode="image").image_embeds |
|
|
|
local_features = self.model_efficientnet.extract_features(images) |
|
local_features = self.pooling1(local_features) |
|
local_features = local_features.permute(0, 3, 2, 1) |
|
local_features = self.pooling2(local_features) |
|
batch_size = images.shape[0] |
|
local_features = local_features.reshape(batch_size, local_features.shape[1], -1) |
|
|
|
v = torch.cat([global_features, local_features], dim=1) |
|
return v |
|
|
|
@dataclass |
|
class ViVQAOutput(ModelOutput): |
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
|
|
def trunc_normal_(tensor, mean=0., std=1.): |
|
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) |
|
|
|
class Pooler(nn.Module): |
|
def __init__(self, input_features, output_features, norm_layer): |
|
super().__init__() |
|
self.norm = norm_layer(input_features) |
|
self.dense = nn.Linear(input_features, output_features) |
|
self.activation = nn.Tanh() |
|
|
|
def forward(self, x): |
|
cls_rep = x[:, 0, :] |
|
cls_rep = self.norm(cls_rep) |
|
pooled_output = self.dense(cls_rep) |
|
pooled_output = self.activation(pooled_output) |
|
return pooled_output |
|
|
|
class ViVQABEiT3(PreTrainedModel): |
|
def __init__(self, args): |
|
super().__init__(args) |
|
assert args.multiway |
|
assert not args.share_encoder_input_output_embed |
|
|
|
self.text_embed = BartPhoExtractor() |
|
|
|
self.vision_embed = Blip2EfficientExtractor() |
|
for param in self.vision_embed.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
self.linear = nn.Linear(1024, 768) |
|
|
|
|
|
num_position_embeddings = 64 |
|
embed_positions = MutliwayEmbedding( |
|
modules=[ |
|
PositionalEmbedding(num_position_embeddings + 2, args.encoder_embed_dim), |
|
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim), |
|
], |
|
dim=1, |
|
) |
|
self.encoder = Encoder( |
|
args, |
|
embed_tokens=None, |
|
embed_positions=embed_positions, |
|
output_projection=None, |
|
is_encoder_decoder=False, |
|
) |
|
|
|
def forward(self, textual_tokens, visual_tokens, text_padding_position): |
|
x1 = self.vision_embed(visual_tokens) |
|
multiway_split_position = x1.size(1) |
|
|
|
x2 = self.text_embed(textual_tokens, 1-text_padding_position) |
|
x2 = self.linear(x2) |
|
|
|
x = torch.cat([x1, x2], dim=1) |
|
|
|
encoder_padding_mask = torch.cat( |
|
[ |
|
torch.zeros(x1.shape[:-1]).to(x1.device).bool(), |
|
text_padding_position, |
|
], |
|
dim=1, |
|
) |
|
|
|
encoder_out = self.encoder( |
|
src_tokens=None, |
|
encoder_padding_mask=encoder_padding_mask, |
|
token_embeddings=x, |
|
multiway_split_position=multiway_split_position |
|
) |
|
encoder_out["multiway_split_position"] = multiway_split_position |
|
return encoder_out |
|
|
|
class BEiT3Wrapper(PreTrainedModel): |
|
def __init__(self, args, **kwargs): |
|
super().__init__(args) |
|
self.beit3 = ViVQABEiT3(args) |
|
|
|
|
|
def fix_init_weight(self): |
|
def rescale(param, layer_id): |
|
param.div_(math.sqrt(2.0 * layer_id)) |
|
|
|
for layer_id, layer in enumerate(self.blocks): |
|
rescale(layer.attn.proj.weight.data, layer_id + 1) |
|
rescale(layer.mlp.fc2.weight.data, layer_id + 1) |
|
|
|
def get_num_layers(self): |
|
return self.beit3.encoder.num_layers |
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self): |
|
return {'pos_embed', 'cls_token', 'beit3.encoder.embed_positions.A.weight', 'beit3.vision_embed.cls_token', 'logit_scale'} |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
class BEiT3ForVietnameseVisualQuestionAnswering(BEiT3Wrapper): |
|
config_class = ViVQAConfig |
|
def __init__( |
|
self, |
|
args, |
|
num_classes=353, |
|
**kwargs |
|
): |
|
super(BEiT3ForVietnameseVisualQuestionAnswering, self).__init__(args=args) |
|
embed_dim = args.encoder_embed_dim |
|
self.pooler = Pooler( |
|
input_features=embed_dim, |
|
output_features=embed_dim, |
|
norm_layer=nn.LayerNorm, |
|
) |
|
self.pooler.apply(self._init_weights) |
|
self.head = nn.Sequential( |
|
nn.Linear(embed_dim, embed_dim * 2), |
|
nn.LayerNorm(embed_dim * 2), |
|
nn.GELU(), |
|
nn.Linear(embed_dim * 2, num_classes), |
|
) |
|
self.head.apply(self._init_weights) |
|
|
|
def forward(self, image, question, padding_mask, labels=None, **kwargs): |
|
outputs = self.beit3( |
|
textual_tokens=question, |
|
visual_tokens=image, |
|
text_padding_position=padding_mask, |
|
) |
|
x = outputs["encoder_out"] |
|
cls_rep = self.pooler(x) |
|
logits = self.head(cls_rep) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = F.cross_entropy(logits, labels) |
|
|
|
return ViVQAOutput( |
|
loss=loss, |
|
logits=logits, |
|
) |