IRRA / app.py
grostaco's picture
initial commit
ba2ab36
raw
history blame
1.18 kB
import streamlit as st
from lib.utils.model import get_model, get_similarities
from PIL import Image
st.title('IRRA Text-To-Image-Retrival')
st.header('Inputs')
caption = st.text_input('Description Input')
images = st.file_uploader('Upload images', accept_multiple_files=True)
if images is not None:
st.image(images) # type: ignore
st.header('Options')
st.subheader('Ranks')
ranks = st.slider('slider_ranks', min_value=1, max_value=10, label_visibility='collapsed',value=5)
button = st.button('Match most similar', disabled=len(images) == 0 or caption == '')
if button:
st.header('Results')
with st.spinner('Loading model'):
model = get_model()
st.text(f'IRRA model loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters')
with st.spinner('Computing and ranking similarities'):
similarities = get_similarities(caption, images, model)
indices = similarities.argsort(descending=True).squeeze(0).cpu().tolist()[:ranks]
for i, idx in enumerate(indices):
c1, c2 = st.columns(2)
with c1:
st.text(f'Rank {i + 1}')
with c2:
st.image(images[idx])