Spaces:
Running
Running
Annas Dev
commited on
Commit
•
868f784
1
Parent(s):
92c1964
add bit model
Browse files
src/similarity/model_implements/bit.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow_hub as hub
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
class BigTransfer:
|
5 |
+
|
6 |
+
def __init__(self):
|
7 |
+
self.module = hub.KerasLayer("https://tfhub.dev/google/bit/m-r50x1/1")
|
8 |
+
|
9 |
+
def extract_feature(self, imgs):
|
10 |
+
features = []
|
11 |
+
for img in imgs:
|
12 |
+
features.append(np.squeeze(self.module(img)))
|
13 |
+
return features
|
src/similarity/similarity.py
CHANGED
@@ -3,17 +3,18 @@ from src.util import image as image_util
|
|
3 |
from src.util import matrix
|
4 |
from .model_implements.mobilenet_v3 import ModelnetV3
|
5 |
from .model_implements.vit_base import VitBase
|
|
|
6 |
|
7 |
|
8 |
class Similarity:
|
9 |
def get_models(self):
|
10 |
return [
|
11 |
model.SimilarityModel(name= 'Mobilenet V3', image_size= 224, model_cls = ModelnetV3()),
|
|
|
12 |
model.SimilarityModel(name= 'Vision Transformer', image_size= 224, model_cls = VitBase(), image_input_type='pil'),
|
13 |
]
|
14 |
|
15 |
def check_similarity(self, img_urls, model):
|
16 |
-
# model = self.get_models()[model_idx]
|
17 |
imgs = []
|
18 |
for url in img_urls:
|
19 |
if url == "": continue
|
|
|
3 |
from src.util import matrix
|
4 |
from .model_implements.mobilenet_v3 import ModelnetV3
|
5 |
from .model_implements.vit_base import VitBase
|
6 |
+
from .model_implements.bit import BigTransfer
|
7 |
|
8 |
|
9 |
class Similarity:
|
10 |
def get_models(self):
|
11 |
return [
|
12 |
model.SimilarityModel(name= 'Mobilenet V3', image_size= 224, model_cls = ModelnetV3()),
|
13 |
+
model.SimilarityModel(name= 'Big Transfer (BiT)', image_size= 224, model_cls = BigTransfer()),
|
14 |
model.SimilarityModel(name= 'Vision Transformer', image_size= 224, model_cls = VitBase(), image_input_type='pil'),
|
15 |
]
|
16 |
|
17 |
def check_similarity(self, img_urls, model):
|
|
|
18 |
imgs = []
|
19 |
for url in img_urls:
|
20 |
if url == "": continue
|