Bingsu commited on
Commit
cefc98c
·
1 Parent(s): 173ff84

feat: include new model

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -14,32 +14,35 @@ st.markdown(
14
 
15
 
16
  @st.cache(allow_output_mutation=True, show_spinner=False)
17
- def get_dual_encoder_model():
18
  with st.spinner("Loading model..."):
19
- model = AutoModel.from_pretrained("Bingsu/vitB32_bert_ko_small_clip").eval()
20
- processor = AutoProcessor.from_pretrained("Bingsu/vitB32_bert_ko_small_clip")
21
  return model, processor
22
 
23
 
24
  @st.cache(allow_output_mutation=True, show_spinner=False)
25
- def get_clip_model():
26
  with st.spinner("Loading model..."):
27
- model = AutoModel.from_pretrained("Bingsu/clip-vit-base-patch32-ko").eval()
28
- processor = AutoProcessor.from_pretrained("Bingsu/clip-vit-base-patch32-ko")
29
  return model, processor
30
 
31
 
32
- model_type = st.radio(
33
- "Select model",
34
- ["Bingsu/clip-vit-base-patch32-ko", "Bingsu/vitB32_bert_ko_small_clip"],
35
- )
 
 
 
 
 
36
 
37
- if model_type == "Bingsu/clip-vit-base-patch32-ko":
38
- model, processor = get_clip_model()
39
- elif model_type == "Bingsu/vitB32_bert_ko_small_clip":
40
- model, processor = get_dual_encoder_model()
41
  else:
42
- raise ValueError("Invalid model type")
43
 
44
  info = pd.read_csv("info.csv")
45
  with open("img_id.pkl", "rb") as f:
@@ -52,7 +55,7 @@ tokens = processor(text=text, return_tensors="pt")
52
  with torch.no_grad():
53
  text_emb = model.get_text_features(**tokens)
54
 
55
- result = semantic_search(text_emb, img_emb, top_k=15)[0]
56
  _result = iter(result)
57
 
58
 
 
14
 
15
 
16
  @st.cache(allow_output_mutation=True, show_spinner=False)
17
+ def get_dual_encoder_model(name: str):
18
  with st.spinner("Loading model..."):
19
+ model = AutoModel.from_pretrained(name).eval()
20
+ processor = AutoProcessor.from_pretrained(name)
21
  return model, processor
22
 
23
 
24
  @st.cache(allow_output_mutation=True, show_spinner=False)
25
+ def get_clip_model(name: str):
26
  with st.spinner("Loading model..."):
27
+ model = AutoModel.from_pretrained(name).eval()
28
+ processor = AutoProcessor.from_pretrained(name)
29
  return model, processor
30
 
31
 
32
+ model_list = [
33
+ "Bingsu/clip-vit-base-patch32-ko",
34
+ "Bingsu/clip-vit-large-patch14-ko",
35
+ "openai/clip-vit-base-patch32",
36
+ "openai/clip-vit-base-patch16",
37
+ "openai/clip-vit-large-patch14",
38
+ "Bingsu/vitB32_bert_ko_small_clip",
39
+ ]
40
+ model_type = st.radio("Select model", model_list)
41
 
42
+ if model_type == "Bingsu/vitB32_bert_ko_small_clip":
43
+ model, processor = get_dual_encoder_model(model_type)
 
 
44
  else:
45
+ model, processor = get_clip_model(model_type)
46
 
47
  info = pd.read_csv("info.csv")
48
  with open("img_id.pkl", "rb") as f:
 
55
  with torch.no_grad():
56
  text_emb = model.get_text_features(**tokens)
57
 
58
+ result = semantic_search(text_emb, img_emb, top_k=16)[0]
59
  _result = iter(result)
60
 
61