qnguyen3 commited on
Commit
17d73ee
1 Parent(s): 40f0486

Update modeling_llava_qwen2.py

Browse files
Files changed (1) hide show
  1. modeling_llava_qwen2.py +5 -1
modeling_llava_qwen2.py CHANGED
@@ -12,6 +12,7 @@ from PIL import Image
12
  import torch.utils.checkpoint
13
  from torch import nn
14
  import torch
 
15
  from transformers.image_processing_utils import BatchFeature, get_size_dict
16
  from transformers.image_transforms import (convert_to_rgb, normalize, rescale, resize, to_channel_dimension_format, )
17
  from transformers.image_utils import (ChannelDimension, PILImageResampling, to_numpy_array, )
@@ -534,6 +535,7 @@ class SigLipVisionTower(nn.Module):
534
  self.is_loaded = True
535
 
536
  @torch.no_grad()
 
537
  def forward(self, images):
538
  if type(images) is list:
539
  image_features = []
@@ -659,11 +661,13 @@ class LlavaMetaForCausalLM(ABC):
659
  def get_vision_tower(self):
660
  return self.get_model().get_vision_tower()
661
 
 
662
  def encode_images(self, images):
663
  image_features = self.get_model().get_vision_tower()(images)
664
  image_features = self.get_model().mm_projector(image_features)
665
  return image_features
666
-
 
667
  def prepare_inputs_labels_for_multimodal(
668
  self, input_ids, position_ids, attention_mask, past_key_values, labels, images
669
  ):
 
12
  import torch.utils.checkpoint
13
  from torch import nn
14
  import torch
15
+ import spaces
16
  from transformers.image_processing_utils import BatchFeature, get_size_dict
17
  from transformers.image_transforms import (convert_to_rgb, normalize, rescale, resize, to_channel_dimension_format, )
18
  from transformers.image_utils import (ChannelDimension, PILImageResampling, to_numpy_array, )
 
535
  self.is_loaded = True
536
 
537
  @torch.no_grad()
538
+ @spaces.GPU
539
  def forward(self, images):
540
  if type(images) is list:
541
  image_features = []
 
661
  def get_vision_tower(self):
662
  return self.get_model().get_vision_tower()
663
 
664
+ @spaces.GPU
665
  def encode_images(self, images):
666
  image_features = self.get_model().get_vision_tower()(images)
667
  image_features = self.get_model().mm_projector(image_features)
668
  return image_features
669
+
670
+ @spaces.GPU
671
  def prepare_inputs_labels_for_multimodal(
672
  self, input_ids, position_ids, attention_mask, past_key_values, labels, images
673
  ):