import streamlit as st import pandas as pd import numpy as np import os import matplotlib.pyplot as plt from transformers import AutoTokenizer, CLIPProcessor from medclip.modeling_hybrid_clip import FlaxHybridCLIP @st.cache(allow_output_mutation=True) def load_model(): model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") return model, processor @st.cache(allow_output_mutation=True) def load_image_embeddings(): embeddings_df = pd.read_pickle('feature_store/image_embeddings.pkl') image_embeds = np.stack(embeddings_df['image_embedding']) image_files = np.asarray(embeddings_df['files'].tolist()) return image_files, image_embeds k = 5 img_dir = './images' st.title("MedCLIP πŸ©ΊπŸ“Ž") st.markdown("""Search for medical images with natural language powered by a CLIP model [[Model Card]](https://huggingface.co/flax-community/medclip-roco) finetuned on the [Radiology Objects in COntext (ROCO) dataset](https://github.com/razorx89/roco-dataset).""") st.markdown("""Example queries: * `ultrasound scans` * `PET scan`""") image_list, image_embeddings = load_image_embeddings() model, processor = load_model() query = st.text_input("Enter your query here:") if st.button("Search"): st.write(f"Searching our image database for {query}...") inputs = processor(text=[query], images=None, return_tensors="jax", padding=True) query_embedding = model.get_text_features(**inputs) query_embedding = np.asarray(query_embedding) query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True) dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1) matching_images = image_list[dot_prod.argsort()[-k:]] #show images for img_path in matching_images: img = plt.imread(os.path.join(img_dir, img_path)) st.write(img_path) st.image(img)