medclip-roco / app.py
kaushalya's picture
Add basic search functionality
cffabcf
raw
history blame
No virus
1.43 kB
import streamlit as st
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor
from medclip.modeling_hybrid_clip import FlaxHybridCLIP
@st.cache(allow_output_mutation=True)
def load_model():
model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
return model, processor
@st.cache(allow_output_mutation=True)
def load_image_embeddings():
embeddings_df = pd.read_pickle('image_embeddings.pkl')
image_embeds = np.stack(embeddings_df['image_embedding'])
image_files = np.asarray(embeddings_df['files'].tolist())
return image_files, image_embeds
# def app():
k = 5
image_list, image_embeddings = load_image_embeddings()
model, processor = load_model()
query = st.text_input("Search:")
if st.button("Search"):
st.write(f"Searching our image database for {query}...")
inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)
query_embedding = model.get_text_features(**inputs)
query_embedding = np.asarray(query_embedding)
query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
matching_images = image_list[dot_prod.argsort()[-k:]]
st.write(f"matching images: {matching_images}")