import streamlit as st import pandas as pd from plip_support import embed_text import numpy as np from PIL import Image import requests import tokenizers from io import BytesIO import torch from transformers import ( VisionTextDualEncoderModel, AutoFeatureExtractor, AutoTokenizer, CLIPModel, AutoProcessor ) import streamlit.components.v1 as components def embed_images(model, images, processor): inputs = processor(images=images) pixel_values = torch.tensor(np.array(inputs["pixel_values"])) with torch.no_grad(): embeddings = model.get_image_features(pixel_values=pixel_values) return embeddings @st.cache def load_embeddings(embeddings_path): print("loading embeddings") return np.load(embeddings_path) @st.cache( hash_funcs={ torch.nn.parameter.Parameter: lambda _: None, tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None } ) def load_path_clip(): model = CLIPModel.from_pretrained("vinid/plip") processor = AutoProcessor.from_pretrained("vinid/plip") return model, processor def app(): st.title('PLIP Image Search') plip_imgURL = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t") plip_weblink = pd.read_csv("tweet_eval_retrieval_twlnk.tsv", sep="\t") model, processor = load_path_clip() image_embedding = load_embeddings("tweet_eval_embeddings.npy") query = st.file_uploader("Choose a file") if query: image = Image.open(query) single_image = embed_images(model, [image], processor)[0].detach().cpu().numpy() single_image = single_image/np.linalg.norm(single_image) # Sort IDs by cosine-similarity from high to low similarity_scores = single_image.dot(image_embedding.T) id_sorted = np.argsort(similarity_scores)[::-1] best_id = id_sorted[0] score = similarity_scores[best_id] target_weblink = plip_weblink.iloc[best_id]["weblink"] st.caption('Most relevant image (similarity = %.4f)' % score) components.html('''
''' % target_weblink, height=600)