Adapt inputs_merger() Function for Scenarios Without Image Input

#5
Files changed (1) hide show
  1. modeling_vmistral.py +30 -28
modeling_vmistral.py CHANGED
@@ -1372,36 +1372,38 @@ class VMistralModel(VMistralPreTrainedModel):
1372
  batch_size = input_ids.size(0)
1373
 
1374
  if inputs_embeds is not None:
1375
- vision_pipeline_output_seq_len = image_hidden_states.shape[1]
1376
- vision_hidden_size = image_hidden_states.shape[2]
1377
  new_inputs_embeds = inputs_embeds.clone()
1378
- # Get the number of images for each example
1379
- num_images = (input_ids == self.image_token_id).sum(dim=-1) // self.image_seq_len
1380
- cum_num_images = num_images.cumsum(dim=-1)
1381
- for batch_idx in range(batch_size):
1382
- # Get the number of images for this particular example
1383
- example_num_images = num_images[batch_idx]
1384
- # Get the image_hidden_states corresponding to True images for the example, so get rid of the padding images.
1385
- start = 0 if batch_idx == 0 else cum_num_images[batch_idx - 1]
1386
- end = cum_num_images[batch_idx]
1387
- example_true_image_hidden_states = image_hidden_states[start:end]
1388
- if (
1389
- new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id].shape[0]
1390
- != example_num_images * vision_pipeline_output_seq_len
1391
- ):
1392
- raise ValueError(
1393
- "new_inputs_embeds to replace has shape[0]:"
1394
- f" {new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id].shape[0]} but"
1395
- " should have shape[0]:"
1396
- f" {example_num_images}*{vision_pipeline_output_seq_len}={example_num_images * vision_pipeline_output_seq_len} "
1397
- )
1398
- # Insert the image_hidden_states
1399
- new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id] = (
1400
- example_true_image_hidden_states.view(
1401
- example_num_images * vision_pipeline_output_seq_len,
1402
- vision_hidden_size,
 
 
 
 
 
1403
  )
1404
- )
1405
 
1406
  return_dict = {}
1407
  if inputs_embeds is not None:
 
1372
  batch_size = input_ids.size(0)
1373
 
1374
  if inputs_embeds is not None:
 
 
1375
  new_inputs_embeds = inputs_embeds.clone()
1376
+
1377
+ if image_hidden_states is not None:
1378
+ vision_pipeline_output_seq_len = image_hidden_states.shape[1]
1379
+ vision_hidden_size = image_hidden_states.shape[2]
1380
+ # Get the number of images for each example
1381
+ num_images = (input_ids == self.image_token_id).sum(dim=-1) // self.image_seq_len
1382
+ cum_num_images = num_images.cumsum(dim=-1)
1383
+ for batch_idx in range(batch_size):
1384
+ # Get the number of images for this particular example
1385
+ example_num_images = num_images[batch_idx]
1386
+ # Get the image_hidden_states corresponding to True images for the example, so get rid of the padding images.
1387
+ start = 0 if batch_idx == 0 else cum_num_images[batch_idx - 1]
1388
+ end = cum_num_images[batch_idx]
1389
+ example_true_image_hidden_states = image_hidden_states[start:end]
1390
+ if (
1391
+ new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id].shape[0]
1392
+ != example_num_images * vision_pipeline_output_seq_len
1393
+ ):
1394
+ raise ValueError(
1395
+ "new_inputs_embeds to replace has shape[0]:"
1396
+ f" {new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id].shape[0]} but"
1397
+ " should have shape[0]:"
1398
+ f" {example_num_images}*{vision_pipeline_output_seq_len}={example_num_images * vision_pipeline_output_seq_len} "
1399
+ )
1400
+ # Insert the image_hidden_states
1401
+ new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id] = (
1402
+ example_true_image_hidden_states.view(
1403
+ example_num_images * vision_pipeline_output_seq_len,
1404
+ vision_hidden_size,
1405
+ )
1406
  )
 
1407
 
1408
  return_dict = {}
1409
  if inputs_embeds is not None: