import torch torch.manual_seed(1024) import torch.nn as nn from transformers import PreTrainedModel from .configuration_hformer import HformerConfig from .qformer_src import BertConfig, BertLMHeadModel from transformers import BertTokenizerFast as BertTokenizer from .configuration_projector import ProjectorConfig from .modeling_projector import ProjectorModel import torch.nn.functional as F from transformers.activations import ACT2FN class LayerNorm(nn.LayerNorm): def forward(self, x: torch.Tensor): ret = super().forward(x) return ret class HformerModel(PreTrainedModel): _auto_class = 'AutoModel' config_class = HformerConfig base_model_prefix = 'model' supports_gradient_checkpointing = False def __init__(self, config) -> None: super().__init__(config) self.gradient_checkpointing = False vision_width = config.visual_hidden_size num_query_token = config.num_query_token bert = config.bert llm_hidden_size = config.llm_hidden_size cross_attention_freq = config.cross_attention_freq qformer_pth = config.qformer_pth encoder_config = BertConfig.from_pretrained(bert) encoder_config.encoder_width = vision_width encoder_config.add_cross_attention = True encoder_config.cross_attention_freq = cross_attention_freq encoder_config.query_length = num_query_token encoder_config.num_hidden_layers = 12 Qformer = BertLMHeadModel.from_pretrained( bert, config=encoder_config ) remove_text = False if remove_text: Qformer.cls = None Qformer.bert.embeddings.word_embeddings = None Qformer.bert.embeddings.position_embeddings = None for layer in Qformer.bert.encoder.layer: layer.output = None layer.intermediate = None query_tokens = nn.Parameter( torch.zeros(1, num_query_token, encoder_config.hidden_size) ) query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) self.Qformer = Qformer self.query_tokens = query_tokens self.llm_proj = nn.Linear(encoder_config.hidden_size, llm_hidden_size, bias=config.bias) self.ln_vision = LayerNorm(encoder_config.encoder_width) self.ln_llava = LayerNorm(encoder_config.encoder_width) tokenizer = BertTokenizer.from_pretrained(bert, truncation_side='right') tokenizer.add_special_tokens({"bos_token": "[DEC]"}) self.Qformer.resize_token_embeddings(len(tokenizer)) if qformer_pth is not None: pretrained_state_dict = torch.load(qformer_pth, map_location='cpu')['model'] print(f'Load Qformer from {qformer_pth}') self.load_state_dict(pretrained_state_dict, strict=False) print('Done.') projector_config = ProjectorConfig( visual_hidden_size = config.visual_hidden_size, llm_hidden_size = config.llm_hidden_size, projector_depth = 2) self.connector = ProjectorModel(projector_config) modules = [ nn.Linear(config.llm_hidden_size, config.llm_hidden_size//4, bias=False), ACT2FN['gelu'], nn.Linear(config.llm_hidden_size//4, config.llm_hidden_size, bias=False) ] self.ffn = nn.Sequential(*modules) def enable_input_require_grads(self): def make_inputs_require_grad(module, input, output): if isinstance(output, tuple): output[0].requires_grad_(True) output[1].requires_grad_(True) else: output.requires_grad_(True) self.Qformer.register_forward_hook(make_inputs_require_grad) self.llm_proj.register_forward_hook(make_inputs_require_grad) self.ln_vision.register_forward_hook(make_inputs_require_grad) self.connector.register_forward_hook(make_inputs_require_grad) self.ffn.register_forward_hook(make_inputs_require_grad) def _set_gradient_checkpointing(self, module, value=False): pass def forward(self, x_): if self.gradient_checkpointing and self.training: print('Not support gradient checkpointing') x = self.ln_vision(x_) query_tokens = self.query_tokens.expand(x.shape[0], -1, -1) query_output = self.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=x, return_dict=True, ) q_feat = self.llm_proj(query_output.last_hidden_state) mlp_outputs = self.connector(x_) mlp_feat = mlp_outputs int_feat = mlp_feat + q_feat.mean(dim=1)[:,None] out = int_feat + self.ffn(int_feat) return out