Spaces:
Running
Running
File size: 3,553 Bytes
d5345e2 cffabcf 1366c30 aa31199 cffabcf 58a320e cffabcf 1366c30 d5345e2 e64dbd8 aa31199 01973e8 e64dbd8 01973e8 e810c3c 01973e8 e64dbd8 d5345e2 e64dbd8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import streamlit as st
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, CLIPProcessor
from medclip.modeling_hybrid_clip import FlaxHybridCLIP
@st.cache(allow_output_mutation=True)
def load_model():
model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
return model, processor
@st.cache(allow_output_mutation=True)
def load_image_embeddings():
embeddings_df = pd.read_pickle('feature_store/image_embeddings_large.pkl')
image_embeds = np.stack(embeddings_df['image_embedding'])
image_files = np.asarray(embeddings_df['files'].tolist())
return image_files, image_embeds
k = 5
img_dir = './images'
st.sidebar.header("MedCLIP")
st.sidebar.image("./assets/logo.png", width=100)
st.sidebar.empty()
st.sidebar.markdown("""Search for medical images with natural language powered by a CLIP model [[Model Card]](https://huggingface.co/flax-community/medclip-roco) finetuned on the
[Radiology Objects in COntext (ROCO) dataset](https://github.com/razorx89/roco-dataset).""")
st.sidebar.markdown("Example queries:")
# * `ultrasound scans`π
# * `pathology`π
# * `pancreatic carcinoma`π
# * `PET scan`π""")
ex1_button = st.sidebar.button("π pathology")
ex2_button = st.sidebar.button("π ultrasound scans")
ex3_button = st.sidebar.button("π pancreatic carcinoma")
ex4_button = st.sidebar.button("π PET scan")
k_slider = st.sidebar.slider("Number of images", min_value=1, max_value=10, value=5)
st.sidebar.markdown("Kaushalya Madhawa, 2021")
st.title("MedCLIP π©Ί")
# st.image("./assets/logo.png", width=100)
# st.markdown("""Search for medical images with natural language powered by a CLIP model [[Model Card]](https://huggingface.co/flax-community/medclip-roco) finetuned on the
# [Radiology Objects in COntext (ROCO) dataset](https://github.com/razorx89/roco-dataset).""")
# st.markdown("""Example queries:
# * `ultrasound scans`π
# * `pathology`π
# * `pancreatic carcinoma`π
# * `PET scan`π""")
text_value = ''
if ex1_button:
text_value = 'pathology'
elif ex2_button:
text_value = 'ultrasound scans'
elif ex3_button:
text_value = 'pancreatic carcinoma'
elif ex4_button:
text_value = 'PET scan'
image_list, image_embeddings = load_image_embeddings()
model, processor = load_model()
query = st.text_input("Enter your query here:", value=text_value)
dot_prod = None
if st.button("Search") or k_slider:
if len(query)==0:
st.write("Please enter a valid search query")
else:
with st.spinner(f"Searching ROCO test set for {query}..."):
k = k_slider
inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)
query_embedding = model.get_text_features(**inputs)
query_embedding = np.asarray(query_embedding)
query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
topk_images = dot_prod.argsort()[-k:]
matching_images = image_list[topk_images]
top_scores = 1. - dot_prod[topk_images]
#show images
for img_path, score in zip(matching_images, top_scores):
img = plt.imread(os.path.join(img_dir, img_path))
st.image(img, width=300)
st.write(f"{img_path} ({score:.2f})", help="score")
|