File size: 5,757 Bytes
18131bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright (c) OpenMMLab. All rights reserved.
import torch
torch.manual_seed(1024)

import torch.nn as nn
from transformers import PreTrainedModel

from .configuration_hformer import HformerConfig
from .qformer_src import BertConfig, BertLMHeadModel

from transformers import BertTokenizerFast as BertTokenizer

from .configuration_projector import ProjectorConfig
from .modeling_projector import ProjectorModel
from .fuse_modules import BiAttentionBlock
import torch.nn.functional as F
from transformers.activations import ACT2FN


class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        ret = super().forward(x)
        return ret
        #orig_type = x.dtype
        #ret = super().forward(x.type(torch.float32))
        #return ret.type(orig_type)

class HformerModel(PreTrainedModel):
    _auto_class = 'AutoModel'
    config_class = HformerConfig
    base_model_prefix = 'model'
    supports_gradient_checkpointing = False

    def __init__(self, config) -> None:
        super().__init__(config)
        self.gradient_checkpointing = False
        vision_width = config.visual_hidden_size
        num_query_token = config.num_query_token
        bert = config.bert
        llm_hidden_size = config.llm_hidden_size
        cross_attention_freq = config.cross_attention_freq
        qformer_pth = config.qformer_pth

        encoder_config = BertConfig.from_pretrained(bert)
        encoder_config.encoder_width = vision_width
        encoder_config.add_cross_attention = True
        encoder_config.cross_attention_freq = cross_attention_freq
        encoder_config.query_length = num_query_token
        encoder_config.num_hidden_layers = 12
        Qformer = BertLMHeadModel.from_pretrained(
            bert, config=encoder_config
        )
        remove_text = False
        if remove_text:
            # remove the Q-former's text component
            Qformer.cls = None
            Qformer.bert.embeddings.word_embeddings = None
            Qformer.bert.embeddings.position_embeddings = None
            for layer in Qformer.bert.encoder.layer:
                layer.output = None
                layer.intermediate = None

        query_tokens = nn.Parameter(
            torch.zeros(1, num_query_token, encoder_config.hidden_size)
        )
        query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
        
        self.Qformer = Qformer
        self.query_tokens = query_tokens
        self.llm_proj = nn.Linear(encoder_config.hidden_size, llm_hidden_size, bias=config.bias)
        self.ln_vision = LayerNorm(encoder_config.encoder_width)
        self.ln_llava = LayerNorm(encoder_config.encoder_width)
        
        tokenizer = BertTokenizer.from_pretrained(bert, truncation_side='right')
        tokenizer.add_special_tokens({"bos_token": "[DEC]"})
        self.Qformer.resize_token_embeddings(len(tokenizer))

        if qformer_pth is not None:
            pretrained_state_dict = torch.load(qformer_pth, map_location='cpu')['model']
            print(f'Load Qformer from {qformer_pth}')
            self.load_state_dict(pretrained_state_dict, strict=False)
            print('Done.')

        projector_config = ProjectorConfig(
            visual_hidden_size = config.visual_hidden_size,
            llm_hidden_size = config.llm_hidden_size,
            projector_depth = 2)
        self.connector = ProjectorModel(projector_config)

        d_model = config.llm_hidden_size
        dim_feedforward = 1024
        nhead = 8
        fusion_dropout = 0.0
        fusion_droppath = 0.1
        self.fuse = BiAttentionBlock(
                v_dim=d_model,
                l_dim=d_model,
                embed_dim=dim_feedforward,
                num_heads=nhead,
                dropout=fusion_dropout,
                drop_path=fusion_droppath,
                )

        modules = [
                nn.Linear(config.llm_hidden_size, config.llm_hidden_size//4, bias=False),
                ACT2FN['gelu'],
                nn.Linear(config.llm_hidden_size//4, config.llm_hidden_size, bias=False)
                ]
        self.ffn = nn.Sequential(*modules)

    def enable_input_require_grads(self):
        def make_inputs_require_grad(module, input, output):
            if isinstance(output, tuple):
                output[0].requires_grad_(True)
                output[1].requires_grad_(True)
            else:
                output.requires_grad_(True)

        self.Qformer.register_forward_hook(make_inputs_require_grad)
        self.llm_proj.register_forward_hook(make_inputs_require_grad)
        self.ln_vision.register_forward_hook(make_inputs_require_grad)
        self.connector.register_forward_hook(make_inputs_require_grad)
        self.ffn.register_forward_hook(make_inputs_require_grad)
        self.fuse.register_forward_hook(make_inputs_require_grad)

    def _set_gradient_checkpointing(self, module, value=False):
        exit()
        if isinstance(module, ProjectorModel):
            module.gradient_checkpointing = value

    def forward(self, x_):
        if self.gradient_checkpointing and self.training:
            print('Not supprted gradient checkpointing')
        #
        x = self.ln_vision(x_)
        query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
        query_output = self.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=x,
                return_dict=True,
        )
        q_feat = self.llm_proj(query_output.last_hidden_state)
        mlp_outputs = self.connector(x_)
        mlp_feat = mlp_outputs

        mlp_feat = mlp_feat + self.fuse(mlp_feat, q_feat)
        out = mlp_feat + self.ffn(mlp_feat)

        return out