qnguyen3 commited on
Commit
85470c7
1 Parent(s): 07368a8

Update modeling_llava_qwen2.py

Browse files
Files changed (1) hide show
  1. modeling_llava_qwen2.py +3 -0
modeling_llava_qwen2.py CHANGED
@@ -11,6 +11,7 @@ from functools import partial, reduce
11
  from PIL import Image
12
  import torch.utils.checkpoint
13
  from torch import nn
 
14
  from transformers.image_processing_utils import BatchFeature, get_size_dict
15
  from transformers.image_transforms import (convert_to_rgb, normalize, rescale, resize, to_channel_dimension_format, )
16
  from transformers.image_utils import (ChannelDimension, PILImageResampling, to_numpy_array, )
@@ -18,6 +19,8 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
18
  from transformers.modeling_utils import PreTrainedModel
19
  from transformers.utils import ModelOutput
20
 
 
 
21
 
22
  class SigLipImageProcessor:
23
  def __init__(self,
 
11
  from PIL import Image
12
  import torch.utils.checkpoint
13
  from torch import nn
14
+ import torch
15
  from transformers.image_processing_utils import BatchFeature, get_size_dict
16
  from transformers.image_transforms import (convert_to_rgb, normalize, rescale, resize, to_channel_dimension_format, )
17
  from transformers.image_utils import (ChannelDimension, PILImageResampling, to_numpy_array, )
 
19
  from transformers.modeling_utils import PreTrainedModel
20
  from transformers.utils import ModelOutput
21
 
22
+ torch.set_default_device('cuda')
23
+
24
 
25
  class SigLipImageProcessor:
26
  def __init__(self,