koclip / image2text.py
jaketae's picture
Merge branch 'main' of https://huggingface.co/spaces/flax-community/koclip
8375841
raw history blame
No virus
2.43 kB
import streamlit as st
import requests
import numpy as np
import jax
import jax.numpy as jnp
from PIL import Image
import pandas as pd
from utils import load_model
def app(model_name):
model, processor = load_model(f"koclip/{model_name}")
st.title("Zero-shot Image Classification")
st.markdown(
"""
This demonstration explores capability of KoCLIP in the field of Zero-Shot Prediction. This demo takes a set of image and captions from, and predicts the most likely label among the different captions given.
KoCLIP is a retraining of OpenAI's CLIP model using 82,783 images from [MSCOCO](https://cocodataset.org/#home) dataset and Korean caption annotations. Korean translation of caption annotations were obtained from [AI Hub](https://aihub.or.kr/keti_data_board/visual_intelligence). Base model `koclip` uses `klue/roberta` as text encoder and `openai/clip-vit-base-patch32` as image encoder. Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder.
"""
)
query1 = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
query2 = st.text_input("or a URL to an image...")
captions = st.text_input(
"Enter candidate captions in comma-separated form.",
value="๊ท€์—ฌ์šด ๊ณ ์–‘์ด,๋ฉ‹์žˆ๋Š” ๊ฐ•์•„์ง€,ํŠธ๋žœ์Šคํฌ๋จธ"
)
if st.button("์งˆ๋ฌธ (Query)"):
if not any([query1, query2]):
st.error("Please upload an image or paste an image URL.")
else:
image_data = query1 if query1 is not None else requests.get(query2, stream=True).raw
image = Image.open(image_data)
st.image(image)
captions = captions.split(",")
inputs = processor(text=captions, images=image, return_tensors="jax", padding=True)
inputs["pixel_values"] = jnp.transpose(
inputs["pixel_values"], axes=[0, 2, 3, 1]
)
outputs = model(**inputs)
probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
score_dict = {captions[idx]: prob for idx, prob in enumerate(*probs)}
df = pd.DataFrame(score_dict.values(), index=score_dict.keys())
st.bar_chart(df)
# for idx, prob in sorted(enumerate(*probs), key=lambda x: x[1], reverse=True):
# st.text(f"Score: `{prob}`, {captions[idx]}")