import nmslib import numpy as np import streamlit as st from transformers import AutoTokenizer, CLIPProcessor from model import FlaxHybridCLIP from PIL import Image import jax.numpy as jnp import os import jax # st.header('Under construction') st.sidebar.title("CLIP React Demo") st.sidebar.write("Search Reaction GIFs with CLIP [Model Card](https://huggingface.co/flax-community/clip-reply)") st.sidebar.image("./huggingface_explode3.png",width=150) top_k=st.sidebar.slider("Show top-K", min_value=1, max_value=50, value=20) show_val=st.sidebar.button("show all validation set images") if show_val: cols=st.sidebar.beta_columns(col_count) for i,im in enumerate(file_names): j=i%col_count cols[j].image("./imgs/"+im) st.write(" ") st.write(" ") @st.cache(allow_output_mutation=True) def load_model(): model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") processor.tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base") return model, processor @st.cache(allow_output_mutation=True) def load_image_index(): index = nmslib.init(method='hnsw', space='cosinesimil') index.loadIndex("./features/image_embeddings", load_data=True) return index file_names=os.listdir("./imgs") file_names.sort() image_index = load_image_index() model, processor = load_model() col_count=4 # TODO def add_image_emb(image): image = Image.open(image).convert("RGB") inputs = processor(text=[""], images=image, return_tensors="jax", padding=True) inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1]) features = model(**inputs).image_embeds image_index.addDataPoint(features) def query_with_images(query_images,query_text): images = [Image.open(im).convert("RGB") for im in query_images] inputs = processor(text=[query_text], images=images, return_tensors="jax", padding=True) inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1]) outputs = model(**inputs) logits_per_image = outputs.logits_per_image.reshape(-1) st.write(logits_per_image) probs = jax.nn.softmax(logits_per_image) st.write(probs) st.write(list(zip(images,probs))) results = sorted(list(zip(images,probs)),key=lambda x: x[1], reverse=True) st.write(results) return zip(*results) q_cols=st.beta_columns([5,2,5]) examples = ["I'm so scared right now"," I got the job 🎉","OMG that is disgusting","I'm awesome","I love you ❤️"] example_input = q_cols[0].radio("Example Queries :",examples,index=4) q_cols[2].markdown( """ Searches among the validation set images if not specified (There may be non-exact duplicates) """ ) query_text = q_cols[0].text_input("Write text you want to get reaction for", value=example_input) query_images = q_cols[2].file_uploader("(optional) Upload images to rank them",type=['jpg','jpeg'], accept_multiple_files=True) if query_images: st.write("Ranking your uploaded images with respect to input text:") ids, dists = query_with_images(query_images,query_text) else: st.write("Found these images within validation set:") proc = processor(text=[query_text], images=None, return_tensors="jax", padding=True) vec = np.asarray(model.get_text_features(**proc)) ids, dists = image_index.knnQuery(vec, k=top_k) res_cols=st.beta_columns(col_count) for i,(id_, dist) in enumerate(zip(ids, dists)): j=i%col_count with res_cols[j]: if isinstance(id_, np.int32): st.image("./imgs/"+file_names[id_]) # st.write(file_names[id_]) st.write(1.0 - dist, help="score") else: st.image(id_) st.write(dist, help="score")