File size: 8,269 Bytes
9909305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: wangchongyi <wangchongyi@zhihu.com>
# @date: 2023/9/1
#

# coding=utf-8
# Copyright 2024 RhapsodyAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from torch import nn
import math
from dataclasses import dataclass
from typing import Optional, Tuple

from transformers.utils import ModelOutput
from transformers.modeling_utils import PreTrainedModel

from .configuration_siglip import SiglipVisionConfig
from .configuration_minicpm import MiniCPMConfig
from .configuration_minicpmv import MiniCPMVConfig

from .resampler import Resampler
from .modeling_minicpm import MiniCPMForCausalLM
from .modeling_siglip import SiglipVisionModel

from transformers import LlamaTokenizer # for text processing


@dataclass
class CausalVLMOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


class MiniCPMVForCausalLM(PreTrainedModel):
    model_type = "minicpm"
    _supports_flash_attn_2 = True
    
    def __init__(self, config: MiniCPMVConfig, adaptive=False):
        super().__init__(config)
        
        llm_config = config.llm_config
        vpm_config = config.vpm_config
        
        self.query_num = config.query_num
        self.patch_size = vpm_config.patch_size
        self.adaptive = adaptive
        self.slice_mode = config.slice_mode
        self.max_slice_nums = config.max_slice_nums
        self.mm_use_im_start_end = config.mm_use_im_start_end
        
        drop_vision_last_layer = config.drop_vision_last_layer
        
        # should assert vpm_config is SiglipVisionConfig
        vpm = SiglipVisionModel(vpm_config).vision_model
        
        if drop_vision_last_layer: # drop last vision layer
            vpm.encoder.layers = nn.ModuleList(vpm.encoder.layers[:-1])
        
        self.vpm = vpm
        
        # should assert llm_config is minicpmconfig
        self.llm = MiniCPMForCausalLM(llm_config)
        
        embed_dim = llm_config.hidden_size
        
        self.resampler = Resampler(
            num_queries=config.query_num,
            embed_dim=embed_dim,
            num_heads=embed_dim // 128,
            kv_dim=vpm_config.hidden_size,
            adaptive=adaptive
        )
        
        return

    def vpm_forward(self, data):
        if 'vision_hidden_states' not in data:
            dtype = self.vpm.embeddings.position_embedding.weight.dtype
            device = self.vpm.embeddings.position_embedding.weight.device

            pixel_values_list = data['pixel_values']
            tgt_sizes = data['tgt_sizes']

            vision_hidden_states = []

            all_pixel_values = []
            img_cnt = []

            for pixel_values in pixel_values_list:
                img_cnt.append(len(pixel_values))
                all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])  # 42 * L

            # exist image
            if all_pixel_values:
                tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
                max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])

                all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True, padding_value=0.0)
                all_pixel_values = all_pixel_values.to(device) # here we finally could put `all_pixel_values` to device.
                
                B, L, _ = all_pixel_values.shape
                all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)  # B, 3, 14, L

                patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
                for i in range(B):
                    patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True

                vision_embedding = self.vpm(all_pixel_values.type(dtype), patch_attention_mask=patch_attn_mask).last_hidden_state
                vision_embedding = self.resampler(vision_embedding, tgt_sizes)

                start = 0
                for pixel_values in pixel_values_list:
                    img_cnt = len(pixel_values)
                    if img_cnt > 0:
                        vision_hidden_states.append(vision_embedding[start: start + img_cnt])
                        start += img_cnt
                    else:
                        vision_hidden_states.append([])
            else: # no image
                if self.training:
                    dummy_image = torch.zeros(
                        (1, 3, 224, 224),
                        device=device, dtype=dtype
                    )
                    # 这是一个 dummy feature
                    tgt_sizes = torch.Tensor([[(224 // self.patch_size), math.ceil(224 / self.patch_size)]]).type(torch.int32)
                    dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
                else:
                    dummy_feature = []
                for _ in range(len(pixel_values_list)):
                    vision_hidden_states.append(dummy_feature)

        else:
            vision_hidden_states = data['vision_hidden_states']

        if hasattr(self.llm.config, 'scale_emb'):
            vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
        else:
            vllm_embedding = self.llm.model.embed_tokens(data['input_ids'])

        vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance(
            i, torch.Tensor) else i for i in vision_hidden_states]

        bs = len(data['input_ids'])
        for i in range(bs):
            cur_vs_hs = vision_hidden_states[i]
            
            if len(cur_vs_hs) > 0:
                
                cur_vllm_emb = vllm_embedding[i]
                
                cur_image_bound = data['image_bound'][i]
                
                if len(cur_image_bound) > 0:
                    
                    image_indices = torch.stack(
                        [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
                    ).to(vllm_embedding.device)
                    
                    cur_vllm_emb.scatter_(
                        0, 
                        image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
                        cur_vs_hs.view(-1, cur_vs_hs.shape[-1])
                    )

        return vllm_embedding, vision_hidden_states

    def forward(self, data, **kwargs):
        vllm_embedding, vision_hidden_states = self.vpm_forward(data)

        output = self.llm(
            inputs_embeds=vllm_embedding,
            attention_mask=data["attention_mask"],
            return_dict=True
        )

        return CausalVLMOutput(
            logits=output.logits,
            hidden_states=output.hidden_states,
            vision_hidden_states=vision_hidden_states
        )
    
    def generate(self, data, **kwargs):
        vllm_embedding, vision_hidden_states = self.vpm_forward(data)

        # position_ids = torch.arange(data["input_ids"].size(1), dtype=torch.long).to(data["input_ids"].device)
        # position_ids = position_ids.unsqueeze(0).expand_as(data["input_ids"])

        # 使用attention_mask将填充位置的position_ids设置为0
        # position_ids = position_ids * data["attention_mask"]
        output = self.llm.generate(
            inputs_embeds=vllm_embedding,
            # position_ids=position_ids,
            attention_mask=data["attention_mask"],
            **kwargs
        )

        return output