Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .gitattributes +0 -1
- .gitignore +2 -0
- README.md +4 -5
- app.py +32 -0
- requirements.txt +103 -0
- src/model/similarity_interface.py +3 -0
- src/model/simlarity_model.py +9 -0
- src/similarity/model_implements/bit.py +13 -0
- src/similarity/model_implements/mobilenet_v3.py +14 -0
- src/similarity/model_implements/vit_base.py +20 -0
- src/similarity/similarity.py +35 -0
- src/util/image.py +13 -0
- src/util/matrix.py +5 -0
.gitattributes
CHANGED
@@ -25,7 +25,6 @@
|
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
28 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
venv/
|
README.md
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
---
|
2 |
title: Image Similarity
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: mit
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: Image Similarity
|
3 |
+
emoji: ๐จ
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.16.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from src.model import simlarity_model as model
|
5 |
+
from src.similarity.similarity import Similarity
|
6 |
+
|
7 |
+
similarity = Similarity()
|
8 |
+
models = similarity.get_models()
|
9 |
+
|
10 |
+
def check(img_main, img_1, img_2, model_idx):
|
11 |
+
result = similarity.check_similarity([img_main, img_1, img_2], models[model_idx])
|
12 |
+
return result
|
13 |
+
|
14 |
+
with gr.Blocks() as demo:
|
15 |
+
gr.Markdown('Checking Image Similarity')
|
16 |
+
img_main = gr.Text(label='Main Image', placeholder='https://myimage.jpg')
|
17 |
+
|
18 |
+
gr.Markdown('Images to check')
|
19 |
+
img_1 = gr.Text(label='1st Image', placeholder='https://myimage_1.jpg')
|
20 |
+
img_2 = gr.Text(label='2nd Image', placeholder='https://myimage_2.jpg')
|
21 |
+
|
22 |
+
gr.Markdown('Choose the model')
|
23 |
+
model = gr.Dropdown([m.name for m in models], label='Model', type='index')
|
24 |
+
|
25 |
+
gallery = gr.Gallery(
|
26 |
+
label="Generated images", show_label=False, elem_id="gallery"
|
27 |
+
).style(grid=[2], height="auto")
|
28 |
+
|
29 |
+
submit_btn = gr.Button('Check Similarity')
|
30 |
+
submit_btn.click(fn=check,inputs=[img_main, img_1, img_2, model], outputs=gallery)
|
31 |
+
|
32 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.3.0
|
2 |
+
aiohttp==3.8.3
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==4.2.0
|
5 |
+
anyio==3.6.2
|
6 |
+
astunparse==1.6.3
|
7 |
+
async-timeout==4.0.2
|
8 |
+
attrs==22.2.0
|
9 |
+
cachetools==5.2.0
|
10 |
+
certifi==2022.12.7
|
11 |
+
charset-normalizer==2.1.1
|
12 |
+
click==8.1.3
|
13 |
+
contourpy==1.0.6
|
14 |
+
cycler==0.11.0
|
15 |
+
entrypoints==0.4
|
16 |
+
fastapi==0.88.0
|
17 |
+
ffmpy==0.3.0
|
18 |
+
filelock==3.9.0
|
19 |
+
flatbuffers==23.1.4
|
20 |
+
fonttools==4.38.0
|
21 |
+
frozenlist==1.3.3
|
22 |
+
fsspec==2022.11.0
|
23 |
+
gast==0.4.0
|
24 |
+
google-auth==2.15.0
|
25 |
+
google-auth-oauthlib==0.4.6
|
26 |
+
google-pasta==0.2.0
|
27 |
+
gradio==3.16.0
|
28 |
+
grpcio==1.51.1
|
29 |
+
h11==0.14.0
|
30 |
+
h5py==3.7.0
|
31 |
+
httpcore==0.16.3
|
32 |
+
httpx==0.23.3
|
33 |
+
huggingface-hub==0.11.1
|
34 |
+
idna==3.4
|
35 |
+
importlib-metadata==6.0.0
|
36 |
+
importlib-resources==5.10.2
|
37 |
+
Jinja2==3.1.2
|
38 |
+
jsonschema==4.17.3
|
39 |
+
keras==2.11.0
|
40 |
+
kiwisolver==1.4.4
|
41 |
+
libclang==14.0.6
|
42 |
+
linkify-it-py==1.0.3
|
43 |
+
Markdown==3.4.1
|
44 |
+
markdown-it-py==2.1.0
|
45 |
+
MarkupSafe==2.1.1
|
46 |
+
matplotlib==3.6.2
|
47 |
+
mdit-py-plugins==0.3.3
|
48 |
+
mdurl==0.1.2
|
49 |
+
multidict==6.0.4
|
50 |
+
numpy==1.24.1
|
51 |
+
nvidia-cublas-cu11==11.10.3.66
|
52 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
53 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
54 |
+
nvidia-cudnn-cu11==8.5.0.96
|
55 |
+
oauthlib==3.2.2
|
56 |
+
opt-einsum==3.3.0
|
57 |
+
orjson==3.8.4
|
58 |
+
packaging==22.0
|
59 |
+
pandas==1.5.2
|
60 |
+
Pillow==9.4.0
|
61 |
+
pkgutil_resolve_name==1.3.10
|
62 |
+
protobuf==3.19.6
|
63 |
+
pyasn1==0.4.8
|
64 |
+
pyasn1-modules==0.2.8
|
65 |
+
pycryptodome==3.16.0
|
66 |
+
pydantic==1.10.4
|
67 |
+
pydub==0.25.1
|
68 |
+
pyparsing==3.0.9
|
69 |
+
pyrsistent==0.19.3
|
70 |
+
python-dateutil==2.8.2
|
71 |
+
python-multipart==0.0.5
|
72 |
+
pytz==2022.7
|
73 |
+
PyYAML==6.0
|
74 |
+
regex==2022.10.31
|
75 |
+
requests==2.28.1
|
76 |
+
requests-oauthlib==1.3.1
|
77 |
+
rfc3986==1.5.0
|
78 |
+
rsa==4.9
|
79 |
+
six==1.16.0
|
80 |
+
sniffio==1.3.0
|
81 |
+
starlette==0.22.0
|
82 |
+
tensorboard==2.11.0
|
83 |
+
tensorboard-data-server==0.6.1
|
84 |
+
tensorboard-plugin-wit==1.8.1
|
85 |
+
tensorflow==2.11.0
|
86 |
+
tensorflow-estimator==2.11.0
|
87 |
+
tensorflow-hub==0.12.0
|
88 |
+
tensorflow-io-gcs-filesystem==0.29.0
|
89 |
+
termcolor==2.2.0
|
90 |
+
tokenizers==0.13.2
|
91 |
+
toolz==0.12.0
|
92 |
+
torch==1.13.1
|
93 |
+
tqdm==4.64.1
|
94 |
+
transformers==4.25.1
|
95 |
+
typing_extensions==4.4.0
|
96 |
+
uc-micro-py==1.0.1
|
97 |
+
urllib3==1.26.13
|
98 |
+
uvicorn==0.20.0
|
99 |
+
websockets==10.4
|
100 |
+
Werkzeug==2.2.2
|
101 |
+
wrapt==1.14.1
|
102 |
+
yarl==1.8.2
|
103 |
+
zipp==3.11.0
|
src/model/similarity_interface.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
class SimilarityInterface:
|
2 |
+
def extract_feature(img):
|
3 |
+
return []
|
src/model/simlarity_model.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from .similarity_interface import SimilarityInterface
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class SimilarityModel:
|
6 |
+
name: str
|
7 |
+
image_size: int
|
8 |
+
model_cls: SimilarityInterface
|
9 |
+
image_input_type: str = 'array'
|
src/similarity/model_implements/bit.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow_hub as hub
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
class BigTransfer:
|
5 |
+
|
6 |
+
def __init__(self):
|
7 |
+
self.module = hub.KerasLayer("https://tfhub.dev/google/bit/m-r50x1/1")
|
8 |
+
|
9 |
+
def extract_feature(self, imgs):
|
10 |
+
features = []
|
11 |
+
for img in imgs:
|
12 |
+
features.append(np.squeeze(self.module(img)))
|
13 |
+
return features
|
src/similarity/model_implements/mobilenet_v3.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow_hub as hub
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
class ModelnetV3():
|
5 |
+
def __init__(self):
|
6 |
+
module_handle = "https://tfhub.dev/google/imagenet/mobilenet_v3_large_100_224/feature_vector/5"
|
7 |
+
self.module = hub.load(module_handle)
|
8 |
+
|
9 |
+
def extract_feature(self, imgs):
|
10 |
+
print('getting with ModelnetV3...')
|
11 |
+
features = []
|
12 |
+
for img in imgs:
|
13 |
+
features.append(np.squeeze(self.module(img)))
|
14 |
+
return features
|
src/similarity/model_implements/vit_base.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import ViTFeatureExtractor, ViTModel
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
class VitBase():
|
7 |
+
|
8 |
+
def __init__(self):
|
9 |
+
self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
|
10 |
+
self.model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
|
11 |
+
|
12 |
+
def extract_feature(self, imgs):
|
13 |
+
features = []
|
14 |
+
for img in imgs:
|
15 |
+
inputs = self.feature_extractor(images=img, return_tensors="pt")
|
16 |
+
with torch.no_grad():
|
17 |
+
outputs = self.model(**inputs)
|
18 |
+
last_hidden_states = outputs.last_hidden_state
|
19 |
+
features.append(np.squeeze(last_hidden_states.numpy()).flatten())
|
20 |
+
return features
|
src/similarity/similarity.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.model import simlarity_model as model
|
2 |
+
from src.util import image as image_util
|
3 |
+
from src.util import matrix
|
4 |
+
from .model_implements.mobilenet_v3 import ModelnetV3
|
5 |
+
from .model_implements.vit_base import VitBase
|
6 |
+
from .model_implements.bit import BigTransfer
|
7 |
+
|
8 |
+
|
9 |
+
class Similarity:
|
10 |
+
def get_models(self):
|
11 |
+
return [
|
12 |
+
model.SimilarityModel(name= 'Mobilenet V3', image_size= 224, model_cls = ModelnetV3()),
|
13 |
+
model.SimilarityModel(name= 'Big Transfer (BiT)', image_size= 224, model_cls = BigTransfer()),
|
14 |
+
model.SimilarityModel(name= 'Vision Transformer', image_size= 224, model_cls = VitBase(), image_input_type='pil'),
|
15 |
+
]
|
16 |
+
|
17 |
+
def check_similarity(self, img_urls, model):
|
18 |
+
imgs = []
|
19 |
+
for url in img_urls:
|
20 |
+
if url == "": continue
|
21 |
+
imgs.append(image_util.load_image_url(url, required_size=(model.image_size, model.image_size), image_type=model.image_input_type))
|
22 |
+
|
23 |
+
features = model.model_cls.extract_feature(imgs)
|
24 |
+
results = []
|
25 |
+
for i, v in enumerate(features):
|
26 |
+
if i == 0: continue
|
27 |
+
dist = matrix.cosine(features[0], v)
|
28 |
+
print(f'{i} -- distance: {dist}')
|
29 |
+
# results.append((imgs[i], f'similarity: {int(dist*100)}%'))
|
30 |
+
original_img = image_util.load_image_url(img_urls[i], required_size=None, image_type='pil')
|
31 |
+
results.append((original_img, f'similarity: {int(dist*100)}%'))
|
32 |
+
|
33 |
+
return results
|
34 |
+
|
35 |
+
|
src/util/image.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
import requests
|
4 |
+
|
5 |
+
def load_image_url(url, required_size = (224,224), image_type = 'array'):
|
6 |
+
print(f'downloading.. {url}, type: {image_type}')
|
7 |
+
img = Image.open(requests.get(url, stream=True).raw)
|
8 |
+
img = Image.fromarray(np.array(img))
|
9 |
+
if required_size is not None:
|
10 |
+
img = img.resize(required_size)
|
11 |
+
if image_type == 'array':
|
12 |
+
img = (np.expand_dims(np.array(img), 0)/255).astype(np.float32)
|
13 |
+
return img
|
src/util/matrix.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numpy.linalg import norm
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def cosine(x, y):
|
5 |
+
return np.dot(x,y)/(norm(x)*norm(y))
|