File size: 3,622 Bytes
1366c30
a269b46
d27f40c
1366c30
d27f40c
 
 
a269b46
d27f40c
cffabcf
 
0f2db82
a269b46
cffabcf
a269b46
 
 
cffabcf
a269b46
cffabcf
0f2db82
cffabcf
 
 
 
 
1366c30
d5345e2
e64dbd8
 
 
 
aa31199
01973e8
 
 
 
 
 
 
 
 
 
e64dbd8
 
 
 
 
 
 
 
 
 
 
 
01973e8
 
 
 
 
 
 
 
 
 
e810c3c
 
a269b46
e810c3c
01973e8
e64dbd8
 
0f2db82
 
 
e64dbd8
 
 
 
 
 
a269b46
 
e64dbd8
 
 
 
 
 
 
 
 
 
 
a269b46
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import token

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import streamlit as st
from transformers import CLIPProcessor, AutoTokenizer

from medclip.modeling_hybrid_clip import FlaxHybridCLIP


@st.cache_resource
def load_model():
    model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco", _do_init=True)
    tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
    return model, tokenizer

@st.cache_resource
def load_image_embeddings():
    embeddings_df = pd.read_hdf('feature_store/image_embeddings_large.hdf', key='emb')
    image_embeds = np.stack(embeddings_df['image_embedding'])
    image_files = np.asarray(embeddings_df['files'].tolist())
    return image_files, image_embeds

k = 5
img_dir = './images'

st.sidebar.header("MedCLIP")
st.sidebar.image("./assets/logo.png", width=100)
st.sidebar.empty()
st.sidebar.markdown("""Search for medical images with natural language powered by a CLIP model [[Model Card]](https://huggingface.co/flax-community/medclip-roco) finetuned on the
 [Radiology Objects in COntext (ROCO) dataset](https://github.com/razorx89/roco-dataset).""")
st.sidebar.markdown("Example queries:")
# * `ultrasound scans`πŸ”
# * `pathology`πŸ”
# * `pancreatic carcinoma`πŸ”
# * `PET scan`πŸ”""")
ex1_button = st.sidebar.button("πŸ” pathology")
ex2_button = st.sidebar.button("πŸ” ultrasound scans")
ex3_button = st.sidebar.button("πŸ” pancreatic carcinoma")
ex4_button = st.sidebar.button("πŸ” PET scan")

k_slider = st.sidebar.slider("Number of images", min_value=1, max_value=10, value=5)
st.sidebar.markdown("Kaushalya Madhawa, 2021")

st.title("MedCLIP 🩺")
# st.image("./assets/logo.png", width=100)
# st.markdown("""Search for medical images with natural language powered by a CLIP model [[Model Card]](https://huggingface.co/flax-community/medclip-roco) finetuned on the
#  [Radiology Objects in COntext (ROCO) dataset](https://github.com/razorx89/roco-dataset).""")
# st.markdown("""Example queries:
# * `ultrasound scans`πŸ”
# * `pathology`πŸ”
# * `pancreatic carcinoma`πŸ”
# * `PET scan`πŸ”""")
text_value = ''
if ex1_button:
    text_value = 'pathology'
elif ex2_button:
    text_value = 'ultrasound scans'
elif ex3_button:
    text_value = 'pancreatic carcinoma'
elif ex4_button:
    text_value = 'PET scan'


image_list, image_embeddings = load_image_embeddings()
model, tokenizer = load_model()

query = st.text_input("Enter your query here:", value=text_value)
dot_prod = None

if len(query)==0:
    query = text_value

if st.button("Search") or k_slider:
    if len(query)==0:
        st.write("Please enter a valid search query")
    else:
        with st.spinner(f"Searching ROCO test set for {query}..."):
            k = k_slider
            inputs = tokenizer(text=[query], return_tensors="jax", padding=True)
            # st.write(f"Query inputs: {inputs}")
            query_embedding = model.get_text_features(**inputs)
            query_embedding = np.asarray(query_embedding)
            query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
            dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
            topk_images = dot_prod.argsort()[-k:]
            matching_images = image_list[topk_images]
            top_scores = 1. - dot_prod[topk_images]
            #show images
            for img_path, score in zip(matching_images, top_scores):
                img = plt.imread(os.path.join(img_dir, img_path))
                st.image(img, width=300)
                st.write(f"{img_path} ({score:.2f})")