Spaces:
Runtime error
Runtime error
feat: include new model
Browse files
app.py
CHANGED
@@ -14,32 +14,35 @@ st.markdown(
|
|
14 |
|
15 |
|
16 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
17 |
-
def get_dual_encoder_model():
|
18 |
with st.spinner("Loading model..."):
|
19 |
-
model = AutoModel.from_pretrained(
|
20 |
-
processor = AutoProcessor.from_pretrained(
|
21 |
return model, processor
|
22 |
|
23 |
|
24 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
25 |
-
def get_clip_model():
|
26 |
with st.spinner("Loading model..."):
|
27 |
-
model = AutoModel.from_pretrained(
|
28 |
-
processor = AutoProcessor.from_pretrained(
|
29 |
return model, processor
|
30 |
|
31 |
|
32 |
-
|
33 |
-
"
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
if model_type == "Bingsu/
|
38 |
-
model, processor =
|
39 |
-
elif model_type == "Bingsu/vitB32_bert_ko_small_clip":
|
40 |
-
model, processor = get_dual_encoder_model()
|
41 |
else:
|
42 |
-
|
43 |
|
44 |
info = pd.read_csv("info.csv")
|
45 |
with open("img_id.pkl", "rb") as f:
|
@@ -52,7 +55,7 @@ tokens = processor(text=text, return_tensors="pt")
|
|
52 |
with torch.no_grad():
|
53 |
text_emb = model.get_text_features(**tokens)
|
54 |
|
55 |
-
result = semantic_search(text_emb, img_emb, top_k=
|
56 |
_result = iter(result)
|
57 |
|
58 |
|
|
|
14 |
|
15 |
|
16 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
17 |
+
def get_dual_encoder_model(name: str):
|
18 |
with st.spinner("Loading model..."):
|
19 |
+
model = AutoModel.from_pretrained(name).eval()
|
20 |
+
processor = AutoProcessor.from_pretrained(name)
|
21 |
return model, processor
|
22 |
|
23 |
|
24 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
25 |
+
def get_clip_model(name: str):
|
26 |
with st.spinner("Loading model..."):
|
27 |
+
model = AutoModel.from_pretrained(name).eval()
|
28 |
+
processor = AutoProcessor.from_pretrained(name)
|
29 |
return model, processor
|
30 |
|
31 |
|
32 |
+
model_list = [
|
33 |
+
"Bingsu/clip-vit-base-patch32-ko",
|
34 |
+
"Bingsu/clip-vit-large-patch14-ko",
|
35 |
+
"openai/clip-vit-base-patch32",
|
36 |
+
"openai/clip-vit-base-patch16",
|
37 |
+
"openai/clip-vit-large-patch14",
|
38 |
+
"Bingsu/vitB32_bert_ko_small_clip",
|
39 |
+
]
|
40 |
+
model_type = st.radio("Select model", model_list)
|
41 |
|
42 |
+
if model_type == "Bingsu/vitB32_bert_ko_small_clip":
|
43 |
+
model, processor = get_dual_encoder_model(model_type)
|
|
|
|
|
44 |
else:
|
45 |
+
model, processor = get_clip_model(model_type)
|
46 |
|
47 |
info = pd.read_csv("info.csv")
|
48 |
with open("img_id.pkl", "rb") as f:
|
|
|
55 |
with torch.no_grad():
|
56 |
text_emb = model.get_text_features(**tokens)
|
57 |
|
58 |
+
result = semantic_search(text_emb, img_emb, top_k=16)[0]
|
59 |
_result = iter(result)
|
60 |
|
61 |
|