Visual Question Answering
Transformers
PyTorch
internvl_chat
feature-extraction
custom_code
czczup commited on
Commit
126d8db
1 Parent(s): 81c2a52

Upload folder using huggingface_hub

Browse files
configuration_internvl_chat.py CHANGED
@@ -12,7 +12,6 @@ from transformers.utils import logging
12
 
13
  from .configuration_intern_vit import InternVisionConfig
14
 
15
-
16
  logger = logging.get_logger(__name__)
17
 
18
 
@@ -52,6 +51,8 @@ class InternVLChatConfig(PretrainedConfig):
52
  self.downsample_ratio = downsample_ratio
53
  self.template = template
54
 
 
 
55
  def to_dict(self):
56
  """
57
  Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
 
12
 
13
  from .configuration_intern_vit import InternVisionConfig
14
 
 
15
  logger = logging.get_logger(__name__)
16
 
17
 
 
51
  self.downsample_ratio = downsample_ratio
52
  self.template = template
53
 
54
+ logger.info(f'vision_select_layer: {self.select_layer}')
55
+
56
  def to_dict(self):
57
  """
58
  Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
conversation.py CHANGED
@@ -1211,3 +1211,46 @@ register_conv_template(
1211
  sep2='</s>',
1212
  )
1213
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1211
  sep2='</s>',
1212
  )
1213
  )
1214
+
1215
+
1216
+ if __name__ == '__main__':
1217
+ from fastchat.conversation import get_conv_template
1218
+
1219
+ print('-- Vicuna template --')
1220
+ conv = get_conv_template('vicuna_v1.1')
1221
+ conv.append_message(conv.roles[0], 'Hello!')
1222
+ conv.append_message(conv.roles[1], 'Hi!')
1223
+ conv.append_message(conv.roles[0], 'How are you?')
1224
+ conv.append_message(conv.roles[1], None)
1225
+ print(conv.get_prompt())
1226
+
1227
+ print('\n')
1228
+
1229
+ print('-- Llama-2 template --')
1230
+ conv = get_conv_template('llama-2')
1231
+ conv.set_system_message('You are a helpful, respectful and honest assistant.')
1232
+ conv.append_message(conv.roles[0], 'Hello!')
1233
+ conv.append_message(conv.roles[1], 'Hi!')
1234
+ conv.append_message(conv.roles[0], 'How are you?')
1235
+ conv.append_message(conv.roles[1], None)
1236
+ print(conv.get_prompt())
1237
+
1238
+ print('\n')
1239
+
1240
+ print('-- ChatGPT template --')
1241
+ conv = get_conv_template('chatgpt')
1242
+ conv.append_message(conv.roles[0], 'Hello!')
1243
+ conv.append_message(conv.roles[1], 'Hi!')
1244
+ conv.append_message(conv.roles[0], 'How are you?')
1245
+ conv.append_message(conv.roles[1], None)
1246
+ print(conv.to_openai_api_messages())
1247
+
1248
+ print('\n')
1249
+
1250
+ print('-- Claude template --')
1251
+ conv = get_conv_template('claude')
1252
+ conv.append_message(conv.roles[0], 'Hello!')
1253
+ conv.append_message(conv.roles[1], 'Hi!')
1254
+ conv.append_message(conv.roles[0], 'How are you?')
1255
+ conv.append_message(conv.roles[1], None)
1256
+ print(conv.get_prompt())
modeling_internvl_chat.py CHANGED
@@ -3,16 +3,21 @@
3
  # Copyright (c) 2023 OpenGVLab
4
  # Licensed under The MIT License [see LICENSE for details]
5
  # --------------------------------------------------------
 
6
  from typing import Any, List, Optional, Tuple, Union
7
-
8
  import torch.utils.checkpoint
9
  from peft import LoraConfig, get_peft_model
10
  from torch import nn
11
  from torch.nn import CrossEntropyLoss
12
  from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
 
 
 
13
  from transformers.modeling_outputs import CausalLMOutputWithPast
14
  from transformers.modeling_utils import PreTrainedModel
15
  from transformers.utils import ModelOutput, logging
 
16
 
17
  from .configuration_internvl_chat import InternVLChatConfig
18
  from .modeling_intern_vit import InternVisionModel
