clip-italian-demo / image2text.py
4rtemi5's picture
fix image loading
0466d84
import streamlit as st
from text2image import get_model, get_tokenizer, get_image_transform
from utils import text_encoder, image_encoder
from PIL import Image
from jax import numpy as jnp
from io import BytesIO
import pandas as pd
import requests
import jax
import gc
headers = {
"User-Agent":
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/70.0.3538.102 Safari/537.36 Edge/18.19582"
}
def app():
st.title("From Image to Text")
st.markdown(
"""
### πŸ‘‹ Ciao!
Here you can find the captions or the labels that are most related to a given image. It is a zero-shot
image classification task!
🀌 Italian mode on! 🀌
For example, try typing "gatto" (cat) in the space for label1 and "cane" (dog) in the space for label2 and click
"classify"!
"""
)
image_url = st.text_input(
"You can input the URL of an image",
value="https://upload.wikimedia.org/wikipedia/commons/thumb/8/88/Ragdoll%2C_blue_mitted.JPG/1280px-Ragdoll%2C_blue_mitted.JPG",
)
MAX_CAP = 4
col1, col2 = st.columns([0.75, 0.25])
with col2:
captions_count = st.selectbox(
"Number of labels", options=range(1, MAX_CAP + 1), index=1
)
compute = st.button("CLASSIFY")
with col1:
captions = list()
for idx in range(min(MAX_CAP, captions_count)):
captions.append(st.text_input(f"Insert label {idx+1}"))
if compute:
captions = [c for c in captions if c != ""]
if not captions or not image_url:
st.error("Please choose one image and at least one label")
else:
with st.spinner("Computing..."):
model = get_model()
tokenizer = get_tokenizer()
text_embeds = list()
for i, c in enumerate(captions):
text_embeds.extend(text_encoder(c, model, tokenizer)[0])
text_embeds = jnp.array(text_embeds)
response = requests.get(image_url, headers=headers, stream=True)
image = Image.open(BytesIO(response.content)).convert("RGB")
transform = get_image_transform(model.config.vision_config.image_size)
image_embed, _ = image_encoder(transform(image), model)
# we could have a softmax here
cos_similarities = jax.nn.softmax(
jnp.matmul(image_embed, text_embeds.T)
)
chart_data = pd.Series(cos_similarities[0], index=captions)
col1, col2 = st.columns(2)
with col1:
st.bar_chart(chart_data)
with col2:
st.image(image, use_column_width=True)
gc.collect()
elif image_url:
response = requests.get(image_url, headers=headers, stream=True)
image = Image.open(BytesIO(response.content)).convert("RGB")
st.image(image)