File size: 1,650 Bytes
d5345e2
cffabcf
 
1366c30
 
cffabcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1366c30
d5345e2
 
 
 
cffabcf
 
 
 
 
 
 
 
 
1366c30
 
 
 
 
 
 
 
cffabcf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import streamlit as st
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor
from medclip.modeling_hybrid_clip import FlaxHybridCLIP

@st.cache(allow_output_mutation=True)
def load_model():
    model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    return model, processor

@st.cache(allow_output_mutation=True)
def load_image_embeddings():
    embeddings_df = pd.read_pickle('image_embeddings.pkl')
    image_embeds = np.stack(embeddings_df['image_embedding'])
    image_files = np.asarray(embeddings_df['files'].tolist())
    return image_files, image_embeds

k = 5
image_list, image_embeddings = load_image_embeddings()
model, processor = load_model()
img_dir = './images'

query = st.text_input("Search:")

if st.button("Search"):
    st.write(f"Searching our image database for {query}...")

    inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)

    query_embedding = model.get_text_features(**inputs)
    query_embedding = np.asarray(query_embedding)
    query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
    dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
    matching_images = image_list[dot_prod.argsort()[-k:]]

    # st.write(f"matching images: {matching_images}")
    #show images
    
    for img_path in matching_images:
        img = plt.imread(os.path.join(img_dir, img_path))
        st.write(img_path)
        st.image(img)