koclip / image2text.py
jaketae's picture
feature: add intro page, cleanup descriptions
a811816
import jax
import jax.numpy as jnp
import pandas as pd
import requests
import streamlit as st
from PIL import Image
from utils import load_model
def app(model_name):
model, processor = load_model(f"koclip/{model_name}")
st.title("Zero-shot Image Classification")
st.markdown(
"""
This demo explores KoCLIP's zero-shot prediction capabilities. The model takes an image and a list of candidate captions from the user and predicts the most likely caption that best describes the given image.
---
"""
)
query1 = st.text_input(
"Enter a URL to an image...",
value="http://images.cocodataset.org/val2017/000000039769.jpg",
)
query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])
col1, col2 = st.beta_columns([3, 1])
with col2:
captions_count = st.selectbox("Number of labels", options=range(1, 6), index=2)
normalize = st.checkbox("Apply Softmax")
compute = st.button("Classify")
with col1:
captions = []
defaults = ["귀여운 고양이", "멋있는 강아지", "포동포동한 햄스터"]
for idx in range(captions_count):
value = defaults[idx] if idx < len(defaults) else ""
captions.append(st.text_input(f"Insert caption {idx+1}", value=value))
if compute:
if not any([query1, query2]):
st.error("Please upload an image or paste an image URL.")
else:
st.markdown("""---""")
with st.spinner("Computing..."):
image_data = (
query2
if query2 is not None
else requests.get(query1, stream=True).raw
)
image = Image.open(image_data)
# captions = [caption.strip() for caption in captions.split(",")]
captions = [f"이것은 {caption.strip()}이다." for caption in captions]
inputs = processor(
text=captions, images=image, return_tensors="jax", padding=True
)
inputs["pixel_values"] = jnp.transpose(
inputs["pixel_values"], axes=[0, 2, 3, 1]
)
outputs = model(**inputs)
if normalize:
name = "normalized prob"
probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
else:
name = "cosine sim"
probs = outputs.logits_per_image
chart_data = pd.Series(probs[0], index=captions, name=name)
col1, col2 = st.beta_columns(2)
with col1:
st.image(image)
with col2:
st.bar_chart(chart_data)