medclip-roco / app.py
kaushalya's picture
Add documentation
aa31199
raw history blame
No virus
1.85 kB
import streamlit as st
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, CLIPProcessor
from medclip.modeling_hybrid_clip import FlaxHybridCLIP
@st.cache(allow_output_mutation=True)
def load_model():
model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
return model, processor
@st.cache(allow_output_mutation=True)
def load_image_embeddings():
embeddings_df = pd.read_pickle('feature_store/image_embeddings.pkl')
image_embeds = np.stack(embeddings_df['image_embedding'])
image_files = np.asarray(embeddings_df['files'].tolist())
return image_files, image_embeds
k = 5
image_list, image_embeddings = load_image_embeddings()
model, processor = load_model()
img_dir = './images'
st.title("MedCLIP 🩺📎")
st.markdown("Search for medical images in natural language.")
st.markdown("""This demo uses a CLIP model finetuned on the
[Radiology Objects in COntext (ROCO) dataset](https://github.com/razorx89/roco-dataset).""")
query = st.text_input("Enter your query here:")
if st.button("Search"):
st.write(f"Searching our image database for {query}...")
inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)
query_embedding = model.get_text_features(**inputs)
query_embedding = np.asarray(query_embedding)
query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
matching_images = image_list[dot_prod.argsort()[-k:]]
#show images
for img_path in matching_images:
img = plt.imread(os.path.join(img_dir, img_path))
st.write(img_path)
st.image(img)