File size: 2,173 Bytes
bf3fe47
 
 
 
e336559
eb6e722
bf3fe47
2a88d86
bf3fe47
4b863c6
 
 
 
bf3fe47
f0c8df9
9c8a7bc
cefc98c
f0c8df9
cefc98c
 
f0c8df9
 
 
2a88d86
cefc98c
2a88d86
cefc98c
 
2a88d86
 
 
cefc98c
 
 
 
 
 
2a88d86
cefc98c
 
4b863c6
cefc98c
bf3fe47
 
 
 
 
 
5834f42
bf3fe47
eb6e722
 
 
bf3fe47
cefc98c
f0c8df9
 
 
 
 
 
 
 
 
 
 
 
 
 
5834f42
 
f0c8df9
 
5834f42
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import pickle

import numpy as np
import pandas as pd
import streamlit as st
import torch
from sentence_transformers.util import semantic_search
from transformers import AutoModel, AutoProcessor

st.title("My CLIP Model Test")
st.markdown(
    "[Unsplash Lite dataset](https://unsplash.com/data)μ—μ„œ μž…λ ₯ ν…μŠ€νŠΈμ™€ κ°€μž₯ μœ μ‚¬ν•œ 이미지λ₯Ό κ²€μƒ‰ν•©λ‹ˆλ‹€."
)


@st.cache(allow_output_mutation=True, show_spinner=False)
def get_dual_encoder_model(name: str):
    with st.spinner("Loading model..."):
        model = AutoModel.from_pretrained(name).eval()
        processor = AutoProcessor.from_pretrained(name)
    return model, processor


@st.cache(allow_output_mutation=True, show_spinner=False)
def get_clip_model(name: str):
    with st.spinner("Loading model..."):
        model = AutoModel.from_pretrained(name).eval()
        processor = AutoProcessor.from_pretrained(name)
    return model, processor


model_list = [
    "Bingsu/clip-vit-base-patch32-ko",
    "openai/clip-vit-base-patch32",
    "Bingsu/vitB32_bert_ko_small_clip",
]
model_type = st.radio("Select model", model_list)

if model_type == "Bingsu/vitB32_bert_ko_small_clip":
    model, processor = get_dual_encoder_model(model_type)
else:
    model, processor = get_clip_model(model_type)

info = pd.read_csv("info.csv")
with open("img_id.pkl", "rb") as f:
    img_id = pickle.load(f)
img_emb = np.load("img_emb.npy")

text = st.text_input("Input Text", value="검은 고양이")
tokens = processor(text=text, return_tensors="pt")

with torch.no_grad():
    text_emb = model.get_text_features(**tokens)

result = semantic_search(text_emb, img_emb, top_k=16)[0]
_result = iter(result)


def get_url() -> str:
    # λͺ‡λͺ‡ 이미지가 info.csv 데이터에 μ—†μŠ΅λ‹ˆλ‹€.
    while True:
        r = next(_result)
        photo_id = img_id[r["corpus_id"]]
        target_series = info.loc[info["photo_id"] == photo_id, "photo_image_url"]
        if len(target_series) == 0:
            continue
        img_url = target_series.iloc[0]
        return img_url


columns = st.columns(3) + st.columns(3)
for col in columns:
    img_url = get_url()
    col.image(img_url, use_column_width=True)