import streamlit as st import pandas as pd from html import escape import os import torch from transformers import RobertaModel, AutoTokenizer @st.cache(show_spinner=False) def load(): text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text') tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text') links = np.load('link.npy', allow_pickle=True) image_embeddings = torch.load('embeddings.pt') return text_encoder, tokenizer, links, image_embeddings text_encoder, tokenizer, links, image_embeddings = load() def get_html(url_list, height=224): html = "
" for url in url_list: html2 = f"" html = html + html2 html += "
" return html @st.cache(show_spinner=False) def image_search(query, top_k=8): with torch.no_grad(): text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True) return [links[i] for i in indices[:top_k]] description = ''' # Semantic image search :) ''' def main(): st.markdown(''' ''', unsafe_allow_html=True) st.sidebar.markdown(description) _, c, _ = st.columns((1, 3, 1)) query = c.text_input('', value='clouds at sunset') if len(query) > 0: results = image_search(query) st.markdown(get_html(results), unsafe_allow_html=True) if __name__ == '__main__': main()