Spaces:
Running
Running
Show matching images
Browse files- app.py +11 -2
- requirements.txt +1 -0
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
|
|
|
|
4 |
from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor
|
5 |
from medclip.modeling_hybrid_clip import FlaxHybridCLIP
|
6 |
|
@@ -17,10 +19,10 @@ def load_image_embeddings():
|
|
17 |
image_files = np.asarray(embeddings_df['files'].tolist())
|
18 |
return image_files, image_embeds
|
19 |
|
20 |
-
# def app():
|
21 |
k = 5
|
22 |
image_list, image_embeddings = load_image_embeddings()
|
23 |
model, processor = load_model()
|
|
|
24 |
|
25 |
query = st.text_input("Search:")
|
26 |
|
@@ -34,5 +36,12 @@ if st.button("Search"):
|
|
34 |
query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
|
35 |
dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
|
36 |
matching_images = image_list[dot_prod.argsort()[-k:]]
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
+
import os
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor
|
7 |
from medclip.modeling_hybrid_clip import FlaxHybridCLIP
|
8 |
|
|
|
19 |
image_files = np.asarray(embeddings_df['files'].tolist())
|
20 |
return image_files, image_embeds
|
21 |
|
|
|
22 |
k = 5
|
23 |
image_list, image_embeddings = load_image_embeddings()
|
24 |
model, processor = load_model()
|
25 |
+
img_dir = './images'
|
26 |
|
27 |
query = st.text_input("Search:")
|
28 |
|
|
|
36 |
query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
|
37 |
dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
|
38 |
matching_images = image_list[dot_prod.argsort()[-k:]]
|
39 |
+
|
40 |
+
# st.write(f"matching images: {matching_images}")
|
41 |
+
#show images
|
42 |
+
|
43 |
+
for img_path in matching_images:
|
44 |
+
img = plt.imread(os.path.join(img_dir, img_path))
|
45 |
+
st.write(img_path)
|
46 |
+
st.image(img)
|
47 |
|
requirements.txt
CHANGED
@@ -5,5 +5,6 @@ streamlit==0.84.1
|
|
5 |
torch==1.9.0
|
6 |
torchvision==0.10.0
|
7 |
pandas==1.3.0
|
|
|
8 |
transformers==4.8.2
|
9 |
watchdog==2.1.3
|
|
|
5 |
torch==1.9.0
|
6 |
torchvision==0.10.0
|
7 |
pandas==1.3.0
|
8 |
+
matplotlib>=3.4.0
|
9 |
transformers==4.8.2
|
10 |
watchdog==2.1.3
|