@@ -20,10 +25,183 @@ from .modeling_intern_vit import InternVisionModel
20
  logger = logging.get_logger(__name__)
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class InternVLChatModel(PreTrainedModel):
24
  config_class = InternVLChatConfig
25
  main_input_name = 'pixel_values'
26
- _no_split_modules = ['InternVisionEncoderLayer', 'LlamaDecoderLayer', 'LlamaForCausalLM']
27
 
28
  def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
29
  super().__init__(config)
@@ -33,6 +211,7 @@ class InternVLChatModel(PreTrainedModel):
33
  self.select_layer = config.select_layer
34
  self.template = config.template
35
  self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
 
36
  logger.info(f'num_image_token: {self.num_image_token}')
37
  if vision_model is not None:
38
  self.vision_model = vision_model
@@ -41,7 +220,8 @@ class InternVLChatModel(PreTrainedModel):
41
  if language_model is not None:
42
  self.language_model = language_model
43
  else:
44
- self.language_model = LlamaForCausalLM(config.llm_config)
 
45
  vit_hidden_size = config.vision_config.hidden_size
46
  llm_hidden_size = config.llm_config.hidden_size
47
 
@@ -52,7 +232,7 @@ class InternVLChatModel(PreTrainedModel):
52
  nn.Linear(llm_hidden_size, llm_hidden_size)
53
  )
54
 
55
- if config.force_image_size:
56
  self.vision_model.resize_pos_embeddings(
57
  old_size=config.vision_config.image_size,
58
  new_size=config.force_image_size,
@@ -173,16 +353,22 @@ class InternVLChatModel(PreTrainedModel):
173
  return x
174
 
175
  def extract_feature(self, pixel_values):
176
- vit_embeds = self.vision_model(
177
- pixel_values=pixel_values,
178
- output_hidden_states=True,
179
- return_dict=True).hidden_states[-4]
 
 
 
 
 
 
180
  vit_embeds = vit_embeds[:, 1:, :]
181
  # if torch.distributed.get_rank() == 0:
182
  # print("before pixel shuffle:", vit_embeds.shape)
183
  h = w = int(vit_embeds.shape[1] ** 0.5)
184
  vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
185
- vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=0.5)
186
  vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
187
  # if torch.distributed.get_rank() == 0:
188
  # print("after pixel shuffle:", vit_embeds.shape)
@@ -194,6 +380,7 @@ class InternVLChatModel(PreTrainedModel):
194
 
195
  img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
196
  self.img_context_token_id = img_context_token_id
 
197
  from .conversation import get_conv_template
198
 
199
  template = get_conv_template(self.template)
@@ -243,7 +430,7 @@ class InternVLChatModel(PreTrainedModel):
243
  input_ids = input_ids.reshape(B * N)
244
  selected = (input_ids == self.img_context_token_id)
245
  assert selected.sum() != 0
246
- input_embeds[selected] = vit_embeds.reshape(-1, C)
247
 
248
  input_embeds = input_embeds.reshape(B, N, C)
249
  else:
 
3
  # Copyright (c) 2023 OpenGVLab
4
  # Licensed under The MIT License [see LICENSE for details]
5
  # --------------------------------------------------------
6
+ import warnings
7
  from typing import Any, List, Optional, Tuple, Union
8
+ import torch.distributed as dist
9
  import torch.utils.checkpoint
10
  from peft import LoraConfig, get_peft_model
11
  from torch import nn
12
  from torch.nn import CrossEntropyLoss
13
  from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
14
+ from transformers.generation.logits_process import LogitsProcessorList
15
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
16
+ from transformers.generation.streamers import BaseStreamer
17
  from transformers.modeling_outputs import CausalLMOutputWithPast
18
  from transformers.modeling_utils import PreTrainedModel
19
  from transformers.utils import ModelOutput, logging
20
+ from transformers.generation.utils import GreedySearchOutput, validate_stopping_criteria, GreedySearchDecoderOnlyOutput,GreedySearchEncoderDecoderOutput
21
 
22
  from .configuration_internvl_chat import InternVLChatConfig
23
  from .modeling_intern_vit import InternVisionModel
 
25
  logger = logging.get_logger(__name__)
26
 
27
 
