Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
from plip_support import embed_text | |
import numpy as np | |
from PIL import Image | |
import requests | |
import tokenizers | |
from io import BytesIO | |
import torch | |
from transformers import ( | |
VisionTextDualEncoderModel, | |
AutoFeatureExtractor, | |
AutoTokenizer, | |
CLIPModel, | |
AutoProcessor | |
) | |
import streamlit.components.v1 as components | |
def embed_texts(model, texts, processor): | |
inputs = processor(text=texts, padding="longest") | |
input_ids = torch.tensor(inputs["input_ids"]) | |
attention_mask = torch.tensor(inputs["attention_mask"]) | |
with torch.no_grad(): | |
embeddings = model.get_text_features( | |
input_ids=input_ids, attention_mask=attention_mask | |
) | |
return embeddings | |
def load_embeddings(embeddings_path): | |
print("loading embeddings") | |
return np.load(embeddings_path) | |
def load_path_clip(): | |
model = CLIPModel.from_pretrained("vinid/plip") | |
processor = AutoProcessor.from_pretrained("vinid/plip") | |
return model, processor | |
def app(): | |
st.title('PLIP Image Search') | |
plip_imgURL = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t") | |
plip_weblink = pd.read_csv("tweet_eval_retrieval_twlnk.tsv", sep="\t") | |
model, processor = load_path_clip() | |
image_embedding = load_embeddings("tweet_eval_embeddings.npy") | |
query = st.text_input('Search Query', '') | |
if query: | |
text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy() | |
text_embedding = text_embedding/np.linalg.norm(text_embedding) | |
# Sort IDs by cosine-similarity from high to low | |
similarity_scores = text_embedding.dot(image_embedding.T) | |
id_sorted = np.argsort(similarity_scores)[::-1] | |
best_id = id_sorted[0] | |
score = similarity_scores[best_id] | |
target_url = plip_imgURL.iloc[best_id]["imageURL"] | |
target_weblink = plip_weblink.iloc[best_id]["weblink"] | |
st.caption('Most relevant image (similarity = %.4f)' % score) | |
#response = requests.get(target_url) | |
#img = Image.open(BytesIO(response.content)) | |
#st.image(img) | |
components.html(''' | |
<blockquote class="twitter-tweet"> | |
<a href="%s"></a> | |
</blockquote> | |
<script async src="https://platform.twitter.com/widgets.js" charset="utf-8"> | |
</script> | |
''' % target_weblink, | |
height=600) | |