clip-italian-demo / image2text.py
4rtemi5's picture
text improvements, CC is the default dataset, check if URLS can be loaded before displaying, small fixes
6101644
raw history blame
No virus
2.72 kB
import streamlit as st
from text2image import get_model, get_tokenizer, get_image_transform
from utils import text_encoder, image_encoder
from PIL import Image
from jax import numpy as jnp
import pandas as pd
import requests
import jax
import gc
def app():
st.title("From Image to Text")
st.markdown(
"""
### πŸ‘‹ Ciao!
Here you can find the captions or the labels that are most related to a given image. It is a zero-shot
image classification task!
🀌 Italian mode on! 🀌
For example, try to write "gatto" (cat) in the space for label1 and "cane" (dog) in the space for label2 and the run
"classify"!
"""
)
image_url = st.text_input(
"You can input the URL of an image",
value="https://www.petdetective.it/wp-content/uploads/2016/04/gatto-toilette.jpg",
)
MAX_CAP = 4
col1, col2 = st.beta_columns([3, 1])
with col2:
captions_count = st.selectbox(
"Number of labels", options=range(1, MAX_CAP + 1), index=1
)
compute = st.button("Classify")
with col1:
captions = list()
for idx in range(min(MAX_CAP, captions_count)):
captions.append(st.text_input(f"Insert label {idx+1}"))
if compute:
captions = [c for c in captions if c != ""]
if not captions or not image_url:
st.error("Please choose one image and at least one label")
else:
with st.spinner("Computing..."):
model = get_model()
tokenizer = get_tokenizer()
text_embeds = list()
for i, c in enumerate(captions):
text_embeds.extend(text_encoder(c, model, tokenizer))
text_embeds = jnp.array(text_embeds)
image_raw = requests.get(image_url, stream=True,).raw
image = Image.open(image_raw).convert("RGB")
transform = get_image_transform(model.config.vision_config.image_size)
image_embed = image_encoder(transform(image), model)
# we could have a softmax here
cos_similarities = jax.nn.softmax(
jnp.matmul(image_embed, text_embeds.T)
)
chart_data = pd.Series(cos_similarities[0], index=captions)
col1, col2 = st.beta_columns(2)
with col1:
st.bar_chart(chart_data)
with col2:
st.image(image)
gc.collect()
elif image_url:
image_raw = requests.get(image_url, stream=True,).raw
image = Image.open(image_raw).convert("RGB")
st.image(image)