roy214 commited on
Commit
8736a4e
·
verified ·
1 Parent(s): 14d4cb7

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +9 -1
src/streamlit_app.py CHANGED
@@ -68,16 +68,24 @@ def show_img(img_id, score=None, col=None):
68
  st.error(f"Error processing image: {e}")
69
 
70
  def search_faiss(model, processor, index, id_map, prompt, top_k=5, device='cpu'):
 
71
  inputs = processor(text=[prompt], return_tensors='pt', padding=True).to(device)
 
 
72
  with torch.no_grad():
73
  txt_emb = model.get_text_features(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
74
  txt_emb = txt_emb / txt_emb.norm(p=2, dim=-1, keepdim=True)
 
 
75
  q = txt_emb.cpu().numpy().astype('float32')
76
 
77
  D, I = index.search(q, top_k)
 
 
78
  return [(id_map[i], float(D[0][j])) for j, i in enumerate(I[0])]
79
 
80
  def running(prompt, top_k=5):
 
81
  results = search_faiss(
82
  model, processor,
83
  index, id_map,
@@ -175,5 +183,5 @@ st.subheader("Example Usage: Enter a Prompt to Retrieve Related Images")
175
  prompt_input = st.text_input("Enter a prompt", "a red Apparel dress")
176
  top_k_input = st.number_input("Enter the number of results (top_k)", min_value=1, max_value=10, value=5)
177
 
178
- if st.button("Find top 5 related images:"):
179
  running(prompt_input, top_k_input)
 
68
  st.error(f"Error processing image: {e}")
69
 
70
  def search_faiss(model, processor, index, id_map, prompt, top_k=5, device='cpu'):
71
+ st.write(f"Running FAISS search for prompt: '{prompt}' with top_k={top_k}")
72
  inputs = processor(text=[prompt], return_tensors='pt', padding=True).to(device)
73
+ st.write("Prompt processed by tokenizer.")
74
+
75
  with torch.no_grad():
76
  txt_emb = model.get_text_features(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
77
  txt_emb = txt_emb / txt_emb.norm(p=2, dim=-1, keepdim=True)
78
+ st.write("Text embedding computed.")
79
+
80
  q = txt_emb.cpu().numpy().astype('float32')
81
 
82
  D, I = index.search(q, top_k)
83
+
84
+ st.write("FAISS search completed.")
85
  return [(id_map[i], float(D[0][j])) for j, i in enumerate(I[0])]
86
 
87
  def running(prompt, top_k=5):
88
+ st.write("Starting image retrieval...")
89
  results = search_faiss(
90
  model, processor,
91
  index, id_map,
 
183
  prompt_input = st.text_input("Enter a prompt", "a red Apparel dress")
184
  top_k_input = st.number_input("Enter the number of results (top_k)", min_value=1, max_value=10, value=5)
185
 
186
+ if st.button(f"Find top {top_k_input} related images"):
187
  running(prompt_input, top_k_input)