NarendraMall commited on
Commit
099bc26
1 Parent(s): 9e86261

updated vit

Browse files
src/similarity/model_implements/vit_base.py CHANGED
@@ -12,7 +12,12 @@ class VitBase():
12
  def extract_feature(self, imgs):
13
  features = []
14
  for img in imgs:
 
 
 
 
15
  inputs = self.feature_extractor(images=img, return_tensors="pt")
 
16
  with torch.no_grad():
17
  outputs = self.model(**inputs)
18
  last_hidden_states = outputs.last_hidden_state
 
12
  def extract_feature(self, imgs):
13
  features = []
14
  for img in imgs:
15
+ # Convert the image to RGB if it has 4 channels
16
+ if img.mode == 'RGBA':
17
+ img = img.convert('RGB')
18
+
19
  inputs = self.feature_extractor(images=img, return_tensors="pt")
20
+ # print("input shape: ", inputs.shape)
21
  with torch.no_grad():
22
  outputs = self.model(**inputs)
23
  last_hidden_states = outputs.last_hidden_state