damerajee commited on
Commit
5595e84
1 Parent(s): d6792f7

Update vision_encoder.py

Browse files
Files changed (1) hide show
  1. vision_encoder.py +8 -7
vision_encoder.py CHANGED
@@ -1,9 +1,7 @@
1
- from transformers import ViTModel
2
  from torchvision import transforms
3
- import torch
4
- import torch.nn as nn
5
- import transformers
6
-
7
 
8
  transformers.logging.set_verbosity_error()
9
 
@@ -17,9 +15,12 @@ class VisionEncoder(nn.Module):
17
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
18
  ])
19
 
20
- def forward(self, images,device):
 
 
 
21
  processed_images = torch.stack([self.image_transform(image) for image in images]).to(device)
22
  with torch.no_grad():
23
  pixel_values = self.vision_model(processed_images)
24
  image_features = pixel_values.last_hidden_state
25
- return image_features
 
1
+ from transformers import ViTModel
2
  from torchvision import transforms
3
+ import torch
4
+ import torch.nn as nn
 
 
5
 
6
  transformers.logging.set_verbosity_error()
7
 
 
15
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
16
  ])
17
 
18
+ def forward(self, images, device):
19
+ if not isinstance(images, list):
20
+ images = [images]
21
+
22
  processed_images = torch.stack([self.image_transform(image) for image in images]).to(device)
23
  with torch.no_grad():
24
  pixel_values = self.vision_model(processed_images)
25
  image_features = pixel_values.last_hidden_state
26
+ return image_features