import streamlit as st import pandas as pd from plip_support import embed_text import numpy as np from PIL import Image import requests import tokenizers import os from io import BytesIO import pickle import base64 import torch from transformers import ( VisionTextDualEncoderModel, AutoFeatureExtractor, AutoTokenizer, CLIPModel, AutoProcessor ) import streamlit.components.v1 as components from st_clickable_images import clickable_images #pip install st-clickable-images @st.cache( hash_funcs={ torch.nn.parameter.Parameter: lambda _: None, tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None } ) def load_path_clip(): model = CLIPModel.from_pretrained("vinid/plip") processor = AutoProcessor.from_pretrained("vinid/plip") return model, processor @st.cache def init(): with open('data/twitter.asset', 'rb') as f: data = pickle.load(f) meta = data['meta'].reset_index(drop=True) image_embedding = data['image_embedding'] text_embedding = data['text_embedding'] print(meta.shape, image_embedding.shape) validation_subset_index = meta['source'].values == 'Val_Tweets' return meta, image_embedding, text_embedding, validation_subset_index def embed_images(model, images, processor): inputs = processor(images=images) pixel_values = torch.tensor(np.array(inputs["pixel_values"])) with torch.no_grad(): embeddings = model.get_image_features(pixel_values=pixel_values) return embeddings def embed_texts(model, texts, processor): inputs = processor(text=texts, padding="longest") input_ids = torch.tensor(inputs["input_ids"]) attention_mask = torch.tensor(inputs["attention_mask"]) with torch.no_grad(): embeddings = model.get_text_features( input_ids=input_ids, attention_mask=attention_mask ) return embeddings