webplip / text2image.py
vinid's picture
Upload 4 files
61448a4
raw history blame
No virus
1.95 kB
import streamlit as st
import pandas as pd
from plip_support import embed_text
import numpy as np
from PIL import Image
import requests
import transformers
import tokenizers
from io import BytesIO
import streamlit as st
from transformers import CLIPModel
import clip
import torch
from transformers import (
VisionTextDualEncoderModel,
AutoFeatureExtractor,
AutoTokenizer
)
from transformers import AutoProcessor
def embed_texts(model, texts, processor):
inputs = processor(text=texts, padding="longest")
input_ids = torch.tensor(inputs["input_ids"])
attention_mask = torch.tensor(inputs["attention_mask"])
with torch.no_grad():
embeddings = model.get_text_features(
input_ids=input_ids, attention_mask=attention_mask
)
return embeddings
@st.cache
def load_embeddings(embeddings_path):
print("loading embeddings")
return np.load(embeddings_path)
@st.cache(
hash_funcs={
torch.nn.parameter.Parameter: lambda _: None,
tokenizers.Tokenizer: lambda _: None,
tokenizers.AddedToken: lambda _: None
}
)
def load_path_clip():
model = CLIPModel.from_pretrained("vinid/plip")
processor = AutoProcessor.from_pretrained("vinid/plip")
return model, processor
def app():
st.title('PLIP Image Search')
plip_dataset = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t")
model, processor = load_path_clip()
image_embedding = load_embeddings("tweet_eval_embeddings.npy")
query = st.text_input('Search Query', '')
if query:
text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy()
text_embedding = text_embedding/np.linalg.norm(text_embedding)
best_id = np.argmax(text_embedding.dot(image_embedding.T))
url = (plip_dataset.iloc[best_id]["imageURL"])
response = requests.get(url)
img = Image.open(BytesIO(response.content))
st.image(img)