import streamlit as st import numpy as np from html import escape import torch from transformers import RobertaModel, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text') text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').eval() image_embeddings = torch.load('embeddings.pt') links = np.load('data.npy', allow_pickle=True) @st.experimental_memo def image_search(query, top_k=10): with torch.no_grad(): text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output _, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True) return [links[i] for i in indices[:top_k]] def get_html(url_list): html = "
" for url in url_list: html2 = f"" html = html + html2 html += "
" return html description = ''' # Persian (fa) image search - Enter your query and hit enter - Note: We used a small set of images to keep this app almost real-time, but it's obvious that the quality of image search depends heavily on the size of the image database. Built with [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) model and 25k images from [Unsplash](https://unsplash.com/) ''' def main(): st.markdown(''' ''', unsafe_allow_html=True) st.sidebar.markdown(description) _, c, _ = st.columns((1, 3, 1)) query = c.text_input('Search Box (type in fa)', value='قطره های باران روی شیشه') c.text("It'll take about 30s to load all new images") if len(query) > 0: results = image_search(query) st.markdown(get_html(results), unsafe_allow_html=True) if __name__ == '__main__': main()