Spaces:
Running
Running
File size: 2,154 Bytes
6985586 a01e989 6985586 bca94f8 6c67d85 bca94f8 a01e989 6c67d85 a01e989 6c67d85 a01e989 6c67d85 a01e989 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import streamlit as st
from text2image import get_model, get_tokenizer, get_image_transform
from utils import text_encoder, image_encoder
from PIL import Image
from jax import numpy as jnp
import pandas as pd
def app():
st.title("From Image to Text")
st.markdown(
"""
### 👋 Ciao!
Here you can find the captions or the labels that are most related to a given image. It is a zero-shot
image classification task!
🤌 Italian mode on! 🤌
"""
)
filename = st.file_uploader(
"Choose an image from your computer", type=["jpg", "jpeg", "png"]
)
MAX_CAP = 4
col1, col2 = st.beta_columns([3, 1])
with col2:
captions_count = st.selectbox(
"Number of labels", options=range(1, MAX_CAP + 1)
)
compute = st.button("Compute")
with col1:
captions = list()
for idx in range(min(MAX_CAP, captions_count)):
captions.append(st.text_input(f"Insert label {idx+1}"))
if compute:
captions = [c for c in captions if c != ""]
if not captions or not filename:
st.error("Please choose one image and at least one label")
else:
with st.spinner("Computing..."):
model = get_model()
tokenizer = get_tokenizer()
text_embeds = list()
for i, c in enumerate(captions):
text_embeds.extend(text_encoder(c, model, tokenizer))
text_embeds = jnp.array(text_embeds)
image = Image.open(filename).convert("RGB")
transform = get_image_transform(model.config.vision_config.image_size)
image_embed = image_encoder(transform(image), model)
# we could have a softmax here
cos_similarities = jnp.matmul(image_embed, text_embeds.T)
chart_data = pd.Series(cos_similarities[0], index=captions)
col1, col2 = st.beta_columns(2)
with col1:
st.bar_chart(chart_data)
with col2:
st.image(image)
|