28
+ # modified from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
29
+ # Fix bug when using device_map='auto' for distributed inference
30
+ class MLlamaForCausalLM(LlamaForCausalLM):
31
+
32
+ def greedy_search(
33
+ self,
34
+ input_ids: torch.LongTensor,
35
+ logits_processor: Optional[LogitsProcessorList] = None,
36
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
37
+ max_length: Optional[int] = None,
38
+ pad_token_id: Optional[int] = None,
39
+ eos_token_id: Optional[Union[int, List[int]]] = None,
40
+ output_attentions: Optional[bool] = None,
41
+ output_hidden_states: Optional[bool] = None,
42
+ output_scores: Optional[bool] = None,
43
+ return_dict_in_generate: Optional[bool] = None,
44
+ synced_gpus: bool = False,
45
+ streamer: Optional["BaseStreamer"] = None,
46
+ **model_kwargs,
47
+ ) -> Union[GreedySearchOutput, torch.LongTensor]:
48
+ # init values
49
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
50
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
51
+ if max_length is not None:
52
+ warnings.warn(
53
+ "`max_length` is deprecated in this function, use"
54
+ " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
55
+ UserWarning,
56
+ )
57
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
58
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
59
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
60
+ if isinstance(eos_token_id, int):
61
+ eos_token_id = [eos_token_id]
62
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
63
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
64
+ output_attentions = (
65
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
66
+ )
67
+ output_hidden_states = (
68
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
69
+ )
70
+ return_dict_in_generate = (
71
+ return_dict_in_generate
72
+ if return_dict_in_generate is not None
73
+ else self.generation_config.return_dict_in_generate
74
+ )
75
+
76
+ # init attention / hidden states / scores tuples
77
+ scores = () if (return_dict_in_generate and output_scores) else None
78
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
79
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
80
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
81
+
82
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
83
+ if return_dict_in_generate and self.config.is_encoder_decoder:
84
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
85
+ encoder_hidden_states = (
86
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
87
+ )
88
+
89
+ # keep track of which sequences are already finished
90
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
91
+
92
+ this_peer_finished = False # used by synced_gpus only
93
+ while True:
94
+ if synced_gpus:
95
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
96
+ # The following logic allows an early break if all peers finished generating their sequence
97
+ this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
98
+ # send 0.0 if we finished, 1.0 otherwise
99
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
100
+ # did all peers finish? the reduced sum will be 0.0 then
101
+ if this_peer_finished_flag.item() == 0.0:
102
+ break
103
+
104
+ # prepare model inputs
105
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
106
+
107
+ # forward pass to get next token
108
+ outputs = self(
109
+ **model_inputs,
110
+ return_dict=True,
111
+ output_attentions=output_attentions,
112
+ output_hidden_states=output_hidden_states,
113
+ )
114
+
115
+ if synced_gpus and this_peer_finished:
116
+ continue # don't waste resources running the code we don't need
117
+
118
+ next_token_logits = outputs.logits[:, -1, :]
119
+
120
+ # pre-process distribution
121
+ next_tokens_scores = logits_processor(input_ids, next_token_logits)
122
+
123
+ # Store scores, attentions and hidden_states when required
124
+ if return_dict_in_generate:
125
+ if output_scores:
126
+ scores += (next_tokens_scores,)
127
+ if output_attentions:
128
+ decoder_attentions += (
129
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
130
+ )
131
+ if self.config.is_encoder_decoder:
132
+ cross_attentions += (outputs.cross_attentions,)
133
+
134
+ if output_hidden_states:
135
+ decoder_hidden_states += (
136
+ (outputs.decoder_hidden_states,)
137
+ if self.config.is_encoder_decoder
138
+ else (outputs.hidden_states,)
139
+ )
140
+
141
+ # argmax
142
+ next_tokens = torch.argmax(next_tokens_scores, dim=-1).to(device=input_ids.device)
143
+ # finished sentences should have their next token be a padding token
144
+ if eos_token_id is not None:
145
+ if pad_token_id is None:
146
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
147
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
148
+
149
+ # update generated ids, model inputs, and length for next step
150
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
151
+ if streamer is not None:
152
+ streamer.put(next_tokens.cpu())
153
+ model_kwargs = self._update_model_kwargs_for_generation(
154
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
155
+ )
156
+
157
+ # if eos_token was found in one sentence, set sentence to finished
158
+ if eos_token_id_tensor is not None:
159
+ unfinished_sequences = unfinished_sequences.mul(
160
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
161
+ )
162
+
163
+ # stop when each sentence is finished
164
+ if unfinished_sequences.max() == 0:
165
+ this_peer_finished = True
166
+
167
+ # stop if we exceed the maximum length
168
+ if stopping_criteria(input_ids, scores):
169
+ this_peer_finished = True
170
+
171
+ if this_peer_finished and not synced_gpus:
172
+ break
173
+
174
+ if streamer is not None:
175
+ streamer.end()
176
+
177
+ if return_dict_in_generate:
178
+ if self.config.is_encoder_decoder:
179
+ return GreedySearchEncoderDecoderOutput(
180
+ sequences=input_ids,
181
+ scores=scores,
182
+ encoder_attentions=encoder_attentions,
183
+ encoder_hidden_states=encoder_hidden_states,
184
+ decoder_attentions=decoder_attentions,
185
+ cross_attentions=cross_attentions,
186
+ decoder_hidden_states=decoder_hidden_states,
187
+ past_key_values=model_kwargs.get("past_key_values"),
188
+ )
189
+ else:
190
+ return GreedySearchDecoderOnlyOutput(
191
+ sequences=input_ids,
192
+ scores=scores,
193
+ attentions=decoder_attentions,
194
+ hidden_states=decoder_hidden_states,
195
+ past_key_values=model_kwargs.get("past_key_values"),
196
+ )
197
+ else:
198
+ return input_ids
199
+
200
+
201
  class InternVLChatModel(PreTrainedModel):
