import streamlit as st import numpy as np from html import escape import torch import torchvision.transforms as transforms from transformers import BertModel, AutoTokenizer,CLIPVisionModel from PIL import Image import io IMAGE_SIZE = 224 MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073]) STD = torch.tensor([0.26862954, 0.26130258, 0.27577711]) device='cuda' modelPath = 'TheLitttleThings/clip-archdaily-text' tokenizer = AutoTokenizer.from_pretrained(modelPath) text_encoder = BertModel.from_pretrained(modelPath).eval() vision_encoder = CLIPVisionModel.from_pretrained('TheLitttleThings/clip-archdaily-vision').eval() image_embeddings = torch.load('image_embeddings.pt') links = np.load('data_list.npy', allow_pickle=True) categories = np.load('categories_list.npy', allow_pickle=True) if 'tab' not in st.session_state: st.session_state['tab'] = 0 @st.experimental_memo def image_search(query, top_k=24): with torch.no_grad(): text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output _, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True) return [links[i] for i in indices[:top_k]] def text_query_embedding( query: str = 'architecture'): tokens = tokenizer(query, return_tensors='pt') with torch.no_grad(): text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output return text_embedding preprocessImage = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=MEAN, std=STD) ]) def image_query_embedding( image): image = preprocessImage(image).unsqueeze(0) with torch.no_grad(): image_embedding = vision_encoder(image).pooler_output return image_embedding def most_similars(embeddings_1, embeddings_2): values, indices = torch.cosine_similarity(embeddings_1, embeddings_2).sort(descending=True) return values.cpu(), indices.cpu() def analogy(input_image_path: str, top_k=24, additional_text: str = '', input_include=True): """ Analogies with embedding space arithmetic. Args: input_image_path (str): The path to original image image_paths (list[str]): A database of images """ base_image = Image.open(input_image_path) image_embedding = image_query_embedding(base_image) additional_embedding = text_query_embedding(query=additional_text) new_image_embedding = image_embedding #+ additional_embedding _, indices = most_similars(image_embeddings, new_image_embedding) return [links[i] for i in indices[:top_k]] def image_comparison(base_image, top_k=24): image_embedding = image_query_embedding(base_image) #additional_embedding = text_query_embedding(query=additional_text) new_image_embedding = image_embedding #+ additional_embedding _, indices = most_similars(image_embeddings, new_image_embedding) return [links[i] for i in indices[:top_k]] def get_html(url_list): html = "