File size: 2,737 Bytes
8ff0261
5dce03a
bf20f73
7326e2c
bf20f73
 
 
f1d50b1
 
 
 
 
e4b9c8b
f1d50b1
e4b9c8b
2cf3514
 
8b6d3c7
 
fedeff8
 
2cf3514
e4b9c8b
df702af
14261e1
 
df702af
 
7326e2c
 
bf20f73
df702af
7326e2c
e4b9c8b
 
7326e2c
 
e4b9c8b
bf20f73
df702af
bf20f73
7326e2c
8ff0261
14261e1
 
bf20f73
 
 
8ff0261
 
 
 
 
7326e2c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import requests
import streamlit as st
from PIL import Image

from utils import load_model


def app(model_name):
    model, processor = load_model(f"koclip/{model_name}")

    st.title("Zero-shot Image Classification")
    st.markdown(
        """
        This demonstration explores capability of KoCLIP in the field of Zero-Shot Prediction. This demo takes a set of image and captions from, and predicts the most likely label among the different captions given.   
        
        KoCLIP is a retraining of OpenAI's CLIP model using 82,783 images from [MSCOCO](https://cocodataset.org/#home) dataset and Korean caption annotations. Korean translation of caption annotations were obtained from [AI Hub](https://aihub.or.kr/keti_data_board/visual_intelligence). Base model `koclip` uses `klue/roberta` as text encoder and `openai/clip-vit-base-patch32` as image encoder. Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder.    
        """
    )

    query1 = st.text_input(
        "Enter a URL to an image...",
        value="http://images.cocodataset.org/val2017/000000039769.jpg",
    )
    query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])

    captions = st.text_input(
        "Enter candidate captions in comma-separated form.",
        value="κ·€μ—¬μš΄ 고양이,λ©‹μžˆλŠ” 강아지,ν¬λ™ν¬λ™ν•œ ν–„μŠ€ν„°",
    )

    if st.button("질문 (Query)"):
        if not any([query1, query2]):
            st.error("Please upload an image or paste an image URL.")
        else:
            image_data = (
                query2 if query2 is not None else requests.get(query1, stream=True).raw
            )
            image = Image.open(image_data)
            st.image(image)
            # captions = [caption.strip() for caption in captions.split(",")]
            captions = [f"이것은 {caption.strip()}이닀." for caption in captions.split(",")]
            inputs = processor(
                text=captions, images=image, return_tensors="jax", padding=True
            )
            inputs["pixel_values"] = jnp.transpose(
                inputs["pixel_values"], axes=[0, 2, 3, 1]
            )
            outputs = model(**inputs)
            probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
            score_dict = {captions[idx]: prob for idx, prob in enumerate(*probs)}
            df = pd.DataFrame(score_dict.values(), index=score_dict.keys())
            st.bar_chart(df)
            # for idx, prob in sorted(enumerate(*probs), key=lambda x: x[1], reverse=True):
            #     st.text(f"Score: `{prob}`, {captions[idx]}")