from timm.models.layers import trunc_normal_ as __call_trunc_normal_ from torchscale.component.multiway_network import MutliwayEmbedding from torchscale.component.embedding import PositionalEmbedding from torchscale.architecture.encoder import Encoder from transformers import PreTrainedModel import torch.nn as nn import torch.nn.functional as F import torch import math from transformers import AutoModel from transformers.utils.generic import ModelOutput from dataclasses import dataclass from typing import Optional from efficientnet_pytorch import EfficientNet from lavis.common.registry import registry from .configuration_vivqa import ViVQAConfig class BartPhoExtractor(nn.Module): def __init__(self): super(BartPhoExtractor, self).__init__() self.bartpho_word = AutoModel.from_pretrained("vinai/bartpho-word") def forward(self, input_ids, attention_mask): last_hidden_states = self.bartpho_word(input_ids, attention_mask) features = last_hidden_states[0] return features class Blip2EfficientExtractor(nn.Module): def __init__(self): super(Blip2EfficientExtractor, self).__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" # BLIP-2 self.model_blip2 = registry.get_model_class(name="blip2_feature_extractor").from_pretrained(model_type="pretrain").to(self.device) if self.device == "cpu" or self.device == torch.device("cpu"): self.model_blip2 = self.model_blip2.float() self.model_blip2.eval() # Efficientnet self.model_efficientnet = EfficientNet.from_pretrained('efficientnet-b7', advprop=True).to(self.device) self.model_efficientnet.eval() self.pooling1 = nn.AdaptiveAvgPool2d((1, 32)) self.pooling2 = nn.AdaptiveAvgPool2d((1, 768)) def forward(self, images): global_features = self.model_blip2.extract_features(samples={"image": images}, mode="image").image_embeds local_features = self.model_efficientnet.extract_features(images) local_features = self.pooling1(local_features) local_features = local_features.permute(0, 3, 2, 1) local_features = self.pooling2(local_features) batch_size = images.shape[0] local_features = local_features.reshape(batch_size, local_features.shape[1], -1) v = torch.cat([global_features, local_features], dim=1) return v @dataclass class ViVQAOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None def trunc_normal_(tensor, mean=0., std=1.): __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) class Pooler(nn.Module): def __init__(self, input_features, output_features, norm_layer): super().__init__() self.norm = norm_layer(input_features) self.dense = nn.Linear(input_features, output_features) self.activation = nn.Tanh() def forward(self, x): cls_rep = x[:, 0, :] cls_rep = self.norm(cls_rep) pooled_output = self.dense(cls_rep) pooled_output = self.activation(pooled_output) return pooled_output class ViVQABEiT3(PreTrainedModel): def __init__(self, args): super().__init__(args) assert args.multiway assert not args.share_encoder_input_output_embed self.text_embed = BartPhoExtractor() self.vision_embed = Blip2EfficientExtractor() for param in self.vision_embed.parameters(): param.requires_grad = False self.linear = nn.Linear(1024, 768) # being consistent with Fairseq, which starts from 2 for position embedding num_position_embeddings = 64 embed_positions = MutliwayEmbedding( modules=[ PositionalEmbedding(num_position_embeddings + 2, args.encoder_embed_dim), PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim), ], dim=1, ) self.encoder = Encoder( args, embed_tokens=None, embed_positions=embed_positions, output_projection=None, is_encoder_decoder=False, ) def forward(self, textual_tokens, visual_tokens, text_padding_position): x1 = self.vision_embed(visual_tokens) multiway_split_position = x1.size(1) x2 = self.text_embed(textual_tokens, 1-text_padding_position) x2 = self.linear(x2) x = torch.cat([x1, x2], dim=1) encoder_padding_mask = torch.cat( [ torch.zeros(x1.shape[:-1]).to(x1.device).bool(), text_padding_position, ], dim=1, ) encoder_out = self.encoder( src_tokens=None, encoder_padding_mask=encoder_padding_mask, token_embeddings=x, multiway_split_position=multiway_split_position ) encoder_out["multiway_split_position"] = multiway_split_position return encoder_out class BEiT3Wrapper(PreTrainedModel): def __init__(self, args, **kwargs): super().__init__(args) self.beit3 = ViVQABEiT3(args) # self.apply(self._init_weights) def fix_init_weight(self): def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) def get_num_layers(self): return self.beit3.encoder.num_layers @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token', 'beit3.encoder.embed_positions.A.weight', 'beit3.vision_embed.cls_token', 'logit_scale'} def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) class BEiT3ForVietnameseVisualQuestionAnswering(BEiT3Wrapper): config_class = ViVQAConfig def __init__( self, args, num_classes=353, **kwargs ): super(BEiT3ForVietnameseVisualQuestionAnswering, self).__init__(args=args) embed_dim = args.encoder_embed_dim self.pooler = Pooler( input_features=embed_dim, output_features=embed_dim, norm_layer=nn.LayerNorm, ) self.pooler.apply(self._init_weights) self.head = nn.Sequential( nn.Linear(embed_dim, embed_dim * 2), nn.LayerNorm(embed_dim * 2), nn.GELU(), nn.Linear(embed_dim * 2, num_classes), ) self.head.apply(self._init_weights) def forward(self, image, question, padding_mask, labels=None, **kwargs): outputs = self.beit3( textual_tokens=question, visual_tokens=image, text_padding_position=padding_mask, ) x = outputs["encoder_out"] cls_rep = self.pooler(x) logits = self.head(cls_rep) loss = None if labels is not None: loss = F.cross_entropy(logits, labels) return ViVQAOutput( loss=loss, logits=logits, )