numbmelon commited on
Commit
aab9828
1 Parent(s): 3a24881

Upload modeling_internvl_chat.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_internvl_chat.py +323 -0
modeling_internvl_chat.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import warnings
7
+ from typing import Any, List, Optional, Tuple, Union
8
+
9
+ import torch.utils.checkpoint
10
+ from torch import nn
11
+ from torch.nn import CrossEntropyLoss
12
+ from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
13
+ LlamaTokenizer)
14
+ from transformers.modeling_outputs import CausalLMOutputWithPast
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.utils import ModelOutput, logging
17
+
18
+ from .configuration_internvl_chat import InternVLChatConfig
19
+ from .conversation import get_conv_template
20
+ from .modeling_intern_vit import InternVisionModel
21
+ from .modeling_phi3 import Phi3ForCausalLM
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class InternVLChatModel(PreTrainedModel):
27
+ config_class = InternVLChatConfig
28
+ main_input_name = 'pixel_values'
29
+ _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Phi3DecoderLayer']
30
+
31
+ def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
32
+ super().__init__(config)
33
+
34
+ image_size = config.force_image_size or config.vision_config.image_size
35
+ patch_size = config.vision_config.patch_size
36
+ self.patch_size = patch_size
37
+ self.select_layer = config.select_layer
38
+ self.template = config.template
39
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
40
+ self.downsample_ratio = config.downsample_ratio
41
+ self.ps_version = config.ps_version
42
+
43
+ logger.info(f'num_image_token: {self.num_image_token}')
44
+ logger.info(f'ps_version: {self.ps_version}')
45
+ if vision_model is not None:
46
+ self.vision_model = vision_model
47
+ else:
48
+ self.vision_model = InternVisionModel(config.vision_config)
49
+ if language_model is not None:
50
+ self.language_model = language_model
51
+ else:
52
+ if config.llm_config.architectures[0] == 'LlamaForCausalLM':
53
+ self.language_model = LlamaForCausalLM(config.llm_config)
54
+ elif config.llm_config.architectures[0] == 'Phi3ForCausalLM':
55
+ self.language_model = Phi3ForCausalLM(config.llm_config)
56
+ else:
57
+ raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
58
+
59
+ vit_hidden_size = config.vision_config.hidden_size
60
+ llm_hidden_size = config.llm_config.hidden_size
61
+
62
+ self.mlp1 = nn.Sequential(
63
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
64
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
65
+ nn.GELU(),
66
+ nn.Linear(llm_hidden_size, llm_hidden_size)
67
+ )
68
+
69
+ self.img_context_token_id = None
70
+
71
+ def forward(
72
+ self,
73
+ pixel_values: torch.FloatTensor,
74
+ input_ids: torch.LongTensor = None,
75
+ attention_mask: Optional[torch.Tensor] = None,
76
+ position_ids: Optional[torch.LongTensor] = None,
77
+ image_flags: Optional[torch.LongTensor] = None,
78
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
79
+ labels: Optional[torch.LongTensor] = None,
80
+ use_cache: Optional[bool] = None,
81
+ output_attentions: Optional[bool] = None,
82
+ output_hidden_states: Optional[bool] = None,
83
+ return_dict: Optional[bool] = None,
84
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
85
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
86
+
87
+ image_flags = image_flags.squeeze(-1)
88
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
89
+
90
+ vit_embeds = self.extract_feature(pixel_values)
91
+ vit_embeds = vit_embeds[image_flags == 1]
92
+ vit_batch_size = pixel_values.shape[0]
93
+
94
+ B, N, C = input_embeds.shape
95
+ input_embeds = input_embeds.reshape(B * N, C)
96
+
97
+ if torch.distributed.get_rank() == 0:
98
+ print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
99
+
100
+ input_ids = input_ids.reshape(B * N)
101
+ selected = (input_ids == self.img_context_token_id)
102
+ try:
103
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
104
+ except Exception as e:
105
+ vit_embeds = vit_embeds.reshape(-1, C)
106
+ print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
107
+ f'vit_embeds.shape={vit_embeds.shape}')
108
+ n_token = selected.sum()
109
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
110
+
111
+ input_embeds = input_embeds.reshape(B, N, C)
112
+
113
+ outputs = self.language_model(
114
+ inputs_embeds=input_embeds,
115
+ attention_mask=attention_mask,
116
+ position_ids=position_ids,
117
+ past_key_values=past_key_values,
118
+ use_cache=use_cache,
119
+ output_attentions=output_attentions,
120
+ output_hidden_states=output_hidden_states,
121
+ return_dict=return_dict,
122
+ )
123
+ logits = outputs.logits
124
+
125
+ loss = None
126
+ if labels is not None:
127
+ # Shift so that tokens < n predict n
128
+ shift_logits = logits[..., :-1, :].contiguous()
129
+ shift_labels = labels[..., 1:].contiguous()
130
+ # Flatten the tokens
131
+ loss_fct = CrossEntropyLoss()
132
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
133
+ shift_labels = shift_labels.view(-1)
134
+ # Enable model parallelism
135
+ shift_labels = shift_labels.to(shift_logits.device)
136
+ loss = loss_fct(shift_logits, shift_labels)
137
+
138
+ if not return_dict:
139
+ output = (logits,) + outputs[1:]
140
+ return (loss,) + output if loss is not None else output
141
+
142
+ return CausalLMOutputWithPast(
143
+ loss=loss,
144
+ logits=logits,
145
+ past_key_values=outputs.past_key_values,
146
+ hidden_states=outputs.hidden_states,
147
+ attentions=outputs.attentions,
148
+ )
149
+
150
+ def pixel_shuffle(self, x, scale_factor=0.5):
151
+ n, w, h, c = x.size()
152
+ # N, W, H, C --> N, W, H * scale, C // scale
153
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
154
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
155
+ x = x.permute(0, 2, 1, 3).contiguous()
156
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
157
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
158
+ int(c / (scale_factor * scale_factor)))
159
+ if self.ps_version == 'v1':
160
+ warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
161
+ 'which results in a transposed image.')
162
+ else:
163
+ x = x.permute(0, 2, 1, 3).contiguous()
164
+ return x
165
+
166
+ def extract_feature(self, pixel_values):
167
+ if self.select_layer == -1:
168
+ vit_embeds = self.vision_model(
169
+ pixel_values=pixel_values,
170
+ output_hidden_states=False,
171
+ return_dict=True).last_hidden_state
172
+ else:
173
+ vit_embeds = self.vision_model(
174
+ pixel_values=pixel_values,
175
+ output_hidden_states=True,
176
+ return_dict=True).hidden_states[self.select_layer]
177
+ vit_embeds = vit_embeds[:, 1:, :]
178
+
179
+ h = w = int(vit_embeds.shape[1] ** 0.5)
180
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
181
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
182
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
183
+ vit_embeds = self.mlp1(vit_embeds)
184
+ return vit_embeds
185
+
186
+ def batch_chat(self, tokenizer, pixel_values, num_patches_list, questions, generation_config, history=None,
187
+ return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
188
+ IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False):
189
+ if history is not None or return_history:
190
+ print('Now multi-turn chat is not supported in batch_chat.')
191
+ raise NotImplementedError
192
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
193
+ self.img_context_token_id = img_context_token_id
194
+
195
+ from .conversation import get_conv_template
196
+
197
+ queries = []
198
+ if verbose:
199
+ image_bs = pixel_values.shape[0]
200
+ print(f'dynamic ViT batch size: {image_bs}, num_patches_list: {num_patches_list}')
201
+ for idx, num_patches in enumerate(num_patches_list):
202
+ image_token = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
203
+ question = image_token + '\n' + questions[idx]
204
+ template = get_conv_template(self.template)
205
+ template.append_message(template.roles[0], question)
206
+ template.append_message(template.roles[1], None)
207
+ query = template.get_prompt()
208
+ queries.append(query)
209
+ tokenizer.padding_side = 'left'
210
+ model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
211
+ input_ids = model_inputs['input_ids'].cuda()
212
+ attention_mask = model_inputs['attention_mask'].cuda()
213
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
214
+ generation_config['eos_token_id'] = eos_token_id
215
+
216
+ generation_output = self.generate(
217
+ pixel_values=pixel_values,
218
+ input_ids=input_ids,
219
+ attention_mask=attention_mask,
220
+ **generation_config
221
+ )
222
+ responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
223
+ responses = [response.split(template.sep)[0].strip() for response in responses]
224
+ return responses
225
+
226
+ def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
227
+ num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
228
+ verbose=False):
229
+
230
+ if history is None and pixel_values is not None and '<image>' not in question:
231
+ question = '<image>\n' + question
232
+
233
+ if num_patches_list is None:
234
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
235
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
236
+
237
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
238
+ self.img_context_token_id = img_context_token_id
239
+
240
+ template = get_conv_template(self.template)
241
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
242
+
243
+ history = [] if history is None else history
244
+ for (old_question, old_answer) in history:
245
+ template.append_message(template.roles[0], old_question)
246
+ template.append_message(template.roles[1], old_answer)
247
+ template.append_message(template.roles[0], question)
248
+ template.append_message(template.roles[1], None)
249
+ query = template.get_prompt()
250
+
251
+ if verbose and pixel_values is not None:
252
+ image_bs = pixel_values.shape[0]
253
+ print(f'dynamic ViT batch size: {image_bs}')
254
+
255
+ for num_patches in num_patches_list:
256
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
257
+ query = query.replace('<image>', image_tokens, 1)
258
+
259
+ model_inputs = tokenizer(query, return_tensors='pt')
260
+ input_ids = model_inputs['input_ids'].cuda()
261
+ attention_mask = model_inputs['attention_mask'].cuda()
262
+ generation_config['eos_token_id'] = eos_token_id
263
+ generation_output = self.generate(
264
+ pixel_values=pixel_values,
265
+ input_ids=input_ids,
266
+ attention_mask=attention_mask,
267
+ **generation_config
268
+ )
269
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
270
+ response = response.split(template.sep)[0].strip()
271
+ history.append((question, response))
272
+ if return_history:
273
+ return response, history
274
+ else:
275
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
276
+ query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
277
+ if verbose:
278
+ print(query_to_print, response)
279
+ return response
280
+
281
+ @torch.no_grad()
282
+ def generate(
283
+ self,
284
+ pixel_values: Optional[torch.FloatTensor] = None,
285
+ input_ids: Optional[torch.FloatTensor] = None,
286
+ attention_mask: Optional[torch.LongTensor] = None,
287
+ visual_features: Optional[torch.FloatTensor] = None,
288
+ generation_config: Optional[GenerationConfig] = None,
289
+ output_hidden_states: Optional[bool] = None,
290
+ return_dict: Optional[bool] = None,
291
+ **generate_kwargs,
292
+ ) -> torch.LongTensor:
293
+
294
+ assert self.img_context_token_id is not None
295
+ if pixel_values is not None:
296
+ if visual_features is not None:
297
+ vit_embeds = visual_features
298
+ else:
299
+ vit_embeds = self.extract_feature(pixel_values)
300
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
301
+ B, N, C = input_embeds.shape
302
+ input_embeds = input_embeds.reshape(B * N, C)
303
+
304
+ input_ids = input_ids.reshape(B * N)
305
+ selected = (input_ids == self.img_context_token_id)
306
+ assert selected.sum() != 0
307
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
308
+
309
+ input_embeds = input_embeds.reshape(B, N, C)
310
+ else:
311
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
312
+
313
+ outputs = self.language_model.generate(
314
+ inputs_embeds=input_embeds,
315
+ attention_mask=attention_mask,
316
+ generation_config=generation_config,
317
+ output_hidden_states=output_hidden_states,
318
+ return_dict=return_dict,
319
+ use_cache=True,
320
+ **generate_kwargs,
321
+ )
322
+
323
+ return outputs