import streamlit as st
import numpy as np
from html import escape
import torch
from transformers import RobertaModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').eval()
image_embeddings = torch.load('embeddings.pt')
links = np.load('data.npy', allow_pickle=True)
@st.experimental_memo
def image_search(query, top_k=10):
with torch.no_grad():
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
_, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True)
return [links[i] for i in indices[:top_k]]
def get_html(url_list):
html = "
"
for url in url_list:
html2 = f"
"
html = html + html2
html += "
"
return html
description = '''
# Persian (fa) image search
- Enter your query and hit enter
- Note: We used a small set of images to keep this app almost real-time, but it's obvious that the quality of image search depends heavily on the size of the image database.
Built with [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) model and 25k images from [Unsplash](https://unsplash.com/)
'''
def main():
st.markdown('''
''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
_, c, _ = st.columns((1, 3, 1))
query = c.text_input('Search Box (type in fa)', value='قطره های باران روی شیشه')
c.text("It'll take about 30s to load all new images")
if len(query) > 0:
results = image_search(query)
st.markdown(get_html(results), unsafe_allow_html=True)
if __name__ == '__main__':
main()