CLIPfa-Demo / app.py
SajjadAyoubi's picture
Update app.py
4f58d6c
raw history blame
No virus
2.57 kB
import streamlit as st
import pandas as pd
import numpy as np
from html import escape
import os
import torch
from transformers import RobertaModel, AutoTokenizer
@st.cache(show_spinner=False)
def load():
text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text')
tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
links = np.load('data.npy', allow_pickle=True)
image_embeddings = torch.load('embeddings.pt')
return text_encoder, tokenizer, links, image_embeddings
text_encoder, tokenizer, links, image_embeddings = load()
def get_html(url_list, height=224):
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
for url in url_list:
html2 = f"<img style='height: {height}px; margin: 5px' src='{escape(url)}'>"
html = html + html2
html += "</div>"
return html
@st.cache(show_spinner=False)
def image_search(query, top_k=8):
with torch.no_grad():
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
return [links[i] for i in indices[:top_k]]
description = '''
# Semantic image search :)
'''
def main():
st.markdown('''
<style>
.block-container{
max-width: 1200px;
}
div.row-widget.stRadio > div{
flex-direction:row;
display: flex;
justify-content: center;
}
div.row-widget.stRadio > div > label{
margin-left: 5px;
margin-right: 5px;
}
section.main>div:first-child {
padding-top: 0px;
}
section:not(.main)>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
_, c, _ = st.columns((1, 3, 1))
query = c.text_input('', value='clouds at sunset')
if len(query) > 0:
results = image_search(query)
st.markdown(get_html(results), unsafe_allow_html=True)
if __name__ == '__main__':
main()