202
  config_class = InternVLChatConfig
203
  main_input_name = 'pixel_values'
204
+ _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer']
205
 
206
  def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
207
  super().__init__(config)
 
211
  self.select_layer = config.select_layer
212
  self.template = config.template
213
  self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
214
+ self.downsample_ratio = config.downsample_ratio
215
  logger.info(f'num_image_token: {self.num_image_token}')
216
  if vision_model is not None:
217
  self.vision_model = vision_model
 
220
  if language_model is not None:
221
  self.language_model = language_model
222
  else:
223
+ # self.language_model = LlamaForCausalLM(config.llm_config)
224
+ self.language_model = MLlamaForCausalLM(config.llm_config)
225
  vit_hidden_size = config.vision_config.hidden_size
226
  llm_hidden_size = config.llm_config.hidden_size
227
 
 
232
  nn.Linear(llm_hidden_size, llm_hidden_size)
233
  )
234
 
235
+ if config.force_image_size != config.vision_config.image_size:
236
  self.vision_model.resize_pos_embeddings(
237
  old_size=config.vision_config.image_size,
238
  new_size=config.force_image_size,
 
353
  return x
354
 
355
  def extract_feature(self, pixel_values):
356
+ if self.select_layer == -1:
357
+ vit_embeds = self.vision_model(
358
+ pixel_values=pixel_values,
359
+ output_hidden_states=False,
360
+ return_dict=True).last_hidden_state
361
+ else:
362
+ vit_embeds = self.vision_model(
363
+ pixel_values=pixel_values,
364
+ output_hidden_states=True,
365
+ return_dict=True).hidden_states[self.select_layer]
366
  vit_embeds = vit_embeds[:, 1:, :]
367
  # if torch.distributed.get_rank() == 0:
368
  # print("before pixel shuffle:", vit_embeds.shape)
369
  h = w = int(vit_embeds.shape[1] ** 0.5)
370
  vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
371
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
372
  vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
373
  # if torch.distributed.get_rank() == 0:
374
  # print("after pixel shuffle:", vit_embeds.shape)
 
380
 
381
  img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
382
  self.img_context_token_id = img_context_token_id
383
+
384
  from .conversation import get_conv_template
385
 
386
  template = get_conv_template(self.template)
 
430
  input_ids = input_ids.reshape(B * N)
431
  selected = (input_ids == self.img_context_token_id)
432
  assert selected.sum() != 0
433
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
434
 
435
  input_embeds = input_embeds.reshape(B, N, C)
436
  else: