import streamlit as st import numpy as np from html import escape import torch import torchvision.transforms as transforms from transformers import BertModel, AutoTokenizer, CLIPVisionModel from PIL import Image import io IMAGE_SIZE = 224 MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073]) STD = torch.tensor([0.26862954, 0.26130258, 0.27577711]) device = 'cuda' modelPath = 'TheLitttleThings/clip-archdaily-text' tokenizer = AutoTokenizer.from_pretrained(modelPath) text_encoder = BertModel.from_pretrained(modelPath).eval() vision_encoder = CLIPVisionModel.from_pretrained( 'TheLitttleThings/clip-archdaily-vision').eval() image_embeddings = torch.load('image_embeddings.pt') text_embeddings = torch.load('text_embeddings.pt') links = np.load('links_list.npy', allow_pickle=True) categories = np.load('categories_list.npy', allow_pickle=True) if 'tab' not in st.session_state: st.session_state['tab'] = 0 @st.experimental_memo def image_search(query, top_k=24): with torch.no_grad(): text_embedding = text_encoder( **tokenizer(query, return_tensors='pt')).pooler_output _, indices = torch.cosine_similarity( image_embeddings, text_embedding).sort(descending=True) return [links[i] for i in indices[:top_k]] def text_query_embedding(query: str = 'architecture'): tokens = tokenizer(query, return_tensors='pt') with torch.no_grad(): text_embedding = text_encoder( **tokenizer(query, return_tensors='pt')).pooler_output return text_embedding preprocessImage = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=MEAN, std=STD) ]) def image_query_embedding(image): image = preprocessImage(image).unsqueeze(0) with torch.no_grad(): image_embedding = vision_encoder(image).pooler_output return image_embedding def most_similars(embeddings_1, embeddings_2): values, indices = torch.cosine_similarity( embeddings_1, embeddings_2).sort(descending=True) return values.cpu(), indices.cpu() def analogy(input_image_path: str, top_k=24, additional_text: str = '', input_include=True): """ Analogies with embedding space arithmetic. Args: input_image_path (str): The path to original image image_paths (list[str]): A database of images """ base_image = Image.open(input_image_path) image_embedding = image_query_embedding(base_image) additional_embedding = text_query_embedding(query=additional_text) new_image_embedding = image_embedding # + additional_embedding _, indices = most_similars(image_embeddings, new_image_embedding) return [links[i] for i in indices[:top_k]] def image_comparison(base_image, top_k=24): image_embedding = image_query_embedding(base_image) #additional_embedding = text_query_embedding(query=additional_text) new_image_embedding = image_embedding # + additional_embedding _, indices = most_similars(image_embeddings, new_image_embedding) return [links[i] for i in indices[:top_k]] def get_html(url_list, classOther=""): html = f"