gzomer's picture
Update app.py
34a1e3c
import json
import numpy as np
import streamlit as st
from st_clickable_images import clickable_images
from clip_multilingual.search import MultiLingualSearch
from clip_multilingual.models import Tokenizer
@st.cache(
suppress_st_warning=True,
hash_funcs={
Tokenizer: lambda _: None
}
)
def load_model():
unsplash_base_folder = './'
all_embeddings = np.load(f'{unsplash_base_folder}/embeddings.npy')
with open(f'{unsplash_base_folder}/urls.json') as f:
all_urls = json.load(f)
return MultiLingualSearch(all_embeddings, all_urls)
semantic_search = load_model()
description = '''
# Multilingual Semantic Search
**Search images in 100 languages (list [here](https://github.com/pytorch/fairseq/blob/main/examples/xlmr/README.md#introduction)) powered by [MultiLingual CLIP](https://huggingface.co/gzomer/clip-multilingual).**
MultiLingual CLIP is a custom model built using OpenAI's [CLIP](https://openai.com/blog/clip/) and [XMLRoBERTa](https://huggingface.co/xlm-roberta-base) models, trained using 16 [Habana](https://habana.ai/) accelerators with PyTorch Lightning, Distributed Data Parallel, Mixed precision and using [COCO](https://cocodataset.org/) and [Google Conceptual Captions](https://ai.google.com/research/ConceptualCaptions) as training datasets.
See [repo](https://github.com/gzomer/clip-multilingual) and [model](https://huggingface.co/gzomer/clip-multilingual) for more info.
'''
st.sidebar.markdown(description)
examples = [
('chinese','家人在一起','family'),
('hindi','विद्यालय में','at school'),
('arabic','البنايات','buildings'),
('swahili','watu wanaofanya kazi','people working'),
('japanese', '美しい空','beautiful sky'),
('portuguese','praias bonitas','beautiful beaches'),
('greek','νόστιμα ζυμαρικά','delicious pasta'),
('armenian','գրասենյակներ','offices'),
('zulu','izilwane ezinhle','beautiful animals'),
('amharic','ደስተኛ ሰዎች','happy people'),
('urdu','کتب خانہ','library'),
('georgian','ლამაზი მანქანები','nice cars'),
('german','schöne Blumen','nice flowers'),
('french','les gens en vacances','people on vacations'),
('spanish','gatos y perros','cats and dogs'),
('english','people having fun', None),
]
for example in examples:
if example[2]:
st.sidebar.text(f'{example[0].title()} (means {example[2]})')
else:
st.sidebar.text(f'{example[0].title()}')
if st.sidebar.button(example[1]):
st.session_state.query = example[1]
_, c, _ = st.columns((1, 3, 1))
if "query" in st.session_state:
query = c.text_input("", value=st.session_state["query"])
else:
query = c.text_input("", value=examples[0][1])
if len(query) > 0:
results = semantic_search.search(query)
clicked = clickable_images(
[result['image'] for result in results],
titles=[f'Prob: {result["prob"]}' for result in results],
div_style={
"display": "flex",
"justify-content": "center",
"flex-wrap": "wrap",
},
img_style={"margin": "2px", "height": "200px"},
)