webplip / app.py
vinid's picture
Upload app.py
d1b8523
raw
history blame
1.89 kB
import streamlit as st
import pandas as pd
from plip_support import embed_text
import numpy as np
from PIL import Image
import requests
import transformers
import tokenizers
from io import BytesIO
import streamlit as st
from transformers import CLIPModel
import clip
import torch
from transformers import (
VisionTextDualEncoderModel,
AutoFeatureExtractor,
AutoTokenizer
)
from transformers import AutoProcessor
def embed_texts(model, texts, processor):
inputs = processor(text=texts, padding="longest")
input_ids = torch.tensor(inputs["input_ids"])
attention_mask = torch.tensor(inputs["attention_mask"])
with torch.no_grad():
embeddings = model.get_text_features(
input_ids=input_ids, attention_mask=attention_mask
)
return embeddings
@st.cache
def load_embeddings(embeddings_path):
print("loading embeddings")
return np.load(embeddings_path)
@st.cache(
hash_funcs={
torch.nn.parameter.Parameter: lambda _: None,
tokenizers.Tokenizer: lambda _: None,
tokenizers.AddedToken: lambda _: None
}
)
def load_path_clip():
model = CLIPModel.from_pretrained("vinid/plip")
processor = AutoProcessor.from_pretrained("vinid/plip")
return model, processor
st.title('PLIP Image Search')
plip_dataset = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t")
model, processor = load_path_clip()
image_embedding = load_embeddings("tweet_eval_embeddings.npy")
query = st.text_input('Search Query', '')
if query:
text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy()
text_embedding = text_embedding/np.linalg.norm(text_embedding)
best_id = np.argmax(text_embedding.dot(image_embedding.T))
url = (plip_dataset.iloc[best_id]["imageURL"])
response = requests.get(url)
img = Image.open(BytesIO(response.content))
st.image(img)