koclip / image2text.py
jaketae's picture
feature: add intro page, cleanup descriptions
a811816
raw
history blame contribute delete
No virus
2.8 kB
import jax
import jax.numpy as jnp
import pandas as pd
import requests
import streamlit as st
from PIL import Image
from utils import load_model
def app(model_name):
model, processor = load_model(f"koclip/{model_name}")
st.title("Zero-shot Image Classification")
st.markdown(
"""
This demo explores KoCLIP's zero-shot prediction capabilities. The model takes an image and a list of candidate captions from the user and predicts the most likely caption that best describes the given image.
---
"""
)
query1 = st.text_input(
"Enter a URL to an image...",
value="http://images.cocodataset.org/val2017/000000039769.jpg",
)
query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])
col1, col2 = st.beta_columns([3, 1])
with col2:
captions_count = st.selectbox("Number of labels", options=range(1, 6), index=2)
normalize = st.checkbox("Apply Softmax")
compute = st.button("Classify")
with col1:
captions = []
defaults = ["κ·€μ—¬μš΄ 고양이", "λ©‹μžˆλŠ” 강아지", "ν¬λ™ν¬λ™ν•œ ν–„μŠ€ν„°"]
for idx in range(captions_count):
value = defaults[idx] if idx < len(defaults) else ""
captions.append(st.text_input(f"Insert caption {idx+1}", value=value))
if compute:
if not any([query1, query2]):
st.error("Please upload an image or paste an image URL.")
else:
st.markdown("""---""")
with st.spinner("Computing..."):
image_data = (
query2
if query2 is not None
else requests.get(query1, stream=True).raw
)
image = Image.open(image_data)
# captions = [caption.strip() for caption in captions.split(",")]
captions = [f"이것은 {caption.strip()}이닀." for caption in captions]
inputs = processor(
text=captions, images=image, return_tensors="jax", padding=True
)
inputs["pixel_values"] = jnp.transpose(
inputs["pixel_values"], axes=[0, 2, 3, 1]
)
outputs = model(**inputs)
if normalize:
name = "normalized prob"
probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
else:
name = "cosine sim"
probs = outputs.logits_per_image
chart_data = pd.Series(probs[0], index=captions, name=name)
col1, col2 = st.beta_columns(2)
with col1:
st.image(image)
with col2:
st.bar_chart(chart_data)