VictorSanh commited on
Commit
7515eca
1 Parent(s): c8028be

big update

Browse files
Files changed (1) hide show
  1. modeling_img2html.py +10 -12
modeling_img2html.py CHANGED
@@ -162,7 +162,7 @@ def expand_inputs_for_generation(
162
  input_ids = input_ids.index_select(0, expanded_return_idx)
163
  model_kwargs["pixel_values"] = model_kwargs.get("pixel_values", None)
164
  model_kwargs["image_hidden_states"] = model_kwargs.get("image_hidden_states", None)
165
- model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None)
166
 
167
  if "token_type_ids" in model_kwargs:
168
  token_type_ids = model_kwargs["token_type_ids"]
@@ -180,9 +180,7 @@ def expand_inputs_for_generation(
180
  model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx)
181
 
182
  elif model_kwargs["image_hidden_states"] is not None:
183
- model_kwargs["image_hidden_states"] = model_kwargs["image_hidden_states"].index_select(
184
- 0, expanded_return_idx
185
- )
186
 
187
  return input_ids, model_kwargs
188
 
@@ -205,10 +203,10 @@ def update_model_kwargs_for_generation(outputs, model_kwargs):
205
  model_kwargs["attention_mask"] = torch.cat(
206
  [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
207
  )
208
- if "image_attention_mask" in model_kwargs:
209
- image_attention_mask = model_kwargs["image_attention_mask"]
210
- last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
211
- model_kwargs["image_attention_mask"] = last_mask
212
 
213
  # Get the precomputed image_hidden_states
214
  model_kwargs["image_hidden_states"] = outputs.image_hidden_states
@@ -236,7 +234,7 @@ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
236
 
237
  pixel_values = kwargs.get("pixel_values", None)
238
  image_hidden_states = kwargs.get("image_hidden_states", None)
239
- image_attention_mask = kwargs.get("image_attention_mask", None)
240
 
241
  return {
242
  "input_ids": input_ids,
@@ -247,7 +245,7 @@ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
247
  "token_type_ids": token_type_ids,
248
  "pixel_values": pixel_values,
249
  "image_hidden_states": image_hidden_states,
250
- "image_attention_mask": image_attention_mask,
251
  }
252
 
253
 
@@ -1373,7 +1371,6 @@ class VMistralModel(VMistralPreTrainedModel):
1373
  input_ids: torch.LongTensor = None,
1374
  inputs_embeds: Optional[torch.Tensor] = None,
1375
  image_hidden_states: Optional[torch.Tensor] = None,
1376
- num_images: Optional[int] = None,
1377
  ):
1378
  """
1379
  This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
@@ -1496,6 +1493,8 @@ class VMistralModel(VMistralPreTrainedModel):
1496
 
1497
  if self.config.use_resampler:
1498
  image_hidden_states = self.perceiver_resampler(image_hidden_states)
 
 
1499
 
1500
  if past_key_values is None:
1501
  # When we generate, we don't want to replace the potential image_token_id that we generated by images
@@ -1504,7 +1503,6 @@ class VMistralModel(VMistralPreTrainedModel):
1504
  input_ids=input_ids,
1505
  inputs_embeds=inputs_embeds,
1506
  image_hidden_states=image_hidden_states,
1507
- num_images=num_images,
1508
  )
1509
  inputs_embeds = new_inp["inputs_embeds"]
1510
 
 
162
  input_ids = input_ids.index_select(0, expanded_return_idx)
163
  model_kwargs["pixel_values"] = model_kwargs.get("pixel_values", None)
164
  model_kwargs["image_hidden_states"] = model_kwargs.get("image_hidden_states", None)
165
+ # model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None)
166
 
167
  if "token_type_ids" in model_kwargs:
168
  token_type_ids = model_kwargs["token_type_ids"]
 
180
  model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx)
181
 
182
  elif model_kwargs["image_hidden_states"] is not None:
183
+ model_kwargs["image_hidden_states"] = model_kwargs["image_hidden_states"].index_select(0, expanded_return_idx)
 
 
184
 
185
  return input_ids, model_kwargs
186
 
 
203
  model_kwargs["attention_mask"] = torch.cat(
204
  [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
205
  )
206
+ # if "image_attention_mask" in model_kwargs:
207
+ # image_attention_mask = model_kwargs["image_attention_mask"]
208
+ # last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
209
+ # model_kwargs["image_attention_mask"] = last_mask
210
 
211
  # Get the precomputed image_hidden_states
212
  model_kwargs["image_hidden_states"] = outputs.image_hidden_states
 
234
 
235
  pixel_values = kwargs.get("pixel_values", None)
236
  image_hidden_states = kwargs.get("image_hidden_states", None)
237
+ # image_attention_mask = kwargs.get("image_attention_mask", None)
238
 
239
  return {
240
  "input_ids": input_ids,
 
245
  "token_type_ids": token_type_ids,
246
  "pixel_values": pixel_values,
247
  "image_hidden_states": image_hidden_states,
248
+ # "image_attention_mask": image_attention_mask,
249
  }
250
 
251
 
 
1371
  input_ids: torch.LongTensor = None,
1372
  inputs_embeds: Optional[torch.Tensor] = None,
1373
  image_hidden_states: Optional[torch.Tensor] = None,
 
1374
  ):
1375
  """
1376
  This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
 
1493
 
1494
  if self.config.use_resampler:
1495
  image_hidden_states = self.perceiver_resampler(image_hidden_states)
1496
+ elif image_hidden_states is not None:
1497
+ image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
1498
 
1499
  if past_key_values is None:
1500
  # When we generate, we don't want to replace the potential image_token_id that we generated by images
 
1503
  input_ids=input_ids,
1504
  inputs_embeds=inputs_embeds,
1505
  image_hidden_states=image_hidden_states,
 
1506
  )
1507
  inputs_embeds = new_inp["inputs_embeds"]
1508