File size: 2,612 Bytes
84c806e
 
 
 
 
6525b03
 
 
84c806e
 
 
6525b03
 
84c806e
6525b03
 
84c806e
6525b03
 
 
84c806e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6525b03
 
84c806e
 
 
 
 
6525b03
 
84c806e
 
 
6525b03
 
 
84c806e
 
 
 
 
 
 
 
 
6525b03
 
 
84c806e
 
 
 
 
6525b03
 
 
 
 
 
84c806e
 
6525b03
84c806e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os

import jax
import jax.numpy as jnp
import numpy as np
import requests
import streamlit as st
from PIL import Image

from utils import load_model


def split_image(im, num_rows=3, num_cols=3):
    im = np.array(im)
    row_size = im.shape[0] // num_rows
    col_size = im.shape[1] // num_cols
    tiles = [
        im[x : x + M, y : y + N]
        for x in range(0, num_rows * row_size, row_size)
        for y in range(0, num_cols * col_size, col_size)
    ]
    return tiles


# def split_image(X):
#     num_rows = X.shape[0] // 224
#     num_cols = X.shape[1] // 224
#     Xc = X[0:num_rows * 224, 0:num_cols * 224, :]
#     patches = []
#     for j in range(num_rows):
#         for i in range(num_cols):
#             patches.append(Xc[j * 224:(j + 1) * 224, i * 224:(i + 1) * 224, :])
#     return patches


def app(model_name):
    model, processor = load_model(f"koclip/{model_name}")

    st.title("Most Relevant Part of Image")
    st.markdown(
        """
        Given a piece of text, the CLIP model finds the part of an image that best explains the text.
        To try it out, you can
        1) Upload an image
        2) Explain a part of the image in text
        Which will yield the most relevant image tile from a 3x3 grid of the image 
        """
    )

    query1 = st.text_input(
        "Enter a URL to an image...",
        value="https://img.sbs.co.kr/newimg/news/20200823/201463830_1280.jpg",
    )
    query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])
    captions = st.text_input(
        "Enter query to find most relevant part of image ",
        value="이건 서울의 경복궁 사진이다.",
    )

    if st.button("질문 (Query)"):
        if not any([query1, query2]):
            st.error("Please upload an image or paste an image URL.")
        else:
            image_data = (
                query2 if query2 is not None else requests.get(query1, stream=True).raw
            )
            image = Image.open(image_data)
            st.image(image)

            images = split_image(image)

            inputs = processor(
                text=captions, images=images, return_tensors="jax", padding=True
            )
            inputs["pixel_values"] = jnp.transpose(
                inputs["pixel_values"], axes=[0, 2, 3, 1]
            )
            outputs = model(**inputs)
            probs = jax.nn.softmax(outputs.logits_per_image, axis=0)
            for idx, prob in sorted(enumerate(probs), key=lambda x: x[1], reverse=True):
                st.text(f"Score: {prob[0]:.3f}")
                st.image(images[idx])