import pickle import numpy as np import pandas as pd import streamlit as st import torch from sentence_transformers.util import semantic_search from transformers import AutoModel, AutoProcessor st.title("My CLIP Model Test") st.markdown( "[Unsplash Lite dataset](https://unsplash.com/data)에서 입력 텍스트와 가장 유사한 이미지를 검색합니다." ) @st.cache(allow_output_mutation=True, show_spinner=False) def get_dual_encoder_model(): with st.spinner("Loading model..."): model = AutoModel.from_pretrained("Bingsu/vitB32_bert_ko_small_clip").eval() processor = AutoProcessor.from_pretrained("Bingsu/vitB32_bert_ko_small_clip") return model, processor @st.cache(allow_output_mutation=True, show_spinner=False) def get_clip_model(): with st.spinner("Loading model..."): model = AutoModel.from_pretrained("Bingsu/clip-vit-base-patch32-ko").eval() processor = AutoProcessor.from_pretrained("Bingsu/clip-vit-base-patch32-ko") return model, processor model_type = st.radio( "Select model", ["Bingsu/clip-vit-base-patch32-ko", "Bingsu/vitB32_bert_ko_small_clip"], ) if model_type == "Bingsu/clip-vit-base-patch32-ko": model, processor = get_clip_model() elif model_type == "Bingsu/vitB32_bert_ko_small_clip": model, processor = get_dual_encoder_model() else: raise ValueError("Invalid model type") info = pd.read_csv("info.csv") with open("img_id.pkl", "rb") as f: img_id = pickle.load(f) img_emb = np.load("img_emb.npy") text = st.text_input("Input Text", value="검은 고양이") tokens = processor(text=text, return_tensors="pt") with torch.no_grad(): text_emb = model.get_text_features(**tokens) result = semantic_search(text_emb, img_emb, top_k=15)[0] _result = iter(result) def get_url() -> str: # 몇몇 이미지가 info.csv 데이터에 없습니다. while True: r = next(_result) photo_id = img_id[r["corpus_id"]] target_series = info.loc[info["photo_id"] == photo_id, "photo_image_url"] if len(target_series) == 0: continue img_url = target_series.iloc[0] return img_url columns = st.columns(3) + st.columns(3) for col in columns: img_url = get_url() col.image(img_url, use_column_width=True)