clip-reply-demo / app.py
Ceyda Cinarel
add config
4474721
raw history blame
No virus
3.25 kB
import nmslib
import numpy as np
import streamlit as st
from transformers import AutoTokenizer, CLIPProcessor
from model import FlaxHybridCLIP
from PIL import Image
import jax.numpy as jnp
import os
import jax
st.header('Under construction')
st.title("CLIP Reply Demo")
st.sidebar.markdown(
"""
Validation set: 351 images/273 deduped (There are still duplicates)
Example Queries :
"""
)
@st.cache(allow_output_mutation=True)
def load_model():
model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
processor.tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base")
return model, processor
@st.cache(allow_output_mutation=True)
def load_image_index():
index = nmslib.init(method='hnsw', space='cosinesimil')
index.loadIndex("./features/image_embeddings", load_data=True)
return index
file_names=os.listdir("./imgs")
file_names.sort()
image_index = load_image_index()
model, processor = load_model()
col_count=4
top_k=10
show_val=st.sidebar.button("show all validation set images")
if show_val:
cols=st.sidebar.beta_columns(col_count)
for i,im in enumerate(file_names):
j=i%col_count
cols[j].image("./imgs/"+im)
# TODO
def add_image_emb(image):
image = Image.open(image).convert("RGB")
inputs = processor(text=[""], images=image, return_tensors="jax", padding=True)
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
features = model(**inputs).image_embeds
image_index.addDataPoint(features)
def query_with_images(query_images,query_text):
images = [Image.open(im).convert("RGB") for im in query_images]
inputs = processor(text=[query_text], images=images, return_tensors="jax", padding=True)
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image.reshape(-1)
st.write(logits_per_image)
probs = jax.nn.softmax(logits_per_image)
st.write(probs)
st.write(list(zip(images,probs)))
results = sorted(list(zip(images,probs)),key=lambda x: x[1], reverse=True)
st.write(results)
return zip(*results)
q_cols=st.beta_columns(2)
query_text = q_cols[0].text_input("Input text", value="I love you")
query_images = q_cols[1].file_uploader("(optional) upload query image",type=['jpg','jpeg'], accept_multiple_files=True)
if query_images:
st.write("Ranking uploaded images with respect to input text")
ids, dists = query_with_images(query_images,query_text)
else:
st.write("Finding within validation set")
proc = processor(text=[query_text], images=None, return_tensors="jax", padding=True)
vec = np.asarray(model.get_text_features(**proc))
ids, dists = image_index.knnQuery(vec, k=top_k)
res_cols=st.beta_columns(col_count)
for i,(id_, dist) in enumerate(zip(ids, dists)):
j=i%col_count
with res_cols[j]:
if isinstance(id_, np.int32):
st.image("./imgs/"+file_names[id_])
# st.write(file_names[id_])
st.write(1.0 - dist)
else:
st.image(id_)
st.write(dist)