Bingsu's picture
fix: revert large models
2949ccd
raw
history blame contribute delete
No virus
2.17 kB
import pickle
import numpy as np
import pandas as pd
import streamlit as st
import torch
from sentence_transformers.util import semantic_search
from transformers import AutoModel, AutoProcessor
st.title("My CLIP Model Test")
st.markdown(
"[Unsplash Lite dataset](https://unsplash.com/data)μ—μ„œ μž…λ ₯ ν…μŠ€νŠΈμ™€ κ°€μž₯ μœ μ‚¬ν•œ 이미지λ₯Ό κ²€μƒ‰ν•©λ‹ˆλ‹€."
)
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_dual_encoder_model(name: str):
with st.spinner("Loading model..."):
model = AutoModel.from_pretrained(name).eval()
processor = AutoProcessor.from_pretrained(name)
return model, processor
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_clip_model(name: str):
with st.spinner("Loading model..."):
model = AutoModel.from_pretrained(name).eval()
processor = AutoProcessor.from_pretrained(name)
return model, processor
model_list = [
"Bingsu/clip-vit-base-patch32-ko",
"openai/clip-vit-base-patch32",
"Bingsu/vitB32_bert_ko_small_clip",
]
model_type = st.radio("Select model", model_list)
if model_type == "Bingsu/vitB32_bert_ko_small_clip":
model, processor = get_dual_encoder_model(model_type)
else:
model, processor = get_clip_model(model_type)
info = pd.read_csv("info.csv")
with open("img_id.pkl", "rb") as f:
img_id = pickle.load(f)
img_emb = np.load("img_emb.npy")
text = st.text_input("Input Text", value="검은 고양이")
tokens = processor(text=text, return_tensors="pt")
with torch.no_grad():
text_emb = model.get_text_features(**tokens)
result = semantic_search(text_emb, img_emb, top_k=16)[0]
_result = iter(result)
def get_url() -> str:
# λͺ‡λͺ‡ 이미지가 info.csv 데이터에 μ—†μŠ΅λ‹ˆλ‹€.
while True:
r = next(_result)
photo_id = img_id[r["corpus_id"]]
target_series = info.loc[info["photo_id"] == photo_id, "photo_image_url"]
if len(target_series) == 0:
continue
img_url = target_series.iloc[0]
return img_url
columns = st.columns(3) + st.columns(3)
for col in columns:
img_url = get_url()
col.image(img_url, use_column_width=True)