File size: 3,553 Bytes
d5345e2
cffabcf
 
1366c30
 
aa31199
cffabcf
 
 
 
 
 
 
 
 
 
58a320e
cffabcf
 
 
 
 
1366c30
d5345e2
e64dbd8
 
 
 
aa31199
01973e8
 
 
 
 
 
 
 
 
 
e64dbd8
 
 
 
 
 
 
 
 
 
 
 
01973e8
 
 
 
 
 
 
 
 
 
e810c3c
 
 
 
01973e8
e64dbd8
 
 
 
 
 
 
 
 
d5345e2
e64dbd8
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, CLIPProcessor
from medclip.modeling_hybrid_clip import FlaxHybridCLIP

@st.cache(allow_output_mutation=True)
def load_model():
    model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    return model, processor

@st.cache(allow_output_mutation=True)
def load_image_embeddings():
    embeddings_df = pd.read_pickle('feature_store/image_embeddings_large.pkl')
    image_embeds = np.stack(embeddings_df['image_embedding'])
    image_files = np.asarray(embeddings_df['files'].tolist())
    return image_files, image_embeds

k = 5
img_dir = './images'

st.sidebar.header("MedCLIP")
st.sidebar.image("./assets/logo.png", width=100)
st.sidebar.empty()
st.sidebar.markdown("""Search for medical images with natural language powered by a CLIP model [[Model Card]](https://huggingface.co/flax-community/medclip-roco) finetuned on the
 [Radiology Objects in COntext (ROCO) dataset](https://github.com/razorx89/roco-dataset).""")
st.sidebar.markdown("Example queries:")
# * `ultrasound scans`πŸ”
# * `pathology`πŸ”
# * `pancreatic carcinoma`πŸ”
# * `PET scan`πŸ”""")
ex1_button = st.sidebar.button("πŸ” pathology")
ex2_button = st.sidebar.button("πŸ” ultrasound scans")
ex3_button = st.sidebar.button("πŸ” pancreatic carcinoma")
ex4_button = st.sidebar.button("πŸ” PET scan")

k_slider = st.sidebar.slider("Number of images", min_value=1, max_value=10, value=5)
st.sidebar.markdown("Kaushalya Madhawa, 2021")

st.title("MedCLIP 🩺")
# st.image("./assets/logo.png", width=100)
# st.markdown("""Search for medical images with natural language powered by a CLIP model [[Model Card]](https://huggingface.co/flax-community/medclip-roco) finetuned on the
#  [Radiology Objects in COntext (ROCO) dataset](https://github.com/razorx89/roco-dataset).""")
# st.markdown("""Example queries:
# * `ultrasound scans`πŸ”
# * `pathology`πŸ”
# * `pancreatic carcinoma`πŸ”
# * `PET scan`πŸ”""")
text_value = ''
if ex1_button:
    text_value = 'pathology'
elif ex2_button:
    text_value = 'ultrasound scans'
elif ex3_button:
    text_value = 'pancreatic carcinoma'
elif ex4_button:
    text_value = 'PET scan'


image_list, image_embeddings = load_image_embeddings()
model, processor = load_model()

query = st.text_input("Enter your query here:", value=text_value)
dot_prod = None

if st.button("Search") or k_slider:
    if len(query)==0:
        st.write("Please enter a valid search query")
    else:
        with st.spinner(f"Searching ROCO test set for {query}..."):
            k = k_slider
            inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)

            query_embedding = model.get_text_features(**inputs)
            query_embedding = np.asarray(query_embedding)
            query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
            dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
            topk_images = dot_prod.argsort()[-k:]
            matching_images = image_list[topk_images]
            top_scores = 1. - dot_prod[topk_images]
            #show images
            for img_path, score in zip(matching_images, top_scores):
                img = plt.imread(os.path.join(img_dir, img_path))
                st.image(img, width=300)
                st.write(f"{img_path} ({score:.2f})", help="score")