xwwu's picture
Upload folder using huggingface_hub
18131bb verified
raw
history blame
5.76 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch
torch.manual_seed(1024)
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_hformer import HformerConfig
from .qformer_src import BertConfig, BertLMHeadModel
from transformers import BertTokenizerFast as BertTokenizer
from .configuration_projector import ProjectorConfig
from .modeling_projector import ProjectorModel
from .fuse_modules import BiAttentionBlock
import torch.nn.functional as F
from transformers.activations import ACT2FN
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
ret = super().forward(x)
return ret
#orig_type = x.dtype
#ret = super().forward(x.type(torch.float32))
#return ret.type(orig_type)
class HformerModel(PreTrainedModel):
_auto_class = 'AutoModel'
config_class = QformerConfig
base_model_prefix = 'model'
supports_gradient_checkpointing = False
def __init__(self, config) -> None:
super().__init__(config)
self.gradient_checkpointing = False
vision_width = config.visual_hidden_size
num_query_token = config.num_query_token
bert = config.bert
llm_hidden_size = config.llm_hidden_size
cross_attention_freq = config.cross_attention_freq
qformer_pth = config.qformer_pth
encoder_config = BertConfig.from_pretrained(bert)
encoder_config.encoder_width = vision_width
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
encoder_config.num_hidden_layers = 12
Qformer = BertLMHeadModel.from_pretrained(
bert, config=encoder_config
)
remove_text = False
if remove_text:
# remove the Q-former's text component
Qformer.cls = None
Qformer.bert.embeddings.word_embeddings = None
Qformer.bert.embeddings.position_embeddings = None
for layer in Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
self.Qformer = Qformer
self.query_tokens = query_tokens
self.llm_proj = nn.Linear(encoder_config.hidden_size, llm_hidden_size, bias=config.bias)
self.ln_vision = LayerNorm(encoder_config.encoder_width)
self.ln_llava = LayerNorm(encoder_config.encoder_width)
tokenizer = BertTokenizer.from_pretrained(bert, truncation_side='right')
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
self.Qformer.resize_token_embeddings(len(tokenizer))
if qformer_pth is not None:
pretrained_state_dict = torch.load(qformer_pth, map_location='cpu')['model']
print(f'Load Qformer from {qformer_pth}')
self.load_state_dict(pretrained_state_dict, strict=False)
print('Done.')
projector_config = ProjectorConfig(
visual_hidden_size = config.visual_hidden_size,
llm_hidden_size = config.llm_hidden_size,
projector_depth = 2)
self.connector = ProjectorModel(projector_config)
d_model = config.llm_hidden_size
dim_feedforward = 1024
nhead = 8
fusion_dropout = 0.0
fusion_droppath = 0.1
self.fuse = BiAttentionBlock(
v_dim=d_model,
l_dim=d_model,
embed_dim=dim_feedforward,
num_heads=nhead,
dropout=fusion_dropout,
drop_path=fusion_droppath,
)
modules = [
nn.Linear(config.llm_hidden_size, config.llm_hidden_size//4, bias=False),
ACT2FN['gelu'],
nn.Linear(config.llm_hidden_size//4, config.llm_hidden_size, bias=False)
]
self.ffn = nn.Sequential(*modules)
def enable_input_require_grads(self):
def make_inputs_require_grad(module, input, output):
if isinstance(output, tuple):
output[0].requires_grad_(True)
output[1].requires_grad_(True)
else:
output.requires_grad_(True)
self.Qformer.register_forward_hook(make_inputs_require_grad)
self.llm_proj.register_forward_hook(make_inputs_require_grad)
self.ln_vision.register_forward_hook(make_inputs_require_grad)
self.connector.register_forward_hook(make_inputs_require_grad)
self.ffn.register_forward_hook(make_inputs_require_grad)
self.fuse.register_forward_hook(make_inputs_require_grad)
def _set_gradient_checkpointing(self, module, value=False):
exit()
if isinstance(module, ProjectorModel):
module.gradient_checkpointing = value
def forward(self, x_):
if self.gradient_checkpointing and self.training:
print('Not supprted gradient checkpointing')
#
x = self.ln_vision(x_)
query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=x,
return_dict=True,
)
q_feat = self.llm_proj(query_output.last_hidden_state)
mlp_outputs = self.connector(x_)
mlp_feat = mlp_outputs
mlp_feat = mlp_feat + self.fuse(mlp_feat, q_feat)
out = mlp_feat + self.ffn(mlp_feat)
return out