Annas Dev commited on
Commit
8f93744
1 Parent(s): 4cb8afb
src/similarity/model_implements/vit_base.py CHANGED
@@ -1,16 +1,21 @@
1
  from transformers import ViTFeatureExtractor, ViTForImageClassification
2
  from PIL import Image
 
3
 
4
  class VitBase():
5
 
6
  def __init__(self):
7
- self.model = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
 
8
 
9
  def extract_feature(self, imgs):
10
  features = []
11
  for img in imgs:
12
- feature = self.model(images=img, return_tensors="np")
13
- print('type::', type(**feature))
14
- features.append(feature)
15
- print(features[0].shape)
 
 
 
16
  return features
 
1
  from transformers import ViTFeatureExtractor, ViTForImageClassification
2
  from PIL import Image
3
+ import numpy as np
4
 
5
  class VitBase():
6
 
7
  def __init__(self):
8
+ self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
9
+ self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
10
 
11
  def extract_feature(self, imgs):
12
  features = []
13
  for img in imgs:
14
+ feature = self.feature_extractor(images=img, return_tensors="tf")
15
+ print('keys: ', feature.keys())
16
+ f = self.model(feature)
17
+ print('--> f', type(f))
18
+ # print('type::', (feature['pixel_values'].shape))
19
+ features.append(np.squeeze(feature['pixel_values']))
20
+ print('shape:::',features[0].shape)
21
  return features
src/similarity/similarity.py CHANGED
@@ -23,7 +23,7 @@ class Similarity:
23
  for i, v in enumerate(features):
24
  if i == 0: continue
25
  dist = matrix.cosine(features[0], v)
26
- print(f'distance: {dist}')
27
 
28
  return 'oke'
29
 
 
23
  for i, v in enumerate(features):
24
  if i == 0: continue
25
  dist = matrix.cosine(features[0], v)
26
+ # print(f'distance: {dist}')
27
 
28
  return 'oke'
29