Spaces:
Runtime error
Runtime error
import nmslib | |
import numpy as np | |
import streamlit as st | |
from transformers import AutoTokenizer, CLIPProcessor | |
from model import FlaxHybridCLIP | |
from PIL import Image | |
import jax.numpy as jnp | |
import os | |
import jax | |
st.header('Under construction') | |
st.title("CLIP Reply Demo") | |
st.sidebar.markdown( | |
""" | |
Validation set: 351 images/273 deduped (There are still duplicates) | |
Example Queries : | |
""" | |
) | |
def load_model(): | |
model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
processor.tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base") | |
return model, processor | |
def load_image_index(): | |
index = nmslib.init(method='hnsw', space='cosinesimil') | |
index.loadIndex("./features/image_embeddings", load_data=True) | |
return index | |
file_names=os.listdir("./imgs") | |
file_names.sort() | |
image_index = load_image_index() | |
model, processor = load_model() | |
col_count=4 | |
top_k=10 | |
show_val=st.sidebar.button("show all validation set images") | |
if show_val: | |
cols=st.sidebar.beta_columns(col_count) | |
for i,im in enumerate(file_names): | |
j=i%col_count | |
cols[j].image("./imgs/"+im) | |
# TODO | |
def add_image_emb(image): | |
image = Image.open(image).convert("RGB") | |
inputs = processor(text=[""], images=image, return_tensors="jax", padding=True) | |
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1]) | |
features = model(**inputs).image_embeds | |
image_index.addDataPoint(features) | |
def query_with_images(query_images,query_text): | |
images = [Image.open(im).convert("RGB") for im in query_images] | |
inputs = processor(text=[query_text], images=images, return_tensors="jax", padding=True) | |
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1]) | |
outputs = model(**inputs) | |
logits_per_image = outputs.logits_per_image.reshape(-1) | |
st.write(logits_per_image) | |
probs = jax.nn.softmax(logits_per_image) | |
st.write(probs) | |
st.write(list(zip(images,probs))) | |
results = sorted(list(zip(images,probs)),key=lambda x: x[1], reverse=True) | |
st.write(results) | |
return zip(*results) | |
q_cols=st.beta_columns(2) | |
query_text = q_cols[0].text_input("Input text", value="I love you") | |
query_images = q_cols[1].file_uploader("(optional) upload query image",type=['jpg','jpeg'], accept_multiple_files=True) | |
if query_images: | |
st.write("Ranking uploaded images with respect to input text") | |
ids, dists = query_with_images(query_images,query_text) | |
else: | |
st.write("Finding within validation set") | |
proc = processor(text=[query_text], images=None, return_tensors="jax", padding=True) | |
vec = np.asarray(model.get_text_features(**proc)) | |
ids, dists = image_index.knnQuery(vec, k=top_k) | |
res_cols=st.beta_columns(col_count) | |
for i,(id_, dist) in enumerate(zip(ids, dists)): | |
j=i%col_count | |
with res_cols[j]: | |
if isinstance(id_, np.int32): | |
st.image("./imgs/"+file_names[id_]) | |
else: | |
st.image(id_) | |
# st.write(file_names[id_]) | |
st.write(dist) | |