File size: 3,112 Bytes
8ff0261
5dce03a
bf20f73
7326e2c
bf20f73
 
 
f1d50b1
 
 
 
 
e4b9c8b
f1d50b1
e4b9c8b
2cf3514
 
83d94a8
8b6d3c7
fedeff8
 
2cf3514
e4b9c8b
df702af
14261e1
 
df702af
 
7326e2c
83d94a8
 
 
 
 
 
 
e4b9c8b
83d94a8
 
 
 
 
 
 
 
7326e2c
 
e4b9c8b
83d94a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import requests
import streamlit as st
from PIL import Image

from utils import load_model


def app(model_name):
    model, processor = load_model(f"koclip/{model_name}")

    st.title("Zero-shot Image Classification")
    st.markdown(
        """
        This demonstration explores capability of KoCLIP in the field of Zero-Shot Prediction. This demo takes a set of image and captions from the user, and predicts the most likely label among the different captions given.   
        
        KoCLIP is a retraining of OpenAI's CLIP model using 82,783 images from [MSCOCO](https://cocodataset.org/#home) dataset and Korean caption annotations. Korean translation of caption annotations were obtained from [AI Hub](https://aihub.or.kr/keti_data_board/visual_intelligence). Base model `koclip` uses `klue/roberta` as text encoder and `openai/clip-vit-base-patch32` as image encoder. Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder.    
        """
    )

    query1 = st.text_input(
        "Enter a URL to an image...",
        value="http://images.cocodataset.org/val2017/000000039769.jpg",
    )
    query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])

    col1, col2 = st.beta_columns([3, 1])

    with col2:
        captions_count = st.selectbox(
            "Number of labels", options=range(1, 6), index=2
        )
        compute = st.button("Classify")

    with col1:
        captions = []
        defaults = ["κ·€μ—¬μš΄ 고양이", "λ©‹μžˆλŠ” 강아지", "ν¬λ™ν¬λ™ν•œ ν–„μŠ€ν„°"]
        for idx in range(captions_count):
            value = defaults[idx] if idx < len(defaults) else ""
            captions.append(st.text_input(f"Insert label {idx+1}", value=value))

    if compute:
        if not any([query1, query2]):
            st.error("Please upload an image or paste an image URL.")
        else:
            st.markdown("""---""")
            with st.spinner("Computing..."):
                image_data = (
                    query2 if query2 is not None else requests.get(query1, stream=True).raw
                )
                image = Image.open(image_data)
                
                # captions = [caption.strip() for caption in captions.split(",")]
                captions = [f"이것은 {caption.strip()}이닀." for caption in captions]
                inputs = processor(
                    text=captions, images=image, return_tensors="jax", padding=True
                )
                inputs["pixel_values"] = jnp.transpose(
                    inputs["pixel_values"], axes=[0, 2, 3, 1]
                )
                outputs = model(**inputs)
                probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
                chart_data = pd.Series(probs[0], index=captions)

                col1, col2 = st.beta_columns(2)
                with col1:
                    st.image(image)
                with col2:
                    st.bar_chart(chart_data)