qnguyen3 commited on
Commit
756d354
1 Parent(s): 1899e83

Update modeling_llava_qwen2.py

Browse files
Files changed (1) hide show
  1. modeling_llava_qwen2.py +4 -4
modeling_llava_qwen2.py CHANGED
@@ -535,13 +535,13 @@ class SigLipVisionTower(nn.Module):
535
  if type(images) is list:
536
  image_features = []
537
  for image in images:
538
- image_forward_out = self.vision_tower(image.unsqueeze(0),
539
  output_hidden_states=True)
540
  image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
541
  assert image_features.shape[-2] == 729
542
  image_features.append(image_feature)
543
  else:
544
- image_forward_outs = self.vision_tower(images,
545
  output_hidden_states=True)
546
  image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
547
  assert image_features.shape[-2] == 729
@@ -682,9 +682,9 @@ class LlavaMetaForCausalLM(ABC):
682
  image_features = self.encode_images(concat_images)
683
  split_sizes = [image.shape[0] for image in images]
684
  image_features = torch.split(image_features, split_sizes, dim=0)
685
- image_features = [x.flatten(0, 1) for x in image_features]
686
  else:
687
- image_features = self.encode_images(images)
688
 
689
  # Let's just add dummy tensors if they do not exist,
690
  # it is a headache to deal with None all the time.
 
535
  if type(images) is list:
536
  image_features = []
537
  for image in images:
538
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
539
  output_hidden_states=True)
540
  image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
541
  assert image_features.shape[-2] == 729
542
  image_features.append(image_feature)
543
  else:
544
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
545
  output_hidden_states=True)
546
  image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
547
  assert image_features.shape[-2] == 729
 
682
  image_features = self.encode_images(concat_images)
683
  split_sizes = [image.shape[0] for image in images]
684
  image_features = torch.split(image_features, split_sizes, dim=0)
685
+ image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
686
  else:
687
+ image_features = self.encode_images(images).to(self.device)
688
 
689
  # Let's just add dummy tensors if they do not exist,
690
  # it is a headache to deal with None all the time.