Spaces:
Build error
Build error
File size: 2,737 Bytes
8ff0261 5dce03a bf20f73 7326e2c bf20f73 f1d50b1 e4b9c8b f1d50b1 e4b9c8b 2cf3514 8b6d3c7 fedeff8 2cf3514 e4b9c8b df702af 14261e1 df702af 7326e2c bf20f73 df702af 7326e2c e4b9c8b 7326e2c e4b9c8b bf20f73 df702af bf20f73 7326e2c 8ff0261 14261e1 bf20f73 8ff0261 7326e2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import requests
import streamlit as st
from PIL import Image
from utils import load_model
def app(model_name):
model, processor = load_model(f"koclip/{model_name}")
st.title("Zero-shot Image Classification")
st.markdown(
"""
This demonstration explores capability of KoCLIP in the field of Zero-Shot Prediction. This demo takes a set of image and captions from, and predicts the most likely label among the different captions given.
KoCLIP is a retraining of OpenAI's CLIP model using 82,783 images from [MSCOCO](https://cocodataset.org/#home) dataset and Korean caption annotations. Korean translation of caption annotations were obtained from [AI Hub](https://aihub.or.kr/keti_data_board/visual_intelligence). Base model `koclip` uses `klue/roberta` as text encoder and `openai/clip-vit-base-patch32` as image encoder. Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder.
"""
)
query1 = st.text_input(
"Enter a URL to an image...",
value="http://images.cocodataset.org/val2017/000000039769.jpg",
)
query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])
captions = st.text_input(
"Enter candidate captions in comma-separated form.",
value="κ·μ¬μ΄ κ³ μμ΄,λ©μλ κ°μμ§,ν¬λν¬λν νμ€ν°",
)
if st.button("μ§λ¬Έ (Query)"):
if not any([query1, query2]):
st.error("Please upload an image or paste an image URL.")
else:
image_data = (
query2 if query2 is not None else requests.get(query1, stream=True).raw
)
image = Image.open(image_data)
st.image(image)
# captions = [caption.strip() for caption in captions.split(",")]
captions = [f"μ΄κ²μ {caption.strip()}μ΄λ€." for caption in captions.split(",")]
inputs = processor(
text=captions, images=image, return_tensors="jax", padding=True
)
inputs["pixel_values"] = jnp.transpose(
inputs["pixel_values"], axes=[0, 2, 3, 1]
)
outputs = model(**inputs)
probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
score_dict = {captions[idx]: prob for idx, prob in enumerate(*probs)}
df = pd.DataFrame(score_dict.values(), index=score_dict.keys())
st.bar_chart(df)
# for idx, prob in sorted(enumerate(*probs), key=lambda x: x[1], reverse=True):
# st.text(f"Score: `{prob}`, {captions[idx]}")
|