kaushalya commited on
Commit
e64dbd8
β€’
1 Parent(s): 62a622d

Add sidebar

Browse files
Files changed (1) hide show
  1. app.py +42 -25
app.py CHANGED
@@ -22,35 +22,52 @@ def load_image_embeddings():
22
  k = 5
23
  img_dir = './images'
24
 
25
- st.title("MedCLIP 🩺")
26
- st.image("./assets/logo.png", width=100)
27
- st.markdown("""Search for medical images with natural language powered by a CLIP model [[Model Card]](https://huggingface.co/flax-community/medclip-roco) finetuned on the
 
28
  [Radiology Objects in COntext (ROCO) dataset](https://github.com/razorx89/roco-dataset).""")
29
- st.markdown("""Example queries:
30
- * `ultrasound scans`
31
- * `pathology`
32
- * `pancreatic carcinoma`
33
- * `PET scan`""")
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  image_list, image_embeddings = load_image_embeddings()
36
  model, processor = load_model()
37
 
38
  query = st.text_input("Enter your query here:")
 
 
 
 
 
 
 
 
 
39
 
40
- if st.button("Search"):
41
- with st.spinner(f"Searching ROCO test set for {query}..."):
42
- inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)
43
-
44
- query_embedding = model.get_text_features(**inputs)
45
- query_embedding = np.asarray(query_embedding)
46
- query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
47
- dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
48
- topk_images = dot_prod.argsort()[-k:]
49
- matching_images = image_list[topk_images]
50
- top_scores = 1. - dot_prod[topk_images]
51
- #show images
52
- for img_path, score in zip(matching_images, top_scores):
53
- img = plt.imread(os.path.join(img_dir, img_path))
54
- st.image(img, width=300)
55
- st.write(f"{img_path} ({score:.2f})", help="score")
56
-
 
22
  k = 5
23
  img_dir = './images'
24
 
25
+ st.sidebar.header("MedCLIP")
26
+ st.sidebar.image("./assets/logo.png", width=100)
27
+ st.sidebar.empty()
28
+ st.sidebar.markdown("""Search for medical images with natural language powered by a CLIP model [[Model Card]](https://huggingface.co/flax-community/medclip-roco) finetuned on the
29
  [Radiology Objects in COntext (ROCO) dataset](https://github.com/razorx89/roco-dataset).""")
30
+ st.sidebar.markdown("""Example queries:
31
+ * `ultrasound scans`πŸ”
32
+ * `pathology`πŸ”
33
+ * `pancreatic carcinoma`πŸ”
34
+ * `PET scan`πŸ”""")
35
+ k_slider = st.sidebar.slider("Number of images", min_value=1, max_value=10, value=5)
36
+ st.sidebar.markdown("Kaushalya Madhawa, 2021")
37
+
38
+ st.title("MedCLIP 🩺")
39
+ # st.image("./assets/logo.png", width=100)
40
+ # st.markdown("""Search for medical images with natural language powered by a CLIP model [[Model Card]](https://huggingface.co/flax-community/medclip-roco) finetuned on the
41
+ # [Radiology Objects in COntext (ROCO) dataset](https://github.com/razorx89/roco-dataset).""")
42
+ # st.markdown("""Example queries:
43
+ # * `ultrasound scans`πŸ”
44
+ # * `pathology`πŸ”
45
+ # * `pancreatic carcinoma`πŸ”
46
+ # * `PET scan`πŸ”""")
47
 
48
  image_list, image_embeddings = load_image_embeddings()
49
  model, processor = load_model()
50
 
51
  query = st.text_input("Enter your query here:")
52
+ dot_prod = None
53
+
54
+ if st.button("Search") or k_slider:
55
+ if len(query)==0:
56
+ st.write("Please enter a valid search query")
57
+ else:
58
+ with st.spinner(f"Searching ROCO test set for {query}..."):
59
+ k = k_slider
60
+ inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)
61
 
62
+ query_embedding = model.get_text_features(**inputs)
63
+ query_embedding = np.asarray(query_embedding)
64
+ query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
65
+ dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
66
+ topk_images = dot_prod.argsort()[-k:]
67
+ matching_images = image_list[topk_images]
68
+ top_scores = 1. - dot_prod[topk_images]
69
+ #show images
70
+ for img_path, score in zip(matching_images, top_scores):
71
+ img = plt.imread(os.path.join(img_dir, img_path))
72
+ st.image(img, width=300)
73
+ st.write(f"{img_path} ({score:.2f})", help="score")