koclip / image2text.py
jaketae's picture
feature: show prob scores as bar chart
7326e2c
raw history blame
No virus
1.72 kB
import streamlit as st
import requests
import numpy as np
import jax
import jax.numpy as jnp
from PIL import Image
import pandas as pd
from utils import load_model
def app(model_name):
model, processor = load_model(f"koclip/{model_name}")
st.title("Zero-shot Image Classification")
st.markdown(
"""
Some text goes in here.
"""
)
query1 = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
query2 = st.text_input("or a URL to an image...")
captions = st.text_input(
"Enter candidate captions in comma-separated form.",
value="๊ท€์—ฌ์šด ๊ณ ์–‘์ด,๋ฉ‹์žˆ๋Š” ๊ฐ•์•„์ง€,ํŠธ๋žœ์Šคํฌ๋จธ"
)
if st.button("์งˆ๋ฌธ (Query)"):
if not any([query1, query2]):
st.error("Please upload an image or paste an image URL.")
else:
image_data = query1 if query1 is not None else requests.get(query2, stream=True).raw
image = Image.open(image_data)
st.image(image)
captions = captions.split(",")
inputs = processor(text=captions, images=image, return_tensors="jax", padding=True)
inputs["pixel_values"] = jnp.transpose(
inputs["pixel_values"], axes=[0, 2, 3, 1]
)
outputs = model(**inputs)
probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
score_dict = {captions[idx]: prob for idx, prob in enumerate(*probs)}
df = pd.DataFrame(score_dict.values(), index=score_dict.keys())
st.bar_chart(df)
# for idx, prob in sorted(enumerate(*probs), key=lambda x: x[1], reverse=True):
# st.text(f"Score: `{prob}`, {captions[idx